In [None]:
%cd ..

In [None]:
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from src.data.ConsPracDataModule import ConsPracDataModule
from src.models.ResNet50 import ResNet50

In [None]:
dm = ConsPracDataModule(augment_images=True)
dm.setup()

In [None]:
dl_train = dm.train_dataloader()
dl_val = dm.val_dataloader()

In [None]:
dl_train.dataset.label.idxmax(axis=1).value_counts(normalize=True)

In [None]:
dl_val.dataset.label.idxmax(axis=1).value_counts(normalize=True)

In [None]:
model = ResNet50(num_classes=8)

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    devices="auto",
    logger=TensorBoardLogger(save_dir="logs/"),
    callbacks=[
        EarlyStopping("val_loss", patience=5),
        ModelCheckpoint(
            monitor="val_loss",  # Metric to monitor
            filename="best-checkpoint-{epoch:02d}-{val_loss:.2f}",  # Filename template
            save_top_k=1,  # Save the top k models
            mode="min",  # Mode 'min' for minimizing the validation loss
            enable_version_counter=True,  # We can keep all checkpoints
        ),
    ],
)

torch.set_float32_matmul_precision("medium")
trainer.fit(model, dl_train, dl_val)