## Training a CNN for Image Classification on your Custom Dataset: 
## Training from Scratch VS Tranfer Learning

This notebook combines two approaches: **From Scratch** where we build and train a custom CNN from the ground up (identical to [notebook 03](python/02-classification/03_train_custom_classifier_CNN.ipynb)), and **Transfer Learning with Pre-Trained Models** where we leverage pre-trained models (ResNet18) for faster and potentially better convergence.

Reference: [PyTorch Transfer Learning Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

Before you proceed with this notebook, you need to have a custom dataset to train your model on. You may use one of the methods suggested in [04_collect_data.ipynb](04_collect_data.ipynb) and make sure your custom dataset in under your `./datasets` folder.

## Part 1: Training from Scratch

### Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pathlib
from PIL import Image
import time
from copy import deepcopy

import torch
from torch import nn
import torchvision as tv
import torchvision.models as models
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import torchvision.datasets as datasets 
from torchvision.transforms import v2

In [None]:
# Get cpu, gpu or mps device for training

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Create your directory structure for your datasets and models

data_dir = pathlib.Path("datasets/gallery_dl_dataset")
data_dir.mkdir(exist_ok=True)

models_dir = pathlib.Path("models")
models_dir.mkdir(exist_ok=True)

model_name = "image_classifier" # change when working with other datasets

model_dir = models_dir / model_name
model_dir.mkdir(exist_ok=True)

### Data Processing ~ Image Transformations

Could you augment your training data by adding more transformations to them?

You could randomly change their brightness, contrast, saturation, and hue.

You could flip them horizontally or vertically with a 0.5 probability.

You could randomly rotate them.

Look in [here](https://pytorch.org/vision/stable/transforms.html) and [here](https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_illustrations.html#sphx-glr-auto-examples-transforms-plot-transforms-illustrations-py) for references and examples. 

Do you need to also add the above transformations to your validation set? Or are the existing ones enough? You need to consider what the purpose of each dataset is.

In [None]:
num_classes = 3 # your number of classes

train_transform = v2.Compose([
        # v2.Resize(size=(64, 64), antialias=True),
        v2.RandomResizedCrop(size=(64, 64), antialias=True),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True), 
    ])

val_transform = v2.Compose([
        v2.Resize(size=(64, 64), antialias=True),
        v2.CenterCrop(size=(64, 64)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True), 
    ])

# create train and validation datasets with seperate transforms
train_dataset = datasets.ImageFolder(data_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(data_dir, transform=val_transform)
test_dataset = datasets.ImageFolder(data_dir, transform=val_transform)

print("\n".join(train_dataset.classes)) # should show the folder names

Here we create our train, validation and test datasets by splitting the full input dataset into three subsets. A 70-20-10 split is quite common.

By setting a `random_state`, we are performing the split randomly but in a deterministic way, i.e. we will always get the same random train_test_split as long as we use the same random_state.

In [None]:
# get length of the full dataset before split, and save it in idx
num_train = len(train_dataset)

# define the percentage that will be used for validation
val_size = 0.2
test_size = 0.1  

# create an array of idx numbers for each element of the full dataset
idx = list(range(num_train))
print(num_train, idx)

In [None]:
# perform train / val split for data points
train_indices, val_indices = train_test_split(idx, test_size=val_size, random_state=42)
train_indices, test_indices = train_test_split(train_indices, test_size=test_size/(1 - val_size), random_state=42)  

# override datasets to only be samples for each split
train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)
test_dataset = Subset(test_dataset, test_indices)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

### Observing our Data

In [None]:
# Check dataset sizes and sample shape
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

img_num = 92 # change this number to view a different sample

# Get a sample to check shape
sample_img, sample_label = train_dataset[img_num]
print(f"\nSample image shape: {sample_img.shape}")
print(f"Sample label: {sample_label}")
print(f"Classes: {train_dataset.dataset.classes}")

### Visualising Data

In [None]:
# Get a sample image and its label
sample_img, sample_label = train_dataset[img_num]
label_name = train_dataset.dataset.classes[sample_label]

# Plot it
plt.figure(figsize=(6, 6))
plt.title(f"Label: {label_name} (index: {sample_label})")
plt.imshow(sample_img.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C) for display
plt.axis("off")
plt.show()

In [None]:
# plotting for multiple images, randomly selected
figure = plt.figure(figsize=(12, 10))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    # generate a random index
    sample_idx = torch.randint(len(train_dataset), size=(1,)).item()
    # retrieve the image and the respective label for that index
    img, label = train_dataset[sample_idx]
    label_name = train_dataset.dataset.classes[label]
    
    # create the grid of subplots
    figure.add_subplot(rows, cols, i)
    plt.title(label_name, fontsize=8)
    plt.axis("off")
    # Convert from (C, H, W) to (H, W, C) for display
    plt.imshow(img.permute(1, 2, 0))
plt.tight_layout()
plt.show()

### Dataloaders

In [None]:
batch_size = 6

# create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for X, y in val_loader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

### Defining our Convolutional Neural Network

In [None]:
class ConvNetwork(nn.Module):
    def __init__(self):
        super(ConvNetwork, self).__init__()
        # Input shape: [batch, 3, 64, ]
        # Breaking down the first conv layer: 
        #   > 1 input channel for grayscale images
        #   > 32 different filters to output
        #   > 3x3 kernel size
        #   > 1 padding
        # output shape: [batch, 64, 64, 64]
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        # 2x2 maxpooling, output shape: [batch, 64, 32, 32]
        self.pool = nn.MaxPool2d(2, 2)
        # and so on and so forth ...
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 3) # change the output size to match your number of classes
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = torch.flatten(x, 1) 
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
model = ConvNetwork().to(device)
print(model)

