# Setup

**Download and install deepdrive_course repository when running in Google Colab (to have access to the libraries)**

In [None]:
import sys
in_colab = 'google.colab' in sys.modules

if in_colab:
  !git clone https://github.com/abojda/deepdrive_course.git dd_course
  !pip install dd_course/ -q

In [2]:
import timm
import pytorch_lightning as pl

## wandb login

In [None]:
import wandb
wandb.login()

# Config

In [46]:
from torchvision.datasets import STL10

config = dict(
  project_name = 'stl10_supervised',

  run_name = 'scratch-onecycle_lr0.001-randaug_3',
  # run_name = 'simclr_tl-onecycle_lr0.001-randaug_3',

  image_size = 96,
  input_dim = 2048, # Resnet50 features have 2048 dimensions

  timm_model = 'resnet50',
  timm_dropout = 0.3,

  checkpoint = None,
  # checkpoint = 'https://mega.nz/file/47E1gBQD#nlwXN6ygtYUuH6K9RFFgGu7i6pmDagtfE3TTxb7wmJw', # SimCLR checkpoint

  epochs = 100,
  batch_size = 64,
  lr = 1e-3,
  seed=42,

  optimizer = 'Adam',
  # optimizer = 'RMSprop',
  optimizer_kwargs = {},
)

scheduler_config = dict(
  # scheduler = None,
  # scheduler_interval = 'step',
  # scheduler_kwargs = {}

  scheduler = 'OneCycleLR',
  scheduler_interval = 'step',
  scheduler_kwargs = dict(
      epochs = config["epochs"],
      max_lr = config["lr"],
      # steps_per_epoch is updated after training DataLoader instantiation
  ),
)

config.update(**scheduler_config)

# Prepare data

## Define data transforms

In [47]:
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, RandAugment

train_transform = Compose([
    Resize((config['image_size'], config['image_size'])),
    RandAugment(num_ops=3),
    ToTensor(),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

test_transform = Compose([
    Resize((config['image_size'], config['image_size'])),
    ToTensor(),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

## Initialize datasets

In [None]:
from torchvision.datasets import STL10

root = 'stl10_data'

train_ds = STL10(root=root,
                 split='train',
                 transform = train_transform,
                 download=True)

test_ds = STL10(root=root,
                split='test',
                transform = test_transform,
                download=True)

# Update config
config["classes"] = train_ds.classes

## Reproducibility

In [None]:
from pytorch_lightning import seed_everything

seed_everything(config['seed'])

## Initialize dataloader

In [None]:
from torch.utils.data import DataLoader
import multiprocessing

train_dl = DataLoader(train_ds,
                      batch_size=config['batch_size'],
                      shuffle=True,
                      drop_last=False,
                      num_workers=multiprocessing.cpu_count(),
                      pin_memory=True,
                      )

test_dl = DataLoader(test_ds,
                     batch_size=config['batch_size'],
                     shuffle=False,
                     drop_last=False,
                     num_workers=multiprocessing.cpu_count(),
                     pin_memory=True,
                     )


# Update steps_per_epoch in configuration dictionary
config["scheduler_kwargs"]["steps_per_epoch"] = int(len(train_dl) * config["limit_train_batches"])
print(config["scheduler_kwargs"]["steps_per_epoch"])

# Models

## Prepare backbone model

In [51]:
from deepdrive_course.stl10.modules import LitSimCLR, LitClassifier
from deepdrive_course.utils import download_from_mega_nz

def get_model(config):
  if config["checkpoint"] is None:
    # We don't use pretrained model. STL10 dataset contains images from Imagenet, so that would be cheating!
    backbone = timm.create_model(
      config['timm_model'],
      num_classes=0,
      pretrained=False,
      drop_rate=config['timm_dropout'])

  elif config['checkpoint'].endswith('.ckpt'):
    simclr_model = LitSimCLR.load_from_checkpoint(config['checkpoint'])
    backbone = simclr_model.backbone

  elif 'mega.nz' in config['checkpoint']:
    checkpoint = download_from_mega_nz(config['checkpoint'])
    simclr_model = LitSimCLR.load_from_checkpoint(checkpoint)
    backbone = simclr_model.backbone

  else:
    raise ValueError(config['checkpoint'])

  return LitClassifier(backbone, config)


In [52]:
model = get_model(config)

# Training

## Define callbacks

In [53]:
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint

checkpoint_cb = ModelCheckpoint(monitor='val_loss',
                                save_top_k=3,
                                dirpath=f'{config["project_name"]}/best/{config["run_name"]}',
                                filename='{epoch}-{val_loss:.2f}')

lr_monitor_cb = LearningRateMonitor(logging_interval='step')

callbacks = [
    checkpoint_cb,
    lr_monitor_cb,
]

## Training and validation loops

In [None]:
from pytorch_lightning.loggers import WandbLogger

# Define logger
logger = WandbLogger(project=config['project_name'], name=config['run_name'])
logger.experiment.config.update(config)

# Setup summary metrics
logger.experiment.define_metric("val_loss", summary="min")
logger.experiment.define_metric("val_acc", summary="max")
logger.experiment.define_metric("train_loss", summary="min")
logger.experiment.define_metric("train_acc", summary="max")

try:
  trainer = pl.Trainer(
      max_epochs=config['epochs'],
      logger=logger,
      callbacks=callbacks,
      num_sanity_val_steps=0,
  )

  trainer.fit(model, train_dl, test_dl)
finally:
  wandb.finish()