In [None]:
from google.colab import drive
# mount google drive
drive.mount('/content/drive/')

In [None]:
!pip install jsonargparse==4.28.0 lightning==2.2.3 seaborn==0.13.2 tabulate==0.9.0 termcolor==2.4.0 torch==2.2.2 torchmetrics==1.3.2 torchvision==0.17.2 wandb==0.16.6 wget==3.2 tqdm==4.66.2

In [None]:
cd /content/drive/MyDrive/hyu/aue8088-pa1

In [None]:
# PyTorch & Pytorch Lightning
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning import Trainer
import torch

# Custom packages
from src.dataset import TinyImageNetDatasetModule
from src.network import SimpleClassifier
import src.config as cfg

torch.set_float32_matmul_precision('medium')

In [None]:
model = SimpleClassifier(
    model_name = cfg.MODEL_NAME,
    num_classes = cfg.NUM_CLASSES,
    optimizer_params = cfg.OPTIMIZER_PARAMS,
    scheduler_params = cfg.SCHEDULER_PARAMS,
)

datamodule = TinyImageNetDatasetModule(
    batch_size = cfg.BATCH_SIZE,
)

wandb_logger = WandbLogger(
    project = cfg.WANDB_PROJECT,
    save_dir = cfg.WANDB_SAVE_DIR,
    entity = cfg.WANDB_ENTITY,
    name = cfg.WANDB_NAME,
)

trainer = Trainer(
    accelerator = cfg.ACCELERATOR,
    devices = cfg.DEVICES,
    precision = cfg.PRECISION_STR,
    max_epochs = cfg.NUM_EPOCHS,
    check_val_every_n_epoch = cfg.VAL_EVERY_N_EPOCH,
    logger = wandb_logger,
    callbacks = [
        LearningRateMonitor(logging_interval='epoch'),
        ModelCheckpoint(save_top_k=1, monitor='accuracy/val', mode='max'),
    ],
)

In [None]:
# Takes time at first (download dataset)
trainer.fit(model, datamodule=datamodule)

In [None]:
trainer.validate(ckpt_path='best', datamodule=datamodule)