## Transfer-Learning

We fine-tune the last Fully-Connected layer in a pretrained ResNet18 on ImageNet1k to classify the Places365 dataset.

In [None]:
# # header
import sys
sys.path.append("./")

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import ResNet18_Weights
import matplotlib.pyplot as plt


# Setup

In [None]:
# please define the data directory
data_root = "./data"
assert os.path.exists(data_root), "Data directory not found!"


In [None]:
# set the device
num_workers, pin_memory = 4, False
if torch.cuda.is_available():
    device = torch.device("cuda:0")  # NVIDIA GPU
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")  # Apple Silicon (Metal)
else:
    device = torch.device("cpu") # CPU (slowest option)

print(f"Device set to {device}!")


In [None]:
# Helper function to return the number of learble parameters in a model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Freezing all layers except the layer with the name "fc"
def freez_not_fc(model: nn.Module):
    for name, param in model.named_parameters():
        if "fc" not in name:
            param.requires_grad = False

def evaluate(model, val_loader, criterion, device, pre_process):
    model.eval()
    model.to(device)
    correct, total, loss = 0, 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            x = pre_process(x)
            y_hat = model(x)
            loss += criterion(y_hat, y).item()
            correct += (y_hat.argmax(1) == y).sum().item()
            total += len(y)
    return loss / len(val_loader), 100 * correct / total


def train(model, train_loader, val_loader, criterion, optimizer, device, pre_process, epochs=10):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            x = pre_process(x)
            optimizer.zero_grad(set_to_none=True)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
        val_loss, val_acc = evaluate(model, val_loader, criterion, device, pre_process)
        print(f"Epoch {1 + epoch:3d} | Val loss {val_loss:.4f} | Val acc {val_acc:.2f}%")

## Dataset

`Places365` dataset

Since the dataset is very large, for this exercise we use the validation dataset, split to two, and use it as train and validation set.

[Torchvision built-in datasets](https://pytorch.org/vision/stable/datasets.html)

In [None]:
val_ds = datasets.Places365(root=data_root, 
                              split='val', 
                              small=True,
                            #   download=True, 
                              transform=transforms.ToTensor())
class_names = val_ds.classes
# randomly splitting the dataset into 90% training and 10% validation
n_train = int(0.9 * len(val_ds))
n_val = len(val_ds) - n_train
train_ds, val_ds = torch.utils.data.random_split(val_ds, [n_train, n_val])

train_dl = torch.utils.data.DataLoader(train_ds,
                                        batch_size=32,
                                        shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, 
                                     batch_size=32, 
                                     shuffle=False)

In [None]:
# plot some images
fig, axes = plt.subplots(2, 3, figsize=(9, 6))
for i, ax in enumerate(axes.flat):
    x, y = val_ds[i]
    ax.imshow(x.permute(1, 2, 0))
    ax.set_title(f"Class: {class_names[y].split('/')[-1]}")
    ax.axis("off")
plt.show()

## Model

We use the pre-trained ResNet18 model on ImageNet1k dataset.

**NOTE:** Most models come with a pre-processing step which we need to include in our process anytime we use the pre-trained models.

[Models and Pre-Trained weights](https://pytorch.org/vision/stable/models.html)

In [None]:
model_name = "resnet18"
resnet_model = torch.hub.load("pytorch/vision", "resnet18", weights=ResNet18_Weights.IMAGENET1K_V1, )
resnet_model.eval()
pre_process = ResNet18_Weights.IMAGENET1K_V1.transforms()
print(pre_process)
print(resnet_model)

## Fine-Tuning through the FC layer

Often, it is enough to only train the Fully-Connected layers. Here we replace the last layer (called `fc` in resnet18) with a randomly initialized linear layer with 365 output classes to match number of classes in our target dataset.

In [None]:
# replace the last layer with a new layer with 365 output classes
resnet_model.fc = nn.Linear(512, 365)

In [None]:
# validation loss before training
loss, acc = evaluate(resnet_model, val_dl, nn.CrossEntropyLoss(), device, pre_process)
print(f"Before training | Val loss {loss:.4f} | Val acc {acc:.2f}%")

In [None]:
# train the model
freez_not_fc(resnet_model)
optimizer = optim.Adam(resnet_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
train(resnet_model, train_dl, val_dl, criterion, optimizer, device, pre_process, epochs=1)