# ⚡ Lightning + WandB - Custom Scratch Models (CNN & Fully Connected)

### Imports

In [1]:
import os
import shutil
import pathlib

from PIL import Image
import numpy as np
import cv2 as cv
import random
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets

import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Resize

import wandb
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger

import torchmetrics


In [2]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33maryangarg019[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
seed_everything(42)

Global seed set to 42


42

### Custom Utils

In [4]:
# Folder Utilities ----------------------------

## Create dir if it doesn't exist
def create_dir(dir_name):
  if not os.path.exists(f'{dir_name}'):
    os.mkdir(f'{dir_name}')

## Delete dir: checkpoints
def delete_dir(dir_name):
  if os.path.isdir(f'{dir_name}'):
    shutil.rmtree(f'{dir_name}')

# ---------------------------------------------

### Transforms

In [5]:
# TODO: For experiments later on!
# train_transform = A.Compose(
#     [
#         A.SmallestMaxSize(max_size=160),
#         A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
#         A.RandomCrop(height=128, width=128),
#         A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
#         A.RandomBrightnessContrast(p=0.5),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#         ToTensorV2(),
#     ]
# )
train_transform = T.Compose([T.ToTensor(), T.Resize((224, 224))])
test_transform = T.Compose([T.ToTensor(), T.Resize((224, 224))])

### Config

In [6]:
CONFIG = {
    'mode': 'train',
    'train_path': 'dataset_FER/train/',
    'test_path': 'dataset_FER/test/',
    'epochs': 10,
    'batch_size': 16,
    'lr': 0.001,
    'num_workers': 4,
    'device': 'cuda',
    'device_ids': [0,1],
    'load_model': False,
    'checkpoint_path': 'ckpts/scratch_CNN/',
    'save_every': 10,
    'device': 'cuda:1' if torch.cuda.is_available() else 'cpu',
}

### Dataset & DataLoader

In [7]:
train_data = datasets.ImageFolder(root='dataset_FER/train/', transform=train_transform)
test_data = os.listdir('dataset_FER/test/')

In [8]:
trainLoader = DataLoader(train_data, batch_size=16, shuffle=True, num_workers=4)

In [9]:
sample, y = next(iter(trainLoader))
sample.shape, y

(torch.Size([16, 3, 224, 224]),
 tensor([4, 5, 0, 4, 2, 2, 4, 3, 2, 5, 3, 3, 2, 3, 6, 3]))

### Models

In [10]:
from models.scratch_cnn import ScratchCNN
from models.scratch_fully_connected import FullyConnected

In [11]:
# cnn = ScratchCNN()
# cnn.to('cuda')

# from torchsummary import summary
# summary(cnn, (3, 224, 224))

In [12]:
cnn = ScratchCNN()

### LIT Trainer

In [13]:
class LIT_Scratch(pl.LightningModule):
  
  def __init__(self, model):
    super().__init__()
    self.model = model
    self.save_hyperparameters()

    self.roc = torchmetrics.ROC(task='multiclass', num_classes=7)
    self.acc = torchmetrics.Accuracy(task='multiclass', num_classes=7)
    self.auroc = torchmetrics.classification.MulticlassAUROC(num_classes=7)
    self.f1 = torchmetrics.F1Score(task='multiclass', num_classes=7)


  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4, weight_decay=1e-3, betas=(0.9, 0.999), eps=1e-8)
    return optimizer


  def forward(self, z):
    return self.model(z)


  def training_step(self, batch, batch_idx):
    imgs, y = batch
    y_hat = self(imgs)
    loss = F.cross_entropy(y_hat, y)

    self.log('train_loss_CE', loss, prog_bar=True)
    
    # np_y = y.clone().detach().view(-1).cpu().numpy()
    # np_yhat = y_hat.clone().detach().view(-1).cpu().numpy()

    # metrics
    preds = torch.argmax(y_hat, dim=1)
    self.log('train_acc', self.acc(preds, y), prog_bar=True)
    self.log('train_F1', self.f1(preds, y), prog_bar=False)

    self.roc.update(y_hat, y)
    # print(f"{batch_idx+1} ROC --> FPR: {fpr} | TPR: {tpr}")

    # self.log({"ROC" : wandb.plot.roc_curve(np_y, np_yhat, \
    #   labels=["Angry","Disgust","Fear","Happy","Neutral","Sad","Surprise"], classes_to_plot=None)}
    #   , prog_bar=False)
    
    # wandb.sklearn.plot_confusion_matrix(np_y, np_yhat, ["Angry","Disgust","Fear","Happy","Neutral","Sad","Surprise"]) 
    
    # self.log({"Precision-Recall": wandb.plot.pr_curve(np_y, np_yhat, \
    #   labels=["Angry","Disgust","Fear","Happy","Neutral","Sad","Surprise"], classes_to_plot=None)})

    self.log('train_AUROC', self.auroc(y_hat, y), prog_bar=True)

    return loss
    

  def on_train_epoch_end(self):
    roc = self.roc.compute()
    # print( f"roc: {roc}")
    self.roc.reset()

