# Week 11: Transfer learning

## Don't forget to install skorch and other packages that aren't included by default

In [1]:
!pip install skorch

## See what torchvision has to offer

In [2]:
import torch
import torch.nn
import torch.optim
import numpy as np
import torchvision
import torchvision.models
import torchvision.transforms
import torchvision.datasets
import skorch
import skorch.helper

torch.manual_seed(2)

<torch._C.Generator at 0x224a2f87470>

In [3]:
torchvision.models.list_models()

['alexnet',
 'convnext_base',
 'convnext_large',
 'convnext_small',
 'convnext_tiny',
 'deeplabv3_mobilenet_v3_large',
 'deeplabv3_resnet101',
 'deeplabv3_resnet50',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b3',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_v2_l',
 'efficientnet_v2_m',
 'efficientnet_v2_s',
 'fasterrcnn_mobilenet_v3_large_320_fpn',
 'fasterrcnn_mobilenet_v3_large_fpn',
 'fasterrcnn_resnet50_fpn',
 'fasterrcnn_resnet50_fpn_v2',
 'fcn_resnet101',
 'fcn_resnet50',
 'fcos_resnet50_fpn',
 'googlenet',
 'inception_v3',
 'keypointrcnn_resnet50_fpn',
 'lraspp_mobilenet_v3_large',
 'maskrcnn_resnet50_fpn',
 'maskrcnn_resnet50_fpn_v2',
 'maxvit_t',
 'mc3_18',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'mvit_v1_b',
 'mvit_v2_s',
 'quantized_googlenet',
 '

In [4]:
import subprocess

# Installation on Google Colab
import os
import google.colab
subprocess.run(['mkdir', '-p', 'datasets'])
subprocess.run(['wget', '-nc', '--no-check-certificate', 'https://download.pytorch.org/tutorial/hymenoptera_data.zip', '-P', 'datasets'])
subprocess.run(['unzip', '-u', 'datasets/hymenoptera_data.zip', '-d' 'datasets'])

ModuleNotFoundError: No module named 'google'

### Data augmentation

In [10]:
data_dir = 'datasets/hymenoptera_data'
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])
val_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_ds = torchvision.datasets.ImageFolder(
    os.path.join(data_dir, 'train'), train_transforms)
val_ds = torchvision.datasets.ImageFolder(
    os.path.join(data_dir, 'val'), val_transforms)

In [11]:
train_transforms

Compose(
    RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=warn)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

### Module with pretrained model definition

In [39]:
class PretrainedModel(torch.nn.Module):
    def __init__(self, output_features):
        super().__init__()
        model = torchvision.models.resnet18(
            weights=torchvision.models.ResNet18_Weights.DEFAULT)
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, output_features)  # we've swapped it out
        self.model = model

    def forward(self, x):
        return self.model(x)

### Set up skorch NeuralNetClassifier and callbacks

In [40]:
from skorch.callbacks import Checkpoint

checkpoint = Checkpoint(
    f_params='best_model.pt', monitor='valid_acc_best')

In [41]:
from skorch.callbacks import Freezer

freezer = Freezer(lambda x: not x.startswith('model.fc'))

In [42]:
net = skorch.NeuralNetClassifier(
    PretrainedModel,
    criterion=torch.nn.CrossEntropyLoss,
    lr=0.2,
    batch_size=4,
    max_epochs=15,
    module__output_features=2,
    optimizer=torch.optim.Adam,
    iterator_train__shuffle=True,
    iterator_train__num_workers=2,
    iterator_valid__num_workers=2,
    train_split=skorch.helper.predefined_split(val_ds),
    callbacks=[checkpoint, freezer],
    device='cuda' # comment to train on cpu
)

### Train the model's new FC layer

In [43]:
net.fit(train_ds, y=None)

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1       [36m10.1440[0m       [32m0.9477[0m        [35m2.0972[0m     +  2.5143
      2       30.5673       0.9020        6.0247        2.5552
      3        [36m8.4168[0m       0.9020        6.5705        3.7246
      4       26.1212       0.6340       44.9682        2.9335
      5       19.0596       0.8301       16.9436        2.4438
      6       25.4040       [32m0.9608[0m        4.7514     +  2.4533
      7       20.3301       0.8889        9.9999        2.5219
      8       27.3555       0.9346        6.8567        4.8134
      9        9.6639       0.9477        6.5245        2.8238
     10       12.0767       0.8039       25.2809        2.4685
     11       19.9669       0.9412        6.1619        2.4796
     12       11.4889       0.9542        6.9993        2.5075
     13       12.0116       0.8954        9.9215        4.2529
     14   

<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=PretrainedModel(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)