# Part B - Fine-tuning a Pretrained model (GoogLeNet)

### Importing required libraries

In [7]:
from tqdm.auto import tqdm
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data_utils
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import wandb

### Logging in Wandb

In [8]:
wandb.login()

True

## Data generation

#### Data generation and transformation for better training input for the model
- From the dataset path provided in main, we are generating the dataset and performing transformation to get better training data for the model
- Applying the normalization using mean and standard deviation (calculated using "get_mean_and_std" on training dataset)
- Splitting 20% of the training dataset such that 80% is in training dataset and 20% is in testing dataset
- Validation dataset is generated such that every class contains an equal amount of data (200 images in each class)

In [15]:
def data_generation(dataset_path, num_classes=10, data_augmentation=False, batch_size=32):
    
    # Mean and standard deviation values calculated from function get_mean_and_std on training dataset

    mean = [0.4708, 0.4596, 0.3891]
    std = [0.1951, 0.1892, 0.1859]


    # Define transformations for training and testing data
    
    augment_transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.RandomHorizontalFlip(), 
        transforms.RandomRotation(30), 
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
    ])

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
        ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
    ])


    # Data augmentation (if data_augmentation = True) 

    train_dataset = datasets.ImageFolder(root = dataset_path + "train", transform=train_transform)
    test_dataset = datasets.ImageFolder(root = dataset_path + "val", transform=test_transform)
    
    
    # Split train dataset into train and validation sets

    train_data_class = dict()
    for c in range(num_classes):
        train_data_class[c] = [i for i, label in enumerate(train_dataset.targets) if label == c]

    val_data_indices = []
    val_ratio = 0.2  # 20% for validation
    for class_indices in train_data_class.values():
        num_val = int(len(class_indices) * val_ratio)
        val_data_indices.extend(random.sample(class_indices, num_val))


    # Create training and validation datasets

    train_data = torch.utils.data.Subset(train_dataset, [i for i in range(len(train_dataset)) if i not in val_data_indices])
    val_data = torch.utils.data.Subset(train_dataset, val_data_indices)


    # Create data loaders

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    if data_augmentation:
      augmented_dataset = datasets.ImageFolder(root = dataset_path + "train", transform=augment_transform)
      augmented_loader = DataLoader(augmented_dataset, batch_size=batch_size, shuffle=True)
      train_loader = torch.utils.data.ConcatDataset([train_loader.dataset, augmented_loader.dataset])
      train_loader = DataLoader(train_loader, batch_size=batch_size, shuffle=True)


    # Get class names
    classpath = pathlib.Path(dataset_path + "train")
    class_names = sorted([j.name.split('/')[-1] for j in classpath.iterdir() if j.name != ".DS_Store"])

    return train_loader, val_loader, test_loader, class_names

## Evaluation

- trainCNN evaluates training, validation and testing accuracies and losses on model (GoogLeNet)

In [16]:
def trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam"):    
    criterion = nn.CrossEntropyLoss()
    if optimizer == "Adam":
        opt_func = optim.Adam(model.parameters(), lr=0.001)

    total_correct = 0
    total_samples = 0

    for epoch in tqdm(range(num_epochs)):
        model.train()  # Set the model to training mode
        running_loss = 0.0
        total_correct = 0
        total_samples = 0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            opt_func.zero_grad()  # Zero the gradients
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute the loss
            loss.backward()  # Backward pass
            opt_func.step()  # Update the parameters

            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)

            running_loss += loss.item() * inputs.size(0)
        loss = running_loss / len(train_loader.dataset)
        accuracy = total_correct / total_samples
        print(f"Epoch [{epoch+1}/{num_epochs}], Accuracy: {accuracy * 100:.2f}%, Loss: {loss:.4f}")
        wandb.log({'accuracy': accuracy, 'loss': loss})


        # Validation
        model.eval()
        with torch.no_grad():
            val_total_correct = 0
            val_total_samples = 0
            val_running_loss = 0.0
            for val_inputs, val_labels in tqdm(val_loader):
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                val_outputs = model(val_inputs)
                val_loss = criterion(val_outputs, val_labels)

                _, val_predicted = torch.max(val_outputs, 1)
                val_total_correct += (val_predicted == val_labels).sum().item()
                val_total_samples += val_labels.size(0)

                val_running_loss += val_loss.item() * val_inputs.size(0)

            val_loss = val_running_loss / len(val_loader.dataset)
            val_accuracy = val_total_correct / val_total_samples
            print(f"Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {val_accuracy * 100:.2f}%, Validation Loss: {val_loss:.4f}")
            wandb.log({'val_accuracy': val_accuracy, 'val_loss': val_loss})

        # Test accuracy evaluation
        if epoch==num_epochs-1:
            model.eval()
            with torch.no_grad():
                test_total_correct = 0
                test_total_samples = 0
                test_running_loss = 0.0
                for test_inputs, test_labels in tqdm(test_loader):
                    test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
                    test_outputs = model(test_inputs)
                    test_loss = criterion(test_outputs, test_labels)
    
                    _, test_predicted = torch.max(test_outputs, 1)
                    test_total_correct += (test_predicted == test_labels).sum().item()
                    test_total_samples += test_labels.size(0)
    
                    test_running_loss += test_loss.item() * test_inputs.size(0)
    
                test_loss = test_running_loss / len(test_loader.dataset)
                test_accuracy = test_total_correct / test_total_samples
                print(f"Test Accuracy: {test_accuracy * 100:.2f}%, Test Loss: {test_loss:.4f}")

