# Setup

**Download deepdrive_course repository when running in Google Colab (to have access to libraries)**

In [None]:
import sys

in_colab = "google.colab" in sys.modules

if in_colab:
  !git clone https://github.com/abojda/deepdrive_course.git dd_course
  !pip install dd_course/ -q

In [None]:
!python3 -m pip install pytorch-lightning timm opencv-python gdown patool optuna mega.py -q

In [3]:
import timm
import pytorch_lightning as pl

## wandb login

In [None]:
import wandb

wandb.login()

## Setup model

In [38]:
from deepdrive_course.resisc45.modules import ResiscLit
from deepdrive_course.utils import (
    download_from_mega_nz,
    timm_prepare_params_for_training,
)


def get_model(config):
    # Create TIMM model
    timm_model = timm.create_model(
        config["timm_model"],
        pretrained=config["timm_pretrained"],
        num_classes=len(RESISC45.classes),
        drop_rate=config["timm_dropout"],
    )

    # Create ResiscLit (pl.LightningModule)
    if config["checkpoint"] == None:
        model = ResiscLit(timm_model, config)
        print("[ResiscLit] No checkpoint - training from scratch")

    elif config["checkpoint"].endswith(".ckpt"):
        model = ResiscLit.load_from_checkpoint(
            config["checkpoint"], model=timm_model, config=config
        )

        print(f'[ResiscLit] Loaded local checkpoint: {config["checkpoint"]}')

    elif "mega.nz" in config["checkpoint"]:
        checkpoint_path = download_from_mega_nz(config["checkpoint"])

        model = ResiscLit.load_from_checkpoint(
            checkpoint_path, model=timm_model, config=config
        )

        print(f"[ResiscLit] Loaded mega.nz checkpoint: {checkpoint_path}")

    else:
        raise ValueError(config["checkpoint"])

    # Transfer learning / full training setup
    timm_prepare_params_for_training(model.model, config["training_type"])
    print(f'Training type: {config["training_type"]}')

    return model

## Setup datamodule

In [6]:
from deepdrive_course.resisc45.datamodules import RESISC45DataModule
from deepdrive_course.resisc45.transforms import get_transform
from deepdrive_course.utils import timm_get_pretrained_data_transform


def get_datamodule(config):
    train_transform = get_transform(config["train_transform"])
    test_transform = get_transform(config["test_transform"])

    datamodule = RESISC45DataModule(
        root="data",
        batch_size=config["batch_size"],
        train_transform=train_transform,
        test_transform=test_transform,
        download=False,
        albumentations=config["albumentations"],
    )

    return datamodule

## Setup Logger

In [7]:
from pytorch_lightning.loggers import WandbLogger


def get_wandb_logger(config):
    logger = WandbLogger(project=config["project_name"], name=config["run_name"])
    logger.experiment.config.update(config)

    # Setup summary metrics
    logger.experiment.define_metric("val_loss", summary="min")
    logger.experiment.define_metric("val_acc", summary="max")
    logger.experiment.define_metric("val_f1_score", summary="max")

    return logger

## Setup Trainer

In [8]:
def get_trainer(config, logger, callbacks):
    trainer = pl.Trainer(
        max_epochs=config["epochs"],
        logger=logger,
        callbacks=callbacks,
        limit_train_batches=config["limit_train_batches"],
        limit_val_batches=config["limit_val_batches"],
    )

    return trainer

## Setup Optuna objective function

In [39]:
from pytorch_lightning.callbacks import LearningRateMonitor
from optuna.integration import PyTorchLightningPruningCallback
from deepdrive_course.pl_callbacks import CollectValidationMetrics
from deepdrive_course.resisc45.datasets import RESISC45


def objective(trial):
    config  =  dict(
        project_name = "resisc-optuna",
        run_name = f"resnet50-ft-optuna_{trial.number}",

        classes = RESISC45.classes,

        training_type = "full",
        checkpoint = "resnet50-epoch = 19-val_loss = 0.53.ckpt",

        timm_model = "resnet50",
        timm_pretrained = False,
        timm_dropout = trial.suggest_float("dropout", 0.0, 0.7),

        # Study only on part of the dataset for faster training
        limit_train_batches = 0.1,
        limit_val_batches = 0.1,

        epochs = 25,
        batch_size = 64,
        lr = trial.suggest_float("lr", 1e-6, 1e-2, log = True),

        optimizer = trial.suggest_categorical("optimizer", ["Adam", "SGD", "RMSprop"]),
        optimizer_kwargs = {},

        scheduler = None,
        scheduler_interval = "epoch",
        scheduler_kwargs = {},

        train_transform = "albumentations_basic_aug",
        test_transform = "albumentations_imagenet_norm",
        albumentations = True,
    )

    collect_val_loss = CollectValidationMetrics("val_loss")

    callbacks = [
        LearningRateMonitor(logging_interval="step"),
        PyTorchLightningPruningCallback(trial, monitor="val_loss"),
        collect_val_loss,
    ]

    model = get_model(config)
    logger = get_wandb_logger(config)
    datamodule = get_datamodule(config)
    trainer = get_trainer(config, logger, callbacks)

    try:
        trainer.fit(model, datamodule=datamodule)
    except optuna.TrialPruned:
        wandb.finish()
        raise

    wandb.finish()

    # return trainer.callback_metrics["val_loss"].item()
    return min(collect_val_loss.metric_history)

## Setup Optuna helpers

In [40]:
import pickle


def save_sampler(sampler, file):
    with open(file, "wb") as f:
        pickle.dump(sampler, f)


def load_sampler(file):
    return pickle.load(open(file, "rb"))


class SaveSamplerToPickleCallback:
    def __init__(self, file):
        self.file = file

    def __call__(self, study, trial):
        save_sampler(study.sampler, self.file)

## Download checkpoint

In [None]:
from deepdrive_course.utils import download_from_mega_nz

download_from_mega_nz(
    "https://mega.nz/file/CxtwyC4R#SfyDcxF4CKKhe2LDUT9Ssk3l6zH2Bct9pUhi7PTznjY"
)  # resnet50-epoch=19-val_loss=0.53.ckpt

## Download dataset (done here to do it only once)

In [None]:
import os

if not os.path.isdir("data"):
    RESISC45DataModule(root="data", batch_size=1, download=True).prepare_data()
else:
    print("Already downloaded...")

## Reproducibility

In [None]:
pl.seed_everything(42)

# Run study

In [None]:
import optuna
import os

n_trials = 40
study_name = "resnet50-ft"

# storage = f'sqlite:////content/drive/MyDrive/Colab Notebooks/lib/optuna/{study_name}.db'
# sampler_file = f'/content/drive/MyDrive/Colab Notebooks/lib/optuna/{study_name}_sampler.pkl'
storage = f"sqlite:///{study_name}.db"
sampler_file = f"{study_name}_sampler.pkl"

sampler = load_sampler(sampler_file) if os.path.isfile(sampler_file) else None

study = optuna.create_study(
    study_name=study_name, storage=storage, load_if_exists=True, sampler=sampler
)

optuna_callbacks = [
    SaveSamplerToPickleCallback(study_name),
]

study.optimize(objective, n_trials=n_trials, callbacks=optuna_callbacks)

save_sampler(study.sampler, study_name)

## Study summary

In [None]:
print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))