# Setup

**Download and install deepdrive_course repository when running in Google Colab (to have access to the libraries)**

In [None]:
import sys

in_colab = "google.colab" in sys.modules

if in_colab:
    !git clone https://github.com/abojda/deepdrive_course.git dd_course
    !pip install dd_course/ -q

In [None]:
!pip install lightly

In [2]:
import timm
import pytorch_lightning as pl

## wandb login

In [None]:
import wandb

wandb.login()

# Config

In [4]:
config = dict(
    project_name="stl10_ssl",
    run_name="simclr-onecycle_lr0.004",
    
    image_size=96,
    input_dim=2048,  # Resnet50 features have 2048 dimensions
    hidden_dim=2048,
    output_dim=128,

    timm_model="resnet50",
    timm_dropout=0.3,

    epochs=30,
    batch_size=64,
    lr=4e-3,
    seed=42,

    optimizer="Adam",
    # optimizer = "RMSprop",
    optimizer_kwargs={},
)

scheduler_config = dict(
    # scheduler = None,
    # scheduler_interval = "step",
    # scheduler_kwargs = {}

    scheduler="OneCycleLR",
    scheduler_interval="step",
    scheduler_kwargs=dict(
        epochs=config["epochs"],
        max_lr=config["lr"],
        # steps_per_epoch is updated after training DataLoader instantiation
    ),
)

config.update(**scheduler_config)

# Prepare data

## Initialize datasets

In [5]:
from torchvision.datasets import STL10
from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction

root = "stl10_data"

# Torchvision datasets
unlabeled_ds_base = STL10(root=root, split="unlabeled", download=True)

test_ds_base = STL10(root=root, split="test", download=True)

# SimCLR dataset
unlabeled_ds_simclr = LightlyDataset.from_torch_dataset(unlabeled_ds_base)
test_ds_simclr = LightlyDataset.from_torch_dataset(test_ds_base)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to stl10_data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [02:14<00:00, 19684917.03it/s]


Extracting stl10_data/stl10_binary.tar.gz to stl10_data
Files already downloaded and verified


## Reproducibility

In [None]:
from pytorch_lightning import seed_everything

seed_everything(config["seed"])

## Initialize dataloader
SimCLRCollateFunction by default performs Imagenet normalization

In [None]:
from torch.utils.data import DataLoader
import multiprocessing

collate_fn_simclr = SimCLRCollateFunction(
    input_size=config["image_size"], vf_prob=0.5, rr_prob=0.5
)

unlabeled_dl_simclr = DataLoader(
    unlabeled_ds_simclr,
    batch_size=config["batch_size"],
    collate_fn=collate_fn_simclr,
    shuffle=True,
    drop_last=True,
    num_workers=multiprocessing.cpu_count(),
    pin_memory=True,
)

test_dl_simclr = DataLoader(
    test_ds_simclr,
    batch_size=config["batch_size"],
    collate_fn=collate_fn_simclr,
    shuffle=False,
    drop_last=True,
    num_workers=multiprocessing.cpu_count(),
    pin_memory=True,
)

# Update steps_per_epoch in configuration dictionary
config["scheduler_kwargs"]["steps_per_epoch"] = len(unlabeled_dl_simclr)
print(config["scheduler_kwargs"]["steps_per_epoch"])

# Models

## Instantiate model

In [None]:
from deepdrive_course.stl10.modules import LitSimCLR

# We don't use pretrained model. STL10 dataset contains images from Imagenet, so that would be cheating!

backbone = timm.create_model(
    config["timm_model"],
    num_classes=0,
    pretrained=False,
    drop_rate=config["timm_dropout"],
)

simclr_model = LitSimCLR(backbone, config)

# Training

## Define callbacks

In [9]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)

checkpoint_cb = ModelCheckpoint(
    monitor="val_loss_ssl",
    save_top_k=3,
    dirpath=f'{config["project_name"]}/best/{config["run_name"]}',
    filename="{epoch}-{val_loss_ssl:.2f}",
)

lr_monitor_cb = LearningRateMonitor(logging_interval="step")

callbacks = [
    checkpoint_cb,
    lr_monitor_cb,
]

## Training and validation loops

In [None]:
from pytorch_lightning.loggers import WandbLogger

# Define logger
logger = WandbLogger(project=config["project_name"], name=config["run_name"])
logger.experiment.config.update(config)

# Setup summary metrics
logger.experiment.define_metric("train_loss_ssl", summary="min")
logger.experiment.define_metric("val_loss_ssl", summary="min")

try:
    trainer = pl.Trainer(
        max_epochs=config["epochs"],
        logger=logger,
        callbacks=callbacks,
        num_sanity_val_steps=0,
    )

    trainer.fit(simclr_model, unlabeled_dl_simclr, test_dl_simclr)
finally:
    wandb.finish()