# Setup

**Download deepdrive_course repository when running in Google Colab (to have access to libraries)**

In [None]:
import sys

in_colab = "google.colab" in sys.modules

if in_colab:
  !git clone https://github.com/abojda/deepdrive_course.git dd_course
  !pip install dd_course/ -q

In [2]:
!python3 -m pip install pytorch-lightning timm opencv-python gdown patool mega.py -q

In [3]:
import timm
import pytorch_lightning as pl

## wandb login

In [None]:
import wandb

wandb.login()

# Config

In [5]:
from deepdrive_course.resisc45.datasets import RESISC45

config = dict(
    project_name = "resisc",
    run_name = "resnet50-tl-ft_onecycle_lr0.0005-drop_0.3-randaugment",

    classes = RESISC45.classes,

    training_type = "full",
    # training_type = "transfer_learning",

    # checkpoint = None,
    # checkpoint = "resnet50-epoch=19-val_loss=0.53.ckpt",
    checkpoint = "https://mega.nz/file/CxtwyC4R#SfyDcxF4CKKhe2LDUT9Ssk3l6zH2Bct9pUhi7PTznjY",  # resnet50-epoch=19-val_loss=0.53.ckpt

    timm_model = "resnet50",
    timm_pretrained = True,
    timm_dropout = 0.3,

    epochs = 20,
    batch_size = 64,
    lr = 5e-4,
    seed = 42,

    optimizer = "Adam",
    # optimizer = "RMSprop",
    optimizer_kwargs = {},

    # scheduler = None,
    # scheduler_interval = "step",
    # scheduler_kwargs = {},

    scheduler = "OneCycleLR",
    scheduler_interval = "step",
    scheduler_kwargs = dict(
        epochs = 20,
        max_lr = 5e-4,
        steps_per_epoch = int(31500 / 64),  # number of batches
    ),

    # train_transform = "albumentations_basic_aug",
    # test_transform = "albumentations_imagenet_norm",

    train_transform = "torchvision_randaugment",
    # train_transform = "torchvision_imagenet_norm",
    test_transform = "torchvision_imagenet_norm",
)


# Train and test transform must be from the same library (torchvision or albumentations)
assert (
    config["train_transform"].startswith("albumentations_")
    and config["test_transform"].startswith("albumentations_")
) or (
    config["train_transform"].startswith("torchvision_")
    and config["test_transform"].startswith("torchvision_")
)

# Set albumentations flag for RESISC45DataModule (modifies __getitem__ to use either cv2 or PIL)
config["albumentations"] = config["train_transform"].startswith("albumentations_")

# Load model

In [6]:
from deepdrive_course.resisc45.modules import ResiscLit
from deepdrive_course.utils import (
    download_from_mega_nz,
    timm_prepare_params_for_training,
)


# Create TIMM model
timm_model = timm.create_model(
    config["timm_model"],
    pretrained=config["timm_pretrained"],
    num_classes=len(RESISC45.classes),
    drop_rate=config["timm_dropout"],
)


# Create ResiscLit (pl.LightningModule)
if config["checkpoint"] == None:
    model = ResiscLit(timm_model, config)
    print("[ResiscLit] No checkpoint - training from scratch")

elif config["checkpoint"].endswith(".ckpt"):
    model = ResiscLit.load_from_checkpoint(
        config["checkpoint"], model=timm_model, config=config
    )

    print(f'[ResiscLit] Loaded local checkpoint: {config["checkpoint"]}')

elif "mega.nz" in config["checkpoint"]:
    checkpoint_path = download_from_mega_nz(config["checkpoint"])

    model = ResiscLit.load_from_checkpoint(
        checkpoint_path, model=timm_model, config=config
    )

    print(f"[ResiscLit] Loaded mega.nz checkpoint: {checkpoint_path}")

else:
    raise ValueError(config["checkpoint"])


# Transfer learning / full training setup
timm_prepare_params_for_training(model.model, config["training_type"])
print(f'Training type: {config["training_type"]}')

Downloading model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

[ResiscLit] Loaded mega.nz checkpoint: resnet50-epoch=19-val_loss=0.53.ckpt
Training type: full


## Print model summary

In [7]:
from deepdrive_course.utils import pl_print_model_summary, pl_find_max_batch_size

pl_print_model_summary(model, depth=1)
# pl_find_max_batch_size(model)

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 23.6 M
---------------------------------
23.6 M    Trainable params
0         Non-trainable params
23.6 M    Total params
94.401    Total estimated model params size (MB)


# Training and validation

## Reproducibility

In [None]:
pl.seed_everything(config["seed"])

## Setup datamodule

In [9]:
from deepdrive_course.resisc45.datamodules import RESISC45DataModule
from deepdrive_course.resisc45.transforms import get_transform
from deepdrive_course.utils import timm_get_pretrained_data_transform
import multiprocessing

train_transform = get_transform(config["train_transform"])
test_transform = get_transform(config["test_transform"])

datamodule = RESISC45DataModule(
    root="data",
    batch_size=config["batch_size"],
    train_transform=train_transform,
    test_transform=test_transform,
    num_workers=multiprocessing.cpu_count(),
    pin_memory=True,
    albumentations=config["albumentations"],
)

## Define callbacks

In [10]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)

early_stopping_cb = EarlyStopping(
    monitor="val_loss", mode="min", patience=10, check_on_train_epoch_end=False
)

checkpoint_cb = ModelCheckpoint(
    monitor="val_loss",
    save_top_k=3,
    dirpath=f'{config["project_name"]}/best/{config["run_name"]}',
    filename="{epoch}-{val_loss:.2f}",
)

lr_monitor_cb = LearningRateMonitor(logging_interval="step")

callbacks = [
    # early_stopping_cb,
    checkpoint_cb,
    lr_monitor_cb,
]

## Training and validation loops

In [None]:
from pytorch_lightning.loggers import WandbLogger

logger = WandbLogger(project=config["project_name"], name=config["run_name"])
logger.experiment.config.update(config)

# Setup summary metrics
logger.experiment.define_metric("val_loss", summary="min")
logger.experiment.define_metric("val_acc", summary="max")
logger.experiment.define_metric("val_f1_score", summary="max")


try:
    trainer = pl.Trainer(
        max_epochs=config["epochs"],
        logger=logger,
        callbacks=callbacks,
        num_sanity_val_steps=0,
    )
    trainer.fit(model, datamodule=datamodule)
finally:
    wandb.finish()