# Lab 02: Training a Custom Model


**Objective of this lab**: training a small custom model on the Tiny-ImageNet dataset.

## Dataset preparation

In [None]:
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip tiny-imagenet-200.zip -d tiny-imagenet

We need to adjust the format of the val split of the dataset to be used with ImageFolder.

Riformattazione del Validation Set: Questo codice Python riorganizza la cartella __val__. Legge _val_annotations.txt_, che mappa ogni immagine del set di validazione alla sua classe. Quindi crea una sottocartella per ognuna delle 200 classi e sposta le immagini al loro interno, ottenendo la struttura root/{cls}/{fn} necessaria per _ImageFolder_.

In [None]:
import os
import shutil

with open('tiny-imagenet/tiny-imagenet-200/val/val_annotations.txt') as f:
    for line in f:
        fn, cls, *_ = line.split('\t')
        os.makedirs(f'tiny-imagenet/tiny-imagenet-200/val/{cls}', exist_ok=True)

        shutil.copyfile(f'tiny-imagenet/tiny-imagenet-200/val/images/{fn}', f'tiny-imagenet/tiny-imagenet-200/val/{cls}/{fn}')

shutil.rmtree('tiny-imagenet/tiny-imagenet-200/val/images')

In [None]:
from torchvision.datasets import ImageFolder
import torchvision.transforms as T

transform = T.Compose([
    T.Resize((224, 224)),  # Resize to fit the input dimensions of the network
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# root/{classX}/x001.jpg

tiny_imagenet_dataset_train = ImageFolder(root='tiny-imagenet/tiny-imagenet-200/train', transform=transform)
tiny_imagenet_dataset_val = ImageFolder(root='tiny-imagenet/tiny-imagenet-200/val', transform=transform)

In [None]:
print(f"Length of train dataset: {len(tiny_imagenet_dataset_train)}")
print(f"Length of val dataset: {len(tiny_imagenet_dataset_val)}")

# The following code also checks the number of samples per class
# from collections import Counter

# class_counts = Counter([target for _, target in tiny_imagenet_dataset_val])
# for class_label, count in class_counts.items():
#     print(f"Class {class_label}: {count} entries")


In [None]:
import torch
train_loader = torch.utils.data.DataLoader(tiny_imagenet_dataset_train, batch_size=32, shuffle=True, num_workers=8)
val_loader = torch.utils.data.DataLoader(tiny_imagenet_dataset_val, batch_size=32, shuffle=False)

## Custom model definition

In [None]:
import torch
from torch import nn

# Define the custom neural network
class CustomNet(nn.Module):
    def __init__(self):
        super(CustomNet, self).__init__()

        # Convolutional sequence (Conv -> ReLU -> MaxPool)
        self.features = nn.Sequential(
            # Block 1: B x 3 x 224 x 224 -> B x 64 x 112 x 112
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 2: B x 64 x 112 x 112 -> B x 128 x 56 x 56
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Calculate the dimensions after the two MaxPools: 224 -> 112 -> 56
        # Channels: 128
        # Flattened size: 128 * 56 * 56 = 401408 (for batch_size=32)
        self.classifier = nn.Sequential(
            nn.Flatten(),   # Flatten the tensor that means B x 128 x 56 x 56 -> B x (128*56*56)
            nn.Linear(128 * 56 * 56, 512), # Fully connected layer means B x (128*56*56) -> B x 512
            nn.Linear(512, 200) # 200 is the number of classes in TinyImageNet
        )

    # Forward pass : defines how the data flows through the network
    def forward(self, x):
        x = self.features(x)
        logits = self.classifier(x)
        return logits   # in logits we have the raw scores for each class

In [None]:
def train(epoch, model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.cuda(), targets.cuda()
        # todo...
        # Compute outputs and loss
        optimizer.zero_grad() # Reset gradients (as in PDF [cite: 112])
        outputs = model(inputs) 
        loss = criterion(outputs, targets) # Compute loss (as in PDF [cite: 111])

        # Backpropagation and optimization
        loss.backward() # Backpropagation (compute new gradients) (as in PDF [cite: 113])
        optimizer.step() # Update weights (as in PDF [cite: 114])
        # .....

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    train_loss = running_loss / len(train_loader)
    train_accuracy = 100. * correct / total
    print(f'Train Epoch: {epoch} Loss: {train_loss:.6f} Acc: {train_accuracy:.2f}%')

In [None]:
# Validation loop
def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0

    correct, total = 0, 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.cuda(), targets.cuda()

            # todo...
            # Calcolo dell'output e della loss
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            # ......

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    val_loss = val_loss / len(val_loader)
    val_accuracy = 100. * correct / total

    print(f'Validation Loss: {val_loss:.6f} Acc: {val_accuracy:.2f}%')
    return val_accuracy

## Putting everything together

In [None]:
model = CustomNet().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

best_acc = 0

# Run the training process for {num_epochs} epochs
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(epoch, model, train_loader, criterion, optimizer)

    # At the end of each training iteration, perform a validation step
    val_accuracy = validate(model, val_loader, criterion)

    # Best validation accuracy
    best_acc = max(best_acc, val_accuracy)


print(f'Best validation accuracy: {best_acc:.2f}%')