In [None]:
print("Layers and their initial weights/bias shapes:")
for name, param in model.named_parameters():
    print(f" - {name} | Shape: {param.shape} | Sample values: {param.data.flatten()[:5]}...")

print()
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

### Optimizer and Loss Function

In [None]:
learning_rate = 0.001

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

### Implementing our Training Loop

In [None]:
epochs = 25
train_losses = []
val_losses = []

for epoch in range(epochs): 
    train_loss = 0.0
    
    # training loop
    for batch_idx, (data, target) in enumerate(train_loader):
        # get data
        inputs = data.to(device)
        labels = target.to(device)
        
        # zero the gradients
        optimizer.zero_grad()
        # forward pass
        predictions = model(inputs)
        # compute the loss
        loss = loss_fn(predictions, labels)
        # backpropagate
        loss.backward()
        # update the parameters, i.e. weights
        optimizer.step()

        # save statistics to plot later
        train_loss += loss.item()
    
    # validation loop
    with torch.no_grad():
        val_loss = 0.0
        for batch_idx, (data, target) in enumerate(val_loader):
            # get data
            inputs = data.to(device)
            labels = target.to(device)
            # forward pass, no backpropagation and optimisation
            predictions = model(inputs)
            # compute the loss
            loss = loss_fn(predictions, labels)
            # save statistics to plot later
            val_loss += loss.item()
    
    # normalise cumulative losses to dataset size
    train_loss = train_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    
    # added cumulative losses to lists to plot later
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f'Epoch {epoch + 1}, train loss: {train_loss:.3f}, val loss: {val_loss:.3f}')

### Testing ~ Evaluating the Performance of our Model

In [None]:
plt.figure(figsize=(10,5))
plt.title("Train vs validation loss - From Scratch")
plt.plot(train_losses,label="train")
plt.plot(val_losses,label="val")
plt.xlabel("epochs")
plt.ylabel("cumulative loss")
plt.legend()
plt.show()

In [None]:
def test(dataloader, model, loss_fn, device=device):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0.0
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    # average loss across batches and accuracy across samples
    test_loss = test_loss / num_batches
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, accuracy

