# Transfer Learning

The process of transfer learning involves taking a pre-trained model and adapting the model to a new, different data set. In this notebook, we will demonstrate how to use transfer learning to train a model to perform image classification on a data set that is different from the data set on which the pre-trained model was trained. 


Transfer learning is really useful when we have a small dataset to train against, and the pre-trained model has been trained on a larger dataset because a small dataset will memorize the data quickly and not work on the new data.

In the previous notebook, we trained a model on the vgg16 model or animal images, we will use the same model to train on the new images. 

In [1]:
from imports import *


sys.path.append("../")
import Utils

device = Utils.get_device()

DATASET_LOCATION = ""
if Utils.in_lab():
    DATASET_LOCATION = "/transfer/pokemon/"
else:
    DATASET_LOCATION = "./pokemon/"

print(f"Dataset location: {DATASET_LOCATION}")
pathlib.Path(DATASET_LOCATION).mkdir(parents=True, exist_ok=True)

Dataset location: ./pokemon/


## Dataset download

In this demo we are going to add a number of images of different types of pokemon and train the model to recognize them, we also need something that is not a pokemon so we need to download other images too. We will then use the model to determine if we have a pokemon or animal image.

There are a number of datasets on Kaggle we can use. In this case we are going to use the following datasets:

https://www.kaggle.com/api/v1/datasets/download/vishalsubbiah/pokemon-images-and-types



In [2]:
url = "https://www.kaggle.com/api/v1/datasets/download/vishalsubbiah/pokemon-images-and-types"

desitnation = DATASET_LOCATION + "pokemon.zip"
if not pathlib.Path(desitnation).exists():
    Utils.download(url, desitnation)
    Utils.unzip_file(desitnation, DATASET_LOCATION)

# VGG16 Model

we are going to usethe vgg16 model which has a 1000 categories, we will remove the last layer and add a new layer with 1001 categories so we can add pokemon as a new category.

We also need to add the new images to the dataset and retrain the model with a new label 1001 for the pokemon images.


We will now download the pre-trained model and as before and do the setup 

In [3]:
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights

# load the VGG16 network *pre-trained* on the ImageNet dataset
weights = VGG16_Weights.DEFAULT
vgg_model = vgg16(weights=weights)
vgg_model.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [4]:
import pathlib

pre_trans = weights.transforms()

# IMAGE_WIDTH, IMAGE_HEIGHT = (224, 224)

# pre_trans = transforms.Compose([
#     transforms.ToDtype(torch.float32, scale=True), # Converts [0, 255] to [0, 1]
#     transforms.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),
#     transforms.Normalize(
#         mean=[0.485, 0.456, 0.406],
#         std=[0.229, 0.224, 0.225],
#     ),
#     transforms.CenterCrop(224)
# ])


class MyDataset(Dataset):
    def __init__(self, data_dir):
        self.imgs = []
        self.labels = []
        images=list(pathlib.Path(data_dir).rglob("*.png"))
        for image in images:
            img = Image.open(image).convert("RGB")
            img_transformed=pre_trans(img)
            self.imgs.append(img_transformed.to(device))
            self.labels.append(torch.tensor(1001).to(device).float())
        print(f"Loaded {len(self.imgs)} images")

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        return img, label


    def __len__(self):
        return len(self.imgs)
    
data_loader = DataLoader(MyDataset(DATASET_LOCATION))   

Loaded 809 images


The vgg16 model has a classifier attribute, which is a sequential module defining the fully connected layers. The last layer is the classification layer.

In [5]:
print(vgg_model.classifier)

Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)


The final layer outputs logits for 1000 classes (ImageNet categories). To add a new category, you must replace this layer.

In [10]:
# Number of input features to the final layer
num_features = vgg_model.classifier[6].in_features

# Replace the final layer
vgg_model.classifier[6] = torch.nn.Linear(num_features, 1001).to(device)

## Freezing the Layers

When we use transfer learning, we typically only want to train the final layer or a few layers of the model. We want to keep the weights of the other layers the same as they were during the initial training. This is known as "freezing" the layers.

If we were to unfreeze all the layers, we would risk destroying the pre-trained weights. The pre-trained model weights are very useful for image classification tasks because they have already learned to recognize many features in the images. We can unfreeze the layers later to add a process called "fine tuning" to further improve the model's accuracy if required.

vgg_model.requires_grad_(False)

In [11]:
for param in vgg_model.features.parameters():
    param.requires_grad = False
    

In [None]:
## Fine Tuning

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vgg_model.classifier.parameters(), lr=0.001)

# Example training loop
num_epochs = 10
for epoch in range(num_epochs):
    for inputs, labels in data_loader: 
        optimizer.zero_grad()
        outputs = vgg_model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
