## Imports

In [None]:
%load_ext autoreload
%autoreload 2

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset_classes.mvcnn_data_module import MVCNNDataModule
from model_classes.mvcnn import MVCNNClassifier
from model_classes.callbacks import UnfreezePretrainedWeights, ResetEvalResults

## Parameters

In [None]:
NUM_CLASSES = 4
LEARNING_RATE = 1e-3
LEARNING_RATE_REDUCTION_FACTOR = 1e3
NUM_EPOCHS = 24
NUM_EPOCHS_FREEZE_PRETRAINED = 12
BATCH_SIZE = 1
DROPOUT_RATE = 0.3
SAVE_PATH = './output'

## Class initialization

In [None]:
data_module = MVCNNDataModule(NUM_CLASSES, BATCH_SIZE)

In [None]:
model = MVCNNClassifier(
    learning_rate=LEARNING_RATE,
    num_epochs_freeze_pretrained=NUM_EPOCHS_FREEZE_PRETRAINED,
    dropout_rate=DROPOUT_RATE,
    )

In [None]:
callbacks = [
    ModelCheckpoint(monitor='val_f1', verbose=True, mode='max'),
    UnfreezePretrainedWeights(LEARNING_RATE_REDUCTION_FACTOR),
    ResetEvalResults(NUM_CLASSES)
]

In [None]:
trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir=SAVE_PATH,
    callbacks=callbacks
)

## Model training

In [None]:
trainer.fit(
    model,
    train_dataloader=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)