# Run test on the test loader
test_loss, test_acc = test(test_loader, model, loss_fn)

### Using our Model on an Input Image

In [None]:
img = Image.open('images/colorful-carpet-sample.png') # try also images/4.png

transforms = v2.Compose([  
    # v2.Grayscale(num_output_channels=1),
    v2.Resize(size=(64,64), antialias=True),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
])

input = transforms(img).unsqueeze(0)  # ADD BATCH DIMENSION [1, 1, 28, 28]
input = input.to(device)

print(f"Input shape: {input.shape}")

In [None]:
model.eval()
with torch.no_grad():
    predictions = nn.Softmax(dim=-1)(model(input)).cpu().numpy()
print(f"Our predictions (shape: {predictions.shape})")
print(predictions)

In [None]:
predicted = np.argmax(predictions[0]) # argmax: the *index* of the highest prediction

plt.figure()
plt.title(f'Predicted number: {train_dataset.dataset.classes[predicted]}') # use the predicted category in the title
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.show()

We can plot our predictions for all classes using a [bar chart](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.bar.html).

In [None]:
plt.figure(figsize=(14,5))
plt.title("Predictions - From Scratch")
xs = train_dataset.dataset.classes     # 0 to 9 for Xs, our ys are our predictions
plt.bar(xs, predictions[0]) # a bar chart
plt.xticks(xs)
plt.show()

### Save our Model

In [None]:
torch.jit.save(torch.jit.script(model), model_dir / f"my_{model_name}_from_scratch.pt")

---

## Part 2: Transfer Learning & Fine-tuning with Pre-trained Models

In this section, we leverage transfer learning with pre-trained models from ImageNet. This approach is typically faster to train and can achieve better accuracy with limited data.

**Key Concepts:**
- **Fine-tuning**: Train all layers of a pre-trained model on your custom dataset
- **Feature Extraction**: Freeze pre-trained layers and only train the final classifier

### ImageNet Normalization for Pre-trained Models

The pre-trained ResNet18 was trained on ImageNet with images that were normalized using those specific mean and std values [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225]. Without the normalization step, our model receives input that doesn't match the distribution it learned from. 

In [None]:
# For transfer learning, we use ImageNet statistics for normalization
# These are the mean and std values that ResNet and other pre-trained models were trained on

transfer_train_transform = v2.Compose([
    v2.RandomResizedCrop(224),
    v2.RandomHorizontalFlip(),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225]),
])

transfer_val_transform = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225]),
])

# Create transfer learning datasets
transfer_train_dataset = datasets.ImageFolder(data_dir, transform=transfer_train_transform)
transfer_val_dataset = datasets.ImageFolder(data_dir, transform=transfer_val_transform)
transfer_test_dataset = datasets.ImageFolder(data_dir, transform=transfer_val_transform)

In [None]:
# Split data for transfer learning (same indices as before for consistency)
num_train_tl = len(transfer_train_dataset)
idx_tl = list(range(num_train_tl))

train_indices_tl, val_indices_tl = train_test_split(idx_tl, test_size=val_size, random_state=42)
train_indices_tl, test_indices_tl = train_test_split(train_indices_tl, test_size=test_size/(1 - val_size), random_state=42)

transfer_train_dataset = Subset(transfer_train_dataset, train_indices_tl)
transfer_val_dataset = Subset(transfer_val_dataset, val_indices_tl)
transfer_test_dataset = Subset(transfer_test_dataset, test_indices_tl)

# Create dataloaders for transfer learning
transfer_batch_size = 6  # 32 Can use larger batch size for feature extraction / pre-trained models 

transfer_train_loader = DataLoader(transfer_train_dataset, batch_size=transfer_batch_size, shuffle=True)
transfer_val_loader = DataLoader(transfer_val_dataset, batch_size=transfer_batch_size, shuffle=False)
transfer_test_loader = DataLoader(transfer_test_dataset, batch_size=transfer_batch_size, shuffle=False)