## Fine Tuning of the model

#### Three types of fine-tuning used:
Freezing - Basically fixing the parameters (weights and biases) of certain layers, preventing them from being updated. 
1. **Feature Extraction:** It freezes all the layers and then we are applying a fully connected layer which takes the input from the last layer and gives one the 10 outputs of class.
2. **Freeze K layers:** In this, we freeze first K layers, while keeping other layers unfreeze. That way we can keep updating the parameters after k layers.
3. **Full fine-tuning:** Not fixing the parameters of any layer, therefore the parameters will be updated in every layer.

In [17]:
def feature_extraction(model, device):
    for params in model.parameters():
        params.requires_grad = False

def freeze_till_k(model, device, k):
    for params in model.parameters():
        for idx, child in enumerate(model.children()):
            # Freeze layers up to the k-th layer
            if idx < k:
                for param in child.parameters():
                    param.requires_grad = False
            else:
                # Stop iterating after k-th layer
                break

def no_freezing(model, device):
    for params in model.parameters():
        params.requires_grad = True

In [18]:
def main():
    dataset_path = '../inaturalist_12K/'  

    data_augmentation = True
    batch_size = 32
    num_classes = 10
    fine_tuning_method = 2
    k = 10

    def train():
        with wandb.init(project="CS6910_Assignment_2_Part_B") as run:
            config = wandb.config
            run_name = "aug_" + str(data_augmentation) + "_bs_" + str(batch_size) + "_fine_tune_" + str(fine_tuning_method) + "_num_freeze_layer_all"
            if fine_tuning_method != 1:
                run_name = "aug_" + str(data_augmentation) + "_bs_" + str(batch_size) + "_fine_tune_" + str(fine_tuning_method) + "_num_freeze_layer_" + str(k)
            elif fine_tuning_method == 3:
                run_name = "aug_" + str(data_augmentation) + "_bs_" + str(batch_size) + "_fine_tune_" + str(fine_tuning_method) + "_num_freeze_layer_none"

            wandb.run.name = run_name
            train_loader, val_loader, test_loader, class_names = data_generation(dataset_path, 
                                                                                     num_classes=10, 
                                                                                     data_augmentation=data_augmentation, 
                                                                                     batch_size=batch_size)
            print("Train: ", len(train_loader))
            print("Val: ", len(val_loader))
            print("Test: ", len(test_loader))
    
            device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
            print("Device: ", device)
        
            model = models.googlenet(pretrained=True)

            if fine_tuning_method == 1:
                feature_extraction(model, device)
                model.fc = nn.Linear(model.fc.in_features, num_classes)
                model.to(device)
                trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam")
            
            elif fine_tuning_method == 2:
                freeze_till_k(model, device, k)
                model.fc = nn.Linear(model.fc.in_features, num_classes)
                model.to(device)
                trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam")

            else:
                feature_extraction(model, device)
                model.fc = nn.Linear(model.fc.in_features, num_classes)
                model.to(device)
                trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam")
    train()
    wandb.finish()
    
main()

Train:  563
Val:  63
Test:  63
Device:  mps




  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/563 [00:00<?, ?it/s]

Epoch [1/10], Accuracy: 63.06%, Loss: 1.0889


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch [1/10], Validation Accuracy: 73.44%, Validation Loss: 0.7959


  0%|          | 0/563 [00:00<?, ?it/s]

Epoch [2/10], Accuracy: 74.55%, Loss: 0.7585


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch [2/10], Validation Accuracy: 77.94%, Validation Loss: 0.6485


  0%|          | 0/563 [00:00<?, ?it/s]

Epoch [3/10], Accuracy: 80.06%, Loss: 0.5994


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch [3/10], Validation Accuracy: 82.39%, Validation Loss: 0.5314


  0%|          | 0/563 [00:00<?, ?it/s]

Epoch [4/10], Accuracy: 83.93%, Loss: 0.4777


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch [4/10], Validation Accuracy: 85.04%, Validation Loss: 0.4638


  0%|          | 0/563 [00:00<?, ?it/s]

Epoch [5/10], Accuracy: 86.71%, Loss: 0.3997


  0%|          | 0/63 [00:00<?, ?it/s]

Epoch [5/10], Validation Accuracy: 85.84%, Validation Loss: 0.4361


  0%|          | 0/563 [00:00<?, ?it/s]

Traceback (most recent call last):
  File "/var/folders/3m/7slbgqtx4rlff38t_2cx53700000gp/T/ipykernel_51426/4226319062.py", line 43, in train
    trainCNN(device, train_loader, val_loader, test_loader, model, num_epochs=10, optimizer="Adam")
  File "/var/folders/3m/7slbgqtx4rlff38t_2cx53700000gp/T/ipykernel_51426/964984367.py", line 23, in trainCNN
    total_correct += (predicted == labels).sum().item()
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁▄▆▇█
loss,█▅▃▂▁
val_accuracy,▁▄▆██
val_loss,█▅▃▂▁

0,1
accuracy,0.8671
loss,0.39967
val_accuracy,0.85843
val_loss,0.43613


KeyboardInterrupt: 