# Setup

**Download deepdrive_course repository when running in Google Colab (to have access to libraries)**

In [None]:
import sys

in_colab = "google.colab" in sys.modules

if in_colab:
  !git clone https://github.com/abojda/deepdrive_course.git dd_course
  !pip install dd_course/ -q

In [15]:
!python3 -m pip install einops wandb pytorch-lightning seaborn matplotlib -qU

## wandb login

In [None]:
import wandb

wandb.login()

## Main imports

In [17]:
import torch
from torch import nn
import pytorch_lightning as pl

# Config

In [18]:
config = dict(
    project_name="quickdraw10",
    run_name="cnnmed-maxpool2-dropout_0.3",

    image_size = (28, 28),
    classes = [
        "banana",
        "baseball bat",
        "carrot",
        "clarinet",
        "crayon",
        "pencil",
        "boomerang",
        "hockey stick",
        "fork",
        "knife",
    ],

    epochs = 20,
    batch_size = 64,
    lr = 1e-3,
    seed = 42,

    optimizer = "Adam",
    optimizer_kwargs = {},
    scheduler = None,
    scheduler_kwargs = {},
)

# Training and validation

## Initialize model

In [19]:
from deepdrive_course.quickdraw.models import (
    CNN,
    CNN_MaxPool2_Dropout,
    CNNMed_MaxPool2_Dropout,
)
from deepdrive_course.quickdraw.modules import QuickdrawLit
from pytorch_lightning.utilities.model_summary import ModelSummary

# Select model
# model = CNN()
# model = CNN_MaxPool2_Dropout(0.3)
model = CNNMed_MaxPool2_Dropout(0.3)

# Wrap model in pl.LightningModule
model = QuickdrawLit(model, config)

summary = ModelSummary(model, max_depth=-1)
print(summary)
# print(model)

   | Name           | Type                    | Params
------------------------------------------------------------
0  | model          | CNNMed_MaxPool2_Dropout | 167 K 
1  | model.model    | Sequential              | 167 K 
2  | model.model.0  | Conv2d                  | 80    
3  | model.model.1  | ReLU                    | 0     
4  | model.model.2  | Conv2d                  | 2.3 K 
5  | model.model.3  | ReLU                    | 0     
6  | model.model.4  | MaxPool2d               | 0     
7  | model.model.5  | Dropout                 | 0     
8  | model.model.6  | Conv2d                  | 37.0 K
9  | model.model.7  | ReLU                    | 0     
10 | model.model.8  | Flatten                 | 0     
11 | model.model.9  | Linear                  | 128 K 
12 | model.model.10 | LogSoftmax              | 0     
------------------------------------------------------------
167 K     Trainable params
0         Non-trainable params
167 K     Total params
0.670     Total estimated m

## Reproducibility

In [None]:
pl.seed_everything(config["seed"])

## Setup datasets and dataloaders

In [21]:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from deepdrive_course.quickdraw.datasets import QuickdrawDatasetInMemory
from deepdrive_course.utils import stratified_train_test_split

class_names = [
    "banana",
    "baseball bat",
    "carrot",
    "clarinet",
    "crayon",
    "pencil",
    "boomerang",
    "hockey stick",
    "fork",
    "knife",
]

transform = ToTensor()

full_ds = QuickdrawDatasetInMemory(
    root="data", classes=class_names, transform=transform
)
train_ds, val_ds, _, _ = stratified_train_test_split(full_ds, train_size=0.8)

train_dl = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)

val_dl = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False)

## Define callbacks

In [22]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)

early_stopping_cb = EarlyStopping(
    monitor="val_loss", mode="min", patience=10, check_on_train_epoch_end=False
)

checkpoint_cb = ModelCheckpoint(
    monitor="val_loss",
    save_top_k=3,
    dirpath=f'{config["project_name"]}/best/{config["run_name"]}',
    filename="{epoch}-val_loss{val_loss:.2f}",
)

lr_monitor_cb = LearningRateMonitor(logging_interval="step")

callbacks = [
    # early_stopping_cb,
    checkpoint_cb,
    lr_monitor_cb,
]

## Training and validation loops

In [None]:
from pytorch_lightning.loggers import WandbLogger

logger = WandbLogger(project=config["project_name"], name=config["run_name"])
logger.experiment.config.update(config)
logger.experiment.config["model"] = model
# logger.watch(model) # log gradients and model topology

try:
    trainer = pl.Trainer(
        max_epochs=config["epochs"],
        logger=logger,
        callbacks=callbacks,
        num_sanity_val_steps=0,
    )
    trainer.fit(model, train_dl, val_dl)
finally:
    wandb.finish()