print(f"Transfer Learning DataLoaders created")
print(f"  Training samples: {len(transfer_train_dataset)}")
print(f"  Validation samples: {len(transfer_val_dataset)}")
print(f"  Test samples: {len(transfer_test_dataset)}")

In [None]:
# Get a sample image and its label
sample_img, sample_label = transfer_train_dataset[img_num]
label_name = transfer_train_dataset.dataset.classes[sample_label]

# Plot it
plt.figure(figsize=(6, 6))
plt.title(f"Label: {label_name} (index: {sample_label})")
plt.imshow(sample_img.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C) for display
plt.axis("off")
plt.show()

### Approach 1: Fine-tuning the Pre-trained Model

In [this approach](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#finetuning-the-convnet), we:
1. Load a pre-trained ResNet18 model
2. Replace the final fully connected layer to match our number of classes
3. Train all layers (updating all weights)

In [None]:
# Load a pre-trained ResNet18 model
model_ft = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# Get the number of input features for the final layer
num_ftrs = model_ft.fc.in_features

# Replace the final fully connected layer
model_ft.fc = nn.Linear(num_ftrs, num_classes)

# Move model to device
model_ft = model_ft.to(device)

print("Pre-trained ResNet18 model loaded and final layer replaced")
print(f"Final layer: {model_ft.fc}")

In [None]:
print("Layers and their initial weights/bias shapes:")
for name, param in model_ft.named_parameters():
    print(f" - {name} | Shape: {param.shape} | Sample values: {param.data.flatten()[:5]}...")

print()
print(f"Total parameters: {sum(p.numel() for p in model_ft.parameters())}")

In [None]:
# Training function for transfer learning
def train_transfer_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, model_name="model"):
    """
    Train a model with validation loop and learning rate scheduling.
    Saves the best model based on validation accuracy.
    """
    since = time.time()
    best_acc = 0.0
    best_model_wts = deepcopy(model.state_dict())
    
    train_losses = []
    val_losses = []
    val_accs = []
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
                dataloader = train_loader
            else:
                model.eval()   # Set model to evaluate mode
                dataloader = val_loader
            
            running_loss = 0.0
            running_corrects = 0
            total_samples = 0
            
            # Iterate over data
            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Zero the parameter gradients
                optimizer.zero_grad()
                
                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    # Backward pass + optimize only in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                
                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                total_samples += inputs.size(0)
            
            if phase == 'train':
                scheduler.step()
            
            epoch_loss = running_loss / total_samples
            epoch_acc = running_corrects.float() / total_samples
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # Deep copy the model if it has the best validation accuracy
            if phase == 'val':
                val_losses.append(epoch_loss)
                epoch_acc_value = epoch_acc.item()
                val_accs.append(epoch_acc_value)
                if epoch_acc_value > best_acc:
                    best_acc = epoch_acc_value
                    best_model_wts = deepcopy(model.state_dict())
            else:
                train_losses.append(epoch_loss)
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, train_losses, val_losses, val_accs

In [None]:
# Setup for fine-tuning
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = torch.optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

# Train and evaluate
num_epochs_ft = 25

print("Starting Fine-tuning...")
model_ft, train_losses_ft, val_losses_ft, val_accs_ft = train_transfer_model(
    model_ft, transfer_train_loader, transfer_val_loader,
    criterion, optimizer_ft, exp_lr_scheduler,
    num_epochs_ft, "fine_tuned"
)

