# ⚡ Lightning + WandB - Custom Scratch Models (CNN & Fully Connected)

### Imports

In [2]:
!pip -qqq install wandb pytorch-lightning torchmetrics

import wandb
from pytorch_lightning.loggers import WandbLogger

wandb.login()

try:
  import lightning.pytorch as pl
except:
  print("[!] Couldn't find pytorch-lightning.\nInstalling it...\n")
  !pip install lightning -Uq
  import lightning.pytorch as pl

from lightning.pytorch.utilities.model_summary import ModelSummary
from pytorch_lightning import seed_everything

import os
import shutil
import pathlib

from PIL import Image
import numpy as np
import cv2 as cv
import random
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets

import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Resize

try:
  import albumentations as A
  from albumentations.pytorch import ToTensorV2
except:
  print("[!] Couldn't find albumentations... installing it.")
  !pip install -Uq albumentations
  import albumentations as A
  from albumentations.pytorch import ToTensorV2

try:
  import torchmetrics
except:
  print(f"[!] Torchmetrics couldn't be imported.\nInstalling...")
  !pip install torchmetrics -Uq

[!] Couldn't find pytorch-lightning.
Installing it...



TypeError: Argument 'bases' has incorrect type (expected list, got tuple)

### Custom Utils

In [None]:
# Folder Utilities ----------------------------

## Create dir if it doesn't exist
def create_dir(dir_name):
  if not os.path.exists(f'{dir_name}'):
    os.mkdir(f'{dir_name}')

## Delete dir: checkpoints
def delete_dir(dir_name):
  if os.path.isdir(f'{dir_name}'):
    shutil.rmtree(f'{dir_name}')

# ---------------------------------------------

### Transforms

In [None]:
# TODO: For experiments later on!
# train_transform = A.Compose(
#     [
#         A.SmallestMaxSize(max_size=160),
#         A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
#         A.RandomCrop(height=128, width=128),
#         A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
#         A.RandomBrightnessContrast(p=0.5),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
#         ToTensorV2(),
#     ]
# )
train_transform = T.Compose([T.ToTensor(), T.Resize((224, 224))])
test_transform = T.Compose([T.ToTensor(), T.Resize((224, 224))])

### Dataset & DataLoader

In [None]:
train_data = datasets.ImageFolder(root='dataset_FER/train/', transform=train_transform)
test_data = datasets.ImageFolder(root='dataset_FER/test/', transform=test_transform)

### Config

In [None]:
CONFIG = {
    'mode': 'train',
    'train_path': 'dataset_FER/train/',
    'test_path': 'dataset_FER/test/',
    'epochs': 100,
    'batch_size': 64,
    'lr': 0.001,
    'num_workers': 4,
    'device': 'cuda',
    'device_ids': [0,1],
    'load_model': False,
    'load_path': 'models/epoch_100.pth',
    'checkpoint_path': 'ckpts/',
    'save_every': 10,
}

### LIT Trainer

In [None]:
class LIT_Scratch(pl.LightningModule):
  
  def __init__(self, model):
    super().__init__()


  def configure_optimizers(self):
    pass


  def forward(self, z):
    return self.generator(z)


  def validation_step(self, batch, batch_idx):
    imgs, _ = batch
    pass


  def training_step(self, batch, batch_idx):
    imgs, _ = batch
    pass

In [None]:
from pytorch_lightning import Callback
from lightning.pytorch.callbacks import DeviceStatsMonitor, TQDMProgressBar, ModelCheckpoint, EarlyStopping, LearningRateMonitor

# Checkpoint
checkpoint_callback = ModelCheckpoint(dirpath=CONFIG['checkpoint_path'],
                                      filename='{epoch}-{val_g_loss:.3f}',
                                      monitor='val_g_loss',
                                      save_top_k=-1,
                                      save_last=True,
                                      save_weights_only=True,
                                      verbose=True,
                                      mode='min')

# Exp2: Learning Rate Monitor
lr_monitor = LearningRateMonitor(logging_interval='step', log_momentum=True)

# Earlystopping
# earlystopping = EarlyStopping(monitor='val_d_acc', patience=3, mode='min')

In [None]:
wandb_logger = WandbLogger(project='MMU-FER', 
                           name='FCN-hs-',
                           config=CONFIG,
                           job_type='train_val',
                           log_model="all")

In [None]:
# train model

trainer = pl.Trainer(fast_dev_run=False,    # For debugging purposes
                     log_every_n_steps=1,   # set the logging frequency
                     accelerator='auto',    # Precedence: tpu > gpu >> cpu
                     devices="auto",        # all
                     max_epochs=CONFIG['NUM_EPOCHS'], # number of epochs
                     callbacks=[TQDMProgressBar(refresh_rate=25), 
                                checkpoint_callback, 
                                lr_monitor],
                     logger=wandb_logger,    # wandb <3
                     )

In [None]:
# Train model:
trainer.fit(dcgan, train_dataloader, val_dataloader)

In [None]:
wandb.finish()