<a href="https://colab.research.google.com/github/DragosTana/cv_homework/blob/main/Using_pre_trained_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using pre-trained CNN

In this lab, we will see:

- Zero-shot performance of pre-trained backbone
- Use pre-trained CNN as backbone
- Fine-tuning the pre-trained CNN

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

In [None]:
batch_size = 64
lr = 0.01
epochs = 10
device = torch.device("cuda") # to use the GPU

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
# create a split for train/validation. We can use early stop
trainset, valset = torch.utils.data.random_split(dataset, [40000, 10000])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2,
                                          drop_last=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                          shuffle=False, num_workers=2,
                                          drop_last=False)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2,
                                          drop_last=False)


## Load a pre-defined network with pretrained weights



In [None]:
net = models.resnet18(pretrained=True)
# override the fc layer of the network since it is of 1000 classes by default (ImageNet)
net.fc = nn.Linear(512, 10)
net.to(device)

In [None]:
# count the trainable parameters of the model
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_trainable_parameters(net)

In [None]:
# frozen all the weights of the network, except for fc ones
for param in net.parameters():
    param.requires_grad = False
net.fc.weight.requires_grad = True
net.fc.bias.requires_grad = True
count_trainable_parameters(net)

In [None]:
# define train and test function
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    losses = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        losses.append(loss.item())
    return np.mean(losses)

def test(model, device, test_loader, val=False):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    mode = "Val" if val else "Test"
    print('\{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        mode,
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_acc = correct / len(test_loader.dataset)
    return test_loss, test_acc

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-04)

In [None]:
# the main loop
train_losses = []
val_losses = []
val_accuracies = []
model_state_dict = None

for epoch in range(1, epochs + 1):
    train_loss = train(net, device, trainloader, optimizer, epoch)
    train_losses.append(train_loss)
    val_loss, val_acc = test(net, device, valloader)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

In [None]:
test_loss, test_acc = test(net, device, testloader)

## Add additional layer to the pre-trained model


In [None]:
fc1 = nn.Linear(512, 128)

# Modify the existing fully connected layer (fc)
net.fc = nn.Linear(128, 10)

# Replace the model's classifier with a new sequential layer
# that includes the new fc1 and the modified fc
net.fc = nn.Sequential(
    fc1,
    nn.ReLU(),   # Optional: Add an activation function like ReLU
    net.fc
)
net.to(device)

## Fine-tuning some part of the CNN (not only the classifier)

In [None]:
# Unfreeze layer4 parameters
for param in net.layer4.parameters():
    param.requires_grad = True

# Unfreeze fc layer parameters
net.fc.requires_grad = True

# Setting different learning rates
layer4_params = {'params': net.layer4.parameters(), 'lr': 0.001}
fc_params = {'params': net.fc.parameters(), 'lr': 0.1}

# Assuming you are using an Adam optimizer
optimizer = torch.optim.SGD([layer4_params, fc_params], momentum=0.9, weight_decay=1e-04)

## Exercise 1

How many layers it is better to fine-tune?

It is better to update all the weights of the model?

## Exercise 2

Try to change the hyper-parameters of the fine-tuning (e.g. lr of CNN layers and lr of the fc layers) and/or network architecture

## Exercise 3

Try to implement the model selection strategy (also known as early stopping) based on the validation accuracy on cifar10.

Consider using the two following command to respectively save and load the state of all the parameters of the model in a moment.

In [None]:
# save all the parameters of the model
model_state_dict = net.state_dict()

# load saved weights on the model
net.load_state_dict(model_state_dict)