In [None]:
# Plot fine-tuning results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.title("Fine-tuning - Loss")
plt.plot(train_losses_ft, label="train")
plt.plot(val_losses_ft, label="val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.title("Fine-tuning - Validation Accuracy")
plt.plot(val_accs_ft, label="val accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Test the fine-tuned model
def test_transfer(dataloader, model, loss_fn, device):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0.0
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss = test_loss / num_batches
    accuracy = correct / size
    print(f"Test Error (Fine-tuned): \n Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, accuracy

test_loss_ft, test_acc_ft = test_transfer(transfer_test_loader, model_ft, criterion, device)

### Approach 2: ConvNet as Fixed Feature Extractor

In [this approach](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#convnet-as-fixed-feature-extractor), we:
1. Load a pre-trained ResNet18 model
2. Freeze all layers except the final fully connected layer
3. Train only the final classifier layer

In [None]:
# Load another copy of pre-trained ResNet18
model_conv = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# Freeze all the network parameters
for param in model_conv.parameters():
    param.requires_grad = False

# Replace the final fully connected layer
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, num_classes)

# Move model to device
model_conv = model_conv.to(device)

print("Pre-trained ResNet18 loaded as feature extractor")
print("All parameters frozen except final layer")

In [None]:
# Setup for feature extraction
criterion_conv = nn.CrossEntropyLoss()

# Observe that only the final layer parameters are being optimized
optimizer_conv = torch.optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler_conv = torch.optim.lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

# Train and evaluate
num_epochs_conv = 25

print("Starting Feature Extraction training...")
model_conv, train_losses_conv, val_losses_conv, val_accs_conv = train_transfer_model(
    model_conv, transfer_train_loader, transfer_val_loader,
    criterion_conv, optimizer_conv, exp_lr_scheduler_conv,
    num_epochs_conv, "feature_extractor"
)

In [None]:
# Plot feature extraction results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.title("Feature Extraction - Loss")
plt.plot(train_losses_conv, label="train")
plt.plot(val_losses_conv, label="val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.title("Feature Extraction - Validation Accuracy")
plt.plot(val_accs_conv, label="val accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Test the feature extractor model
test_loss_conv, test_acc_conv = test_transfer(transfer_test_loader, model_conv, criterion_conv, device)

### Comparison: From Scratch vs Transfer Learning

In [None]:
# Create a comparison plot
plt.figure(figsize=(14, 5))

plt.title("Test Accuracy Comparison")
approaches = ['From Scratch\n(Custom CNN)', 'Fine-tuning\n(ResNet18)', 'Feature Extraction\n(ResNet18)']
accuracies = [test_acc, test_acc_ft, test_acc_conv]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
bars = plt.bar(approaches, accuracies, color=colors)
plt.ylabel("Test Accuracy")
plt.ylim([0, 1])
# Add value labels on bars
for bar, acc in zip(bars, accuracies):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2, 
             f'{acc:.1%}', ha='center', va='bottom')

plt.show()

In [None]:
# Print summary
print("\n" + "="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"\nFrom Scratch (Custom CNN):")
print(f"  Test Accuracy: {test_acc:.4f} ({100*test_acc:.2f}%)")
print(f"  Test Loss: {test_loss:.4f}")

print(f"\nTransfer Learning - Fine-tuning (ResNet18):")
print(f"  Test Accuracy: {test_acc_ft:.4f} ({100*test_acc_ft:.2f}%)")
print(f"  Test Loss: {test_loss_ft:.4f}")
print(f"  Improvement over from-scratch: {100*(test_acc_ft - test_acc):.2f}%")

print(f"\nTransfer Learning - Feature Extraction (ResNet18):")
print(f"  Test Accuracy: {test_acc_conv:.4f} ({100*test_acc_conv:.2f}%)")
print(f"  Test Loss: {test_loss_conv:.4f}")
print(f"  Improvement over from-scratch: {100*(test_acc_conv - test_acc):.2f}%")
print("="*50)

In [None]:
img = Image.open('images/colorful-carpet-sample.png')

# transforms for CNN from scratch (must match what the model was trained on, except for batch dimension)
small_transform = v2.Compose([  
        v2.Resize(size=(64, 64), antialias=True),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
    ])

input_img_small = small_transform(img).unsqueeze(0)
input_img_small = input_img_small.to(device)

# transforms for transfer learning (must match what the model was trained on, including normalization)
transfer_transforms = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                  std=[0.229, 0.224, 0.225]),
])

input_img = transfer_transforms(img).unsqueeze(0)
input_img = input_img.to(device)

In [None]:
# Compare all three models on the same image
print("\n" + "="*60)
print("COMPARISON: ALL THREE MODELS")
print("="*60)

fig, axes = plt.subplots(1, 4, figsize=(18, 4))

# Original image
axes[0].imshow(img)
axes[0].set_title("Original Image")
axes[0].axis("off")

# From Scratch Model
model.eval()
with torch.no_grad():
    predictions_scratch = nn.Softmax(dim=-1)(model(input_img_small)).cpu().numpy()

predicted_scratch = np.argmax(predictions_scratch[0])
axes[1].bar(train_dataset.dataset.classes, predictions_scratch[0])
axes[1].set_title(f"From Scratch\nPred: {train_dataset.dataset.classes[predicted_scratch]}")
axes[1].set_ylabel("Probability")
axes[1].tick_params(axis='x', rotation=45)

# Fine-tuned Model
model_ft.eval()
with torch.no_grad():
    predictions_ft = nn.Softmax(dim=-1)(model_ft(input_img)).cpu().numpy()

predicted_ft = np.argmax(predictions_ft[0])
axes[2].bar(transfer_train_dataset.dataset.classes, predictions_ft[0])
axes[2].set_title(f"Fine-tuned (ResNet18)\nPred: {transfer_train_dataset.dataset.classes[predicted_ft]}")
axes[2].set_ylabel("Probability")
axes[2].tick_params(axis='x', rotation=45)

# Feature Extractor Model
model_conv.eval()
with torch.no_grad():
    predictions_conv = nn.Softmax(dim=-1)(model_conv(input_img)).cpu().numpy()

predicted_conv = np.argmax(predictions_conv[0])
axes[3].bar(transfer_train_dataset.dataset.classes, predictions_conv[0])
axes[3].set_title(f"Feature Extractor (ResNet18)\nPred: {transfer_train_dataset.dataset.classes[predicted_conv]}")
axes[3].set_ylabel("Probability")
axes[3].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

print(f"\nFrom Scratch prediction: {train_dataset.dataset.classes[predicted_scratch]} (conf: {predictions_scratch[0][predicted_scratch]:.4f})")
print(f"Fine-tuned prediction: {transfer_train_dataset.dataset.classes[predicted_ft]} (conf: {predictions_ft[0][predicted_ft]:.4f})")
print(f"Feature Extractor prediction: {transfer_train_dataset.dataset.classes[predicted_conv]} (conf: {predictions_conv[0][predicted_conv]:.4f})")

### Save the Transfer Learning Models

In [None]:
# Save the fine-tuned model
torch.save(model_ft.state_dict(), model_dir / f"my_{model_name}_fine_tuned.pt")
print(f"Fine-tuned model saved to {model_dir / f'my_{model_name}_fine_tuned.pt'}")

# Save the feature extractor model
torch.save(model_conv.state_dict(), model_dir / f"my_{model_name}_feature_extractor.pt")
print(f"Feature extractor model saved to {model_dir / f'my_{model_name}_feature_extractor.pt'}")

## Key Takeaways

### From Scratch vs Transfer Learning:

1. **From Scratch (Custom CNN)**:
   - Requires more data and training time
   - Good for learning fundamentals
   - Achieves lower accuracy with limited data

2. **Transfer Learning - Fine-Tuning**:
   - Leverages pre-trained features (from ImageNet in this case)
   - Updates all weights during training
   - Usually faster and better accuracy
   - Best when you have moderate-sized datasets

3. **Transfer Learning - Feature Extraction**:
   - Freezes pre-trained layers, trains only final classifier
   - Fastest training time
   - Good for small datasets
   - May not adapt as well to your specific task