## Imports

In [1]:
%load_ext autoreload
%autoreload 2

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset.aml_data_module import AMLDataModule
from model.aml_classifier import AMLClassifier
from model.callbacks import SwitchPretrainedWeightsState

## Model training

Model hyperparameters:

In [None]:
NUM_EPOCHS = 20
SEQUENCE_LENGTH = 200
SEQUENCE_OVERLAP = 50
BATCH_SIZE = 1  # TODO: at the moment model can only handle online learning
NUM_FREEZE_PRETRAINED = NUM_EPOCHS


data_module = AMLDataModule(
    sequence_length=SEQUENCE_LENGTH,
    overlap=SEQUENCE_OVERLAP
)

model = AMLClassifier()

callbacks = [
    ModelCheckpoint(
        filename=f'AMLClassifier-seq_len{SEQUENCE_LENGTH}-ovlp{SEQUENCE_OVERLAP}'+'-{epoch}-{val/accuracy:.3f}',
        monitor='val/accuracy',
        mode='max',
        save_top_k=1,
        verbose=True,
    ),
    SwitchPretrainedWeightsState()
]

trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir='../output',
    callbacks=callbacks
)

Training loop:

In [None]:
trainer.fit(
    model,
    train_dataloader=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)