# Transfer Learning

The process of transfer learning involves taking a pre-trained model and adapting the model to a new, different data set. In this notebook, we will demonstrate how to use transfer learning to train a model to perform image classification on a data set that is different from the data set on which the pre-trained model was trained. 


Transfer learning is really useful when we have a small dataset to train against, and the pre-trained model has been trained on a larger dataset because a small dataset will memorize the data quickly and not work on the new data.

In the previous notebook, we trained a model on the vgg16 model or animal images, we will use the same model to train on the new images. 

In [6]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms
import torchvision.io as tv_io
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights
import sys
import pathlib
import glob
import json
from PIL import Image


sys.path.append("../")
import Utils

device = Utils.get_device()

DATASET_LOCATION = ""
if Utils.in_lab():
    DATASET_LOCATION = "/transfer/pokemon/"
else:
    DATASET_LOCATION = "./pokemon/"

print(f"Dataset location: {DATASET_LOCATION}")
pathlib.Path(DATASET_LOCATION).mkdir(parents=True, exist_ok=True)


Dataset location: ./pokemon/


We will now download the pre-trained model and as before and do the setup 

In [2]:
# load the VGG16 network *pre-trained* on the ImageNet dataset
weights = VGG16_Weights.DEFAULT
vgg_model = vgg16(weights=weights)
vgg_model.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

## Freezing the Layers

When we use transfer learning, we typically only want to train the final layer or a few layers of the model. We want to keep the weights of the other layers the same as they were during the initial training. This is known as "freezing" the layers.

If we were to unfreeze all the layers, we would risk destroying the pre-trained weights. The pre-trained model weights are very useful for image classification tasks because they have already learned to recognize many features in the images. We can unfreeze the layers later to add a process called "fine tuning" to further improve the model's accuracy if required.

vgg_model.requires_grad_(False)

## Dataset download

In this demo we are going to add a number of images of different types of pokemon and train the model to recognize them, we also need something that is not a pokemon so we need to download other images too. We will then use the model to determine if we have a pokemon or animal image.

There are a number of datasets on Kaggle we can use. In this case we are going to use the following datasets:

https://www.kaggle.com/api/v1/datasets/download/vishalsubbiah/pokemon-images-and-types



In [13]:
url="https://www.kaggle.com/api/v1/datasets/download/vishalsubbiah/pokemon-images-and-types"

desitnation = DATASET_LOCATION + "pokemon.zip"
if not pathlib.Path(desitnation).exists():
    Utils.download(url, desitnation) 
    Utils.unzip_file(desitnation, DATASET_LOCATION)


./pokemon/pokemon.zip: 100%|██████████| 3.68M/3.68M [00:01<00:00, 2.48MiB/s]
