In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

import torch
import torch.nn as nn
import lightning as pl
from lightning.pytorch.tuner import Tuner
from torchvision.transforms import v2
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from clearml import Task
from omegaconf import OmegaConf
from torch.optim.lr_scheduler import CosineAnnealingLR
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from torchmetrics.classification import Accuracy, Precision, ConfusionMatrix

from weather_classification import PROJ_ROOT, CONFIG_DIR, PROCESSED_DATA_DIR
from weather_classification.pl_data import WeatherDataModule
from weather_classification.pl_model import LiEfficientNet
from weather_classification.custom_callbacks import CustomTensorBoardLogger, CustomTQDMProgressBar
from weather_classification.utils import get_params_num, load_weights_lt_model

In [None]:
# Load configuration parameters
config_fname = "train.yaml"
config_fpath = CONFIG_DIR / config_fname

cfg = OmegaConf.load(config_fpath)

In [None]:
if cfg.seed_everything_enable:
    pl.seed_everything(seed=cfg.seed, workers=cfg.seed_workers)

#### ClearML

In [None]:
if cfg.clearml_enable:
    reuse_last_task_id = False
    if cfg.resume_train:
        reuse_last_task_id = True

    task = Task.init(
        project_name=cfg.clearml_proj_name,
        task_name=cfg.task_name,
        reuse_last_task_id=reuse_last_task_id,
    )

#### Dataset

In [None]:
dataset_dpath = PROCESSED_DATA_DIR / cfg.dataset_dname
train_transforms = v2.Compose([
    v2.Resize((256, 256), interpolation=v2.InterpolationMode.BICUBIC),
    v2.CenterCrop((224, 224)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_test_transforms = v2.Compose([
    v2.Resize((256, 256), interpolation=v2.InterpolationMode.BICUBIC),
    v2.CenterCrop((224, 224)),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dm = WeatherDataModule(
    data_dir=dataset_dpath,
    batch_size=cfg.batch_size,
    train_transforms=train_transforms,
    val_transforms=val_test_transforms,
    num_workers=cfg.num_workers,
    persistent_workers=cfg.persistent_workers,
)
dm.setup(stage="fit")

#### Model

In [None]:
efficientnet_model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)

for param in efficientnet_model.parameters():
    param.requires_grad = False

fc_layer = nn.Linear(efficientnet_model.classifier[1].in_features, dm.num_cls)
efficientnet_model.classifier[1] = fc_layer

print(f"Total parameters: {get_params_num(efficientnet_model)}")
print(f"Trainable parameters: {get_params_num(efficientnet_model, with_grad=True)}")

#### Loss, Optimizer, Scheduler

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(efficientnet_model.parameters(), lr=cfg.lr)

lr_scheduler = None
if cfg.lr_scheduler_enable:
    lr_scheduler = CosineAnnealingLR(
        optimizer,
        T_max=cfg.num_epochs,
    )

#### Metrics

In [None]:
metrics = nn.ModuleDict({
    "accuracy": Accuracy(task="multiclass", num_classes=dm.num_cls),
    "precision": Precision(task="multiclass", num_classes=dm.num_cls),
    "conf_matrix": ConfusionMatrix(task="multiclass", num_classes=dm.num_cls, normalize="true"),
})

#### Lightning Model

In [None]:
model = LiEfficientNet(
    model=efficientnet_model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    metrics=metrics,
    lr_scheduler=lr_scheduler,
)

#### Callbacks

In [None]:
profiler = None
if cfg.profiler_enable:
    profiler = "simple"

save_dir = PROJ_ROOT / cfg.save_dir
tb_logger = CustomTensorBoardLogger(save_dir=save_dir)

callbacks = [
    EarlyStopping(monitor="val_loss", patience=cfg.early_stopping_patience),
    CustomTQDMProgressBar(leave=True),
    LearningRateMonitor(logging_interval="epoch"),
    ModelCheckpoint(monitor="val_loss", filename="best", save_last=True),
]

#### Trainer

In [None]:
trainer = pl.Trainer(
    max_epochs=cfg.num_epochs,
    # accelerator=device.type,
    deterministic=cfg.deterministic,
    profiler=profiler,
    logger=tb_logger,
    callbacks=callbacks,
    # fast_dev_run=True,
    # limit_train_batches=0.2,
)
tb_logger.trainer = trainer

#### Learning Rate Finder

In [None]:
# tuner = Tuner(trainer)
# tuner.lr_find(model, datamodule=dm)
# tuner.scale_batch_size(model, datamodule=dm)

#### Train && Validate

In [None]:
ckpt_path = None
if cfg.resume_train:
    ckpt_path = PROJ_ROOT / cfg.model_fpath

trainer.fit(
    model=model,
    datamodule=dm,
    ckpt_path=ckpt_path,
)

#### Test

In [None]:
trainer.test(model, datamodule=dm)

if cfg.clearml_enable:
    task.close()

#### Export Model

In [None]:
trained_model_fpath = PROJ_ROOT / "lightning_logs/lightning_logs/6 epochs/checkpoints/best.ckpt"
model.load_model_weights(trained_model_fpath)

In [None]:
script = model.to_torchscript()
torch.jit.save(script, "efficientnet.torchscript")