<a href="https://colab.research.google.com/github/Dimildizio/DS_course/blob/main/Neural_networks/Transfer_learning/imagenette_transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Transfer learning

## Downloading and importing libs

In [1]:
from IPython import display

In [22]:
!pip install -U albumentations
!pip install -q --upgrade wandb
!pip install timm

display.clear_output()

In [5]:
import os
import wandb
import shutil
from pathlib import Path

import numpy as np
import random

import cv2
from tqdm import tqdm
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch.optim.lr_scheduler import StepLR

try:
    from torchinfo import summary
except:
    print("[INFO] Couldn't find torchinfo. Installing...")
    !pip install -q torchinfo
    from torchinfo import summary

import timm

import albumentations as A
import albumentations.pytorch as AP

from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, Resize, RandomCrop,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, RandomBrightnessContrast, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, Flip, OneOf, Compose, Rotate, RandomScale, RandomGridShuffle,
    RandomContrast, RandomGamma, RandomBrightness, CenterCrop, VerticalFlip, ColorJitter,
    ChannelShuffle, InvertImg, RGBShift, ElasticTransform, Equalize, RandomResizedCrop, ChannelDropout
)

import matplotlib.pyplot as plt
%matplotlib inline

[INFO] Couldn't find torchinfo. Installing...


# Setting things up

##Creating config class to utilize global variables

In [18]:
class CFG:

    num_workers=2
    model_name = 'inception_v3'
    size = 224
    scheduler = 'StepLR'
    epochs = 15
    step_size = 20 # StepLR
    gamma = 0.1 # StepLR
    lr = 3e-4
    min_lr = 1e-6  #
    batch_size = 8
    seed = 42

    api = ""
    project = "imagenette"
    entity = "dimildizio"
    wandb = False


    @classmethod
    @ property
    def _get_config(cls):
      return {key:value for key, value in CFG.__dict__.items() if (
                key[:1]!= '_' and key not in ('api', 'project', 'entity', 'wandb'))}


In [7]:
 CFG._get_config

{'num_workers': 2,
 'model_name': 'inception_v3',
 'size': 224,
 'scheduler': 'StepLR',
 'epochs': 15,
 'step_size': 20,
 'gamma': 0.1,
 'lr': 0.0003,
 'min_lr': 1e-06,
 'batch_size': 8,
 'seed': 42}

## Prepare data transformations - augmentation, normalization

### Modify torch.utils.data.Dataset class for augmented data