In [14]:
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import DeviceStatsMonitor, TQDMProgressBar, ModelCheckpoint, EarlyStopping, LearningRateMonitor

# Checkpoint
checkpoint_callback = ModelCheckpoint(dirpath=CONFIG['checkpoint_path'],
                                      filename='{epoch}-{train_loss_CE:.2f}',
                                      monitor='train_loss_CE',
                                      save_top_k=-1,
                                      save_last=True,
                                      save_weights_only=True,
                                      verbose=True,
                                      mode='min')

# Exp2: Learning Rate Monitor
lr_monitor = LearningRateMonitor(logging_interval='step', log_momentum=True)

# Earlystopping
# earlystopping = EarlyStopping(monitor='train_loss_CE', patience=3, mode='min')

In [15]:
wandb_logger = WandbLogger(project='MMU-FER', 
                           name='CNN_fixed_10epochs',
                           config=CONFIG,
                           job_type='train',
                           log_model="all")

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670011666913828, max=1.0…

In [16]:
# train model

trainer = pl.Trainer(fast_dev_run=False,            # For debugging purposes
                    log_every_n_steps=1,           # set the logging frequency
                    accelerator='gpu',            # Precedence: tpu > gpu >> cpu
                    devices=[1],                       # all
                    # strategy="ddp_notebook",       # distributed data parallel
                    max_epochs= CONFIG['epochs'],   # number of epochs
                    precision=16,
                    callbacks=[TQDMProgressBar(refresh_rate=25), 
                               checkpoint_callback, 
                               lr_monitor],
                    logger=wandb_logger,           # wandb <3
                    benchmark=True)

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [17]:
# Train model:
model = LIT_Scratch(cnn)
trainer.fit(model, trainLoader)

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type               | Params
---------------------------------------------
0 | model | ScratchCNN         | 51.9 M
1 | roc   | MulticlassROC      | 0     
2 | acc   | MulticlassAccuracy | 0     
3 | auroc | MulticlassAUROC    | 0     
4 | f1    | MulticlassF1Score  | 0     
---------------------------------------------
51.9 M    Trainable params
0         Non-trainable params
51.9 M    Total params
207.733   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]



### Testing

In [None]:
model_ckpt_path = os.path.join(os.getcwd(), 'ckpts', 'scratch_CNN', )

In [None]:
wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
lr-AdamW,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr-AdamW-momentum,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_AUROC,▃▄▅▃▂▂▅▃▅▃▆▃▅▃▂▆▅▆▆█▄▃▃▄▅▅▃▅▄▅▆▅▂▂▄▄▃▅▁▁
train_F1,▃▄▃▇▅▅▂▆▄▃▅▄▃▂█▂▅▄▂▆▁▇▅▃▃▂▄▄▃▄▄▄▂▆▅▄▄▆▇▇
train_acc,▃▄▃▇▅▅▂▆▄▃▅▄▃▂█▂▅▄▂▆▁▇▅▃▃▂▄▄▃▄▄▄▂▆▅▄▄▆▇▇
train_loss_CE,▆▅▆▂▄▄▇▃▅▆▄▅▆▇▁▇▄▅▇▃█▂▄▆▆▇▅▅▆▅▅▅▇▃▄▅▅▃▂▂
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,9.0
lr-AdamW,0.0003
lr-AdamW-momentum,0.9
train_AUROC,0.39484
train_F1,0.3
train_acc,0.3
train_loss_CE,1.86542
trainer/global_step,22279.0


: 