In [8]:
class MakeDataset(torch.utils.data.Dataset):
    def __init__(self, files, transform=None):
        super().__init__()
        self.files = files
        self.labels = [path.parent.name for path in self.files]
        self.len_ = len(self.files)
        self.transform = transform
        self.classes = ['n01440764', 'n02102040', 'n02979186', 'n03000684', 'n03028079', 'n03394916',
                        'n03417042', 'n03425413', 'n03445777', 'n03888257']

    def __len__(self):
        return self.len_

    def __getitem__(self, index):
        image = cv2.imread(f'{self.files[index]}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = self.labels[index]
        for i in range(len(self.classes)):
          if label == self.classes[i]:
            y = i

        return image, y

### Augmentations, transformations, normalization

In [9]:
def get_transforms(*, data):

    if data == 'train':
        transforma = A.Compose([Resize(256,256), RandomCrop(CFG.size,CFG.size), Rotate(limit = 90, p=0.5),
                                    GaussNoise(p=0.5), A.Sharpen(p=0.5), ChannelShuffle(p=0.5),
                                    HorizontalFlip(p=0.5), A.Normalize(), AP.ToTensorV2()])
    elif data == 'valid':
        transforma = A.Compose([Resize(256,256), CenterCrop(CFG.size,CFG.size), A.Normalize(), AP.ToTensorV2()])
    return transforma

In [10]:
def get_loaders(train_df, val_df):
  train_loader = torch.utils.data.DataLoader(train_df, batch_size=CFG.batch_size,shuffle=True)
  valid_loader = torch.utils.data.DataLoader(val_df, batch_size = CFG.batch_size)

  print('Train and Valid datasets are loaded:\n')
  print('{:<7s}{:>10s}{:>10s}'.format('Dataset', 'Batches', 'Pictures')), print('-' * 28)
  print('{:<7s}{:>10d}{:>10d}'.format('Train', len(train_loader), len(train_df)))
  print('{:<7s}{:>10d}{:>10d}'.format('Valid', len(valid_loader), len(val_df)))
  return train_loader, valid_loader

# Downloading dataset and preparing data

In [11]:
dataset_path = 'imagenette2-160/'
if not os.path.exists(dataset_path):
    !wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz
    !tar zxvf imagenette2-160.tgz

TRAIN_DIR = Path('/content/imagenette2-160/train')
VAL_DIR = Path('/content/imagenette2-160/val')

train_files = sorted(list(TRAIN_DIR.rglob('*.JPEG')))
val_files = sorted(list(VAL_DIR.rglob('*.JPEG')))
display.clear_output()

In [12]:
train_dataset = MakeDataset(train_files, transform = get_transforms(data='train'))
val_dataset  = MakeDataset(val_files, transform = get_transforms(data='valid'))

In [13]:
train_loader, valid_loader = get_loaders(train_dataset, val_dataset)

Train and Valid datasets are loaded:

Dataset   Batches  Pictures
----------------------------
Train        1184      9469
Valid         491      3925


In [15]:
class Performer:
  trainloader = train_loader
  validloader = valid_loader
  num_epochs = CFG.epochs


  @classmethod
  def get_device(cls):
    cuda = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(cuda)
    print(device)
    return device

  @classmethod
  def trainval_hist(cls):
    return {'train':[], 'val':[]}, {'train':[], 'val':[]}

  @classmethod
  def run_wandb(cls):
    if CFG.wandb:
      os.environ['WANDB_API_KEY'] = CFG.api
      wandb.init(project = CFG.project, entity = CFG.entity,
                 name = CFG.model_name, reinit = True, config = CFG._get_config)

  @classmethod
  def log_wandb(cls, name, loss, acc):
    if CFG.wandb:
      wandb.log({name+'_loss': loss, name+'_accuracy': acc})

  @classmethod
  def seed_everything(cls):
    seed = CFG.seed
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] =str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

  @classmethod
  def switch_trainval(cls, model, phase):
    if phase == 'train':
      model.train()
      return cls.trainloader
    else:
      model.eval()
      return cls.validloader

  @classmethod
  def train_val(cls, model, optimizer, loss = nn.CrossEntropyLoss(), scheduler = False):
    #pre-setup
    cls.seed_everything()
    cls.run_wandb()
    device = cls.get_device()
    #set local variables for all epochs
    loss_hist, acc_hist = cls.trainval_hist()
    best_acc = 0.
    best_weights = model.state_dict()

    for epoch in range(cls.num_epochs):
      print(f'Epoch: {epoch+1}/{cls.num_epochs}')
      #each epoch a phase of train and a phase of val
      for phase in ['train', 'val']:
        dataloader = cls.switch_trainval(model, phase)

        model.to(device)
        running_loss, running_acc = 0., 0.

        #set X, y
        for inputs, labels in tqdm(dataloader):
          inputs = inputs.to(device)
          labels = labels.to(device)
          #predict and adjust if train
          optimizer.zero_grad()
          with torch.set_grad_enabled(phase == 'train'):
            y_pred = model(inputs)
            loss_v = loss(y_pred, labels)
            if phase == 'train':
              loss_v.backward()
              optimizer.step()
          #calc train val stats
          running_loss += loss_v.item()
          running_acc += (y_pred.argmax(dim=1) == labels).float().mean().data.cpu().numpy()
        if phase == 'train' and scheduler:
          scheduler.step()
        #calc epoch stats
        epoch_loss = running_loss / len(dataloader)
        epoch_acc = running_acc / len(dataloader)
        cls.log_wandb(phase, epoch_loss, epoch_acc)
        display.clear_output()
        print(f'Epoch:{epoch+1}, loss:{round(epoch_loss,3)}, accuracy: {round(epoch_acc,3)}')
        #select best weights
        if phase == 'val' and epoch_acc > best_acc:
          best_acc = epoch_acc
          best_weights = model.state_dict()
        #write for history
        loss_hist[phase].append(epoch_loss)
        acc_hist[phase].append(epoch_acc)
    torch.save(best_weights, '/content/'+f'{CFG.model_name}_best.pth')
    return model, loss_hist, acc_hist


In [16]:
class SimpleModel(nn.Module):
    def __init__(self, num_classes):
        super(SimpleModel, self).__init__()

        # Define your layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(16 * 224 * 224, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x


In [20]:
CFG.model_name = 'SimpleModel'
CFG.model_name

'SimpleModel'

In [21]:
qq = SimpleModel(10)
Performer.train_val(qq, torch.optim.Adam(qq.parameters(), CFG.lr))

Epoch:14, loss:1.938, accuracy: 0.334


(SimpleModel(
   (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (relu): ReLU()
   (flatten): Flatten(start_dim=1, end_dim=-1)
   (fc): Linear(in_features=802816, out_features=10, bias=True)
 ),
 {'train': [8.482652222133568,
   2.8098356569236196,
   2.3090561161170133,
   2.1948584191803193,
   2.090530858349961,
   2.0608900489436612,
   2.0403361814851695,
   2.0370661336425187,
   2.0051828849355915,
   2.0083602048658036,
   2.01857825373677,
   2.001336330346562,
   1.9791055603103864,
   1.991113272671764,
   1.9940574452079631],
  'val': [2.4046090827939954,
   1.896533402733793,
   1.8073172530428698,
   1.8655767985856704,
   1.9075227787684035,
   1.956791252565481,
   1.9468285604067102,
   1.893396975066657,
   1.9203462089637633,
   1.9391166715903583,
   1.9321798231839649,
   1.9352761017814917,
   1.9393043714725315,
   1.9539880109173218,
   1.9376340775528653]},
 {'train': [0.2341849662212504,
   0.28669763513765223,
   0.322086148668785

### Importing a model and freezing layers

In [None]:
class MyNet:
  pass

### Freezing all layers except the last one

In [None]:
model = MyNet()

In [None]:
last_layer_num = len(list(model.model.parameters()))-2
retrain_layers = list(model.model.parameters())[:last_layer_num] + list(model.parameters())[-2:]

optimizer = torch.optim.AdamW(retrain_layers, lr = CFG.lr)
scheduler = StepLR(optimizer, step_size=CFG.step_size, gamma=CFG.gamma)

In [None]:
Performer.train_val(model, optimizer, scheduler = scheduler)