# EfficientNetV2 Fine-Tuning on the real dataset

### Set all the necessary parameters in the next cell:

In [1]:
# Choose which variant of EfficientNetV2 to train
model_size = 's' # s, m, l

# Choose which dataset to train on (it's stored in ../storage/datasets)
dataset_name = 'tiny'

# Choose the output directory for the trained model
output_dir = f'../storage/trained_efficientnet/{dataset_name}/{model_size}' 

### Import the necessary libraries

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset, default_collate
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, CosineAnnealingLR

from torchvision import datasets as dset
from torchvision.transforms import v2
from torchvision.models import efficientnet_v2_s, efficientnet_v2_m, efficientnet_v2_l, EfficientNet_V2_S_Weights, EfficientNet_V2_M_Weights, EfficientNet_V2_L_Weights

from datasets import load_dataset, load_from_disk

from medmnist import DermaMNIST, BloodMNIST

import os
import numpy as np
import shutil
from PIL import Image
import pandas as pd
import logging
from datetime import datetime
import random
import time
import copy
import json
import gc
import psutil
import scipy
import traceback
import csv
import sys
import pickle

import IPython

### Set the seed

In [3]:
seed = 420

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.backends.cudnn.benchmark = True

### Create the logger

In [4]:
os.makedirs(output_dir, exist_ok=False) # safe check

log_filename = os.path.join(output_dir, 'training.log')
logging.basicConfig(filename=log_filename, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filemode='w') # filemode='w' to overwrite the log file every time the pipeline is run
logger = logging.getLogger()

### Datasets setup

In [5]:
class OxfordPetsDataset(Dataset):
    def __init__(self, root='../datasets/pets', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train85.pth' if split == 'train' else root+'/test15.pth'
        self.data = torch.load(self.data_dir)
        self.classes = sorted(set(label.item() for _, label in self.data)) # -> 0, 1, 2, ..., 36

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, label = self.data[index]

        if self.transform:
            img = self.transform(img)

        return img, label


class TinyImageNetDataset(Dataset): # from HuggingFace
    def __init__(self, root='../datasets/tiny', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train' if split == 'train' else root+'/valid'
        self.data = load_from_disk(self.data_dir)
        self.classes = self.data.features['label'].names
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        example = self.data[index]
        img = example['image']
        label = example['label']

        # Convert the 1821 grayscale images to RGB by duplicating channels (colour will remain gray, it's just for consistency)
        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label

In [6]:
def get_transform(dataset_name, image_size, huge_augment, horizontal_flip, random_crop, random_erasing, train=True):

    if dataset_name == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2470, 0.2435, 0.2616]
    elif dataset_name == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
    elif dataset_name == 'pets':
        mean = [0.4717, 0.4499, 0.3837]
        std = [0.2726, 0.2634, 0.2794]
    elif dataset_name == 'cars':
        mean = [0.4708, 0.4602, 0.4550]
        std = [0.2892, 0.2882, 0.2968]
    elif dataset_name == 'food':
        mean = [0.5450, 0.4435, 0.3436]
        std = [0.2695, 0.2719, 0.2766]
    elif dataset_name == 'tiny':
        mean = [0.4805, 0.4483, 0.3978]
        std = [0.2177, 0.2138, 0.2136]
    elif dataset_name == 'dermamnist': 
        mean = [0.7632, 0.5381, 0.5615]
        std = [0.0872, 0.1204, 0.1360]
    elif dataset_name == 'bloodmnist':
        mean = [0.7961, 0.6596, 0.6964]
        std = [0.2139, 0.2464, 0.0903]
    else:
        raise TypeError(f"Unknown dataset: {dataset_name}. Supported dataset names are cifar10, cifar100, pets, cars, food, tiny, dermamnist, bloodmnist.")

    transformations = [v2.ToTensor(), v2.Resize(image_size, interpolation=Image.BICUBIC, antialias=True)] # for both train and test sets

    if train:
        if horizontal_flip:
            transformations.append(v2.RandomHorizontalFlip())

        if huge_augment == 'trivial_augment':
            transformations.append(v2.TrivialAugmentWide())
        elif huge_augment == 'auto_augment':
            transformations.append(v2.AutoAugment())
        elif huge_augment == 'rand_augment':
            transformations.append(v2.RandAugment())
        elif huge_augment == 'aug_mix':
            transformations.append(v2.AugMix())

        if random_crop:
            padding_fraction = 0.10  # 10% padding
            new_padding = int(image_size[0] * padding_fraction)
            transformations.append(v2.RandomCrop(image_size[0], padding=new_padding))

    transformations.append(v2.Normalize(mean=mean, std=std)) # also for the test set

    if train and random_erasing: # i put it after normalization for consistency with the code i used so far
        transformations.append(v2.RandomErasing(value='random'))

    transform = v2.Compose(transformations)

    return transform

In [7]:
def get_real_dataset(dataset_name, image_size, huge_augment, horizontal_flip, random_crop, random_erasing):
    
    logger.info(f'Dataset {dataset_name} loading and processing...')

    training_transformations = get_transform(dataset_name, image_size, huge_augment, horizontal_flip, random_crop, random_erasing, train=True)
    test_transformations = get_transform(dataset_name, image_size, huge_augment, horizontal_flip, random_crop, random_erasing, train=False)

    if dataset_name == 'cifar10':
        train_data = dset.CIFAR10(root='../datasets/cifar10', train=True, transform=training_transformations, download=True)
        test_data = dset.CIFAR10(root='../datasets/cifar10', train=False, transform=test_transformations, download=True)
    elif dataset_name == 'cifar100':
        train_data = dset.CIFAR100(root='../datasets/cifar100', train=True, transform=training_transformations, download=True)
        test_data = dset.CIFAR100(root='../datasets/cifar100', train=False, transform=test_transformations, download=True)
    elif dataset_name == 'pets':
        train_data = OxfordPetsDataset(root='../datasets/pets', split='train', transform=training_transformations)
        test_data = OxfordPetsDataset(root='../datasets/pets', split='test', transform=test_transformations)
    elif dataset_name == 'cars':
        train_data = dset.StanfordCars(root='../datasets/cars', split='train', transform=training_transformations, download=False) # download does not work for this dataset, it's only for backward compatibility
        test_data = dset.StanfordCars(root='../datasets/cars', split='test', transform=test_transformations, download=False) # download does not work for this dataset, it's only for backward compatibility
    elif dataset_name == 'food':
        train_data = dset.Food101(root='../datasets/food', split='train', transform=training_transformations, download=True)
        test_data = dset.Food101(root='../datasets/food', split='test', transform=test_transformations, download=True)
    elif dataset_name == 'tiny':
        train_data = TinyImageNetDataset(root='../datasets/tiny', split='train', transform=training_transformations)
        test_data = TinyImageNetDataset(root='../datasets/tiny', split='valid', transform=test_transformations) # valid is the test set for Tiny ImageNet huggingface dataset
    elif dataset_name == 'dermamnist':
        train_data = DermaMNIST(root='../datasets/dermamnist', split='train', size=224, as_rgb=True, transform=training_transformations, download=True)
        test_data = DermaMNIST(root='../datasets/dermamnist', split='test', size=224, as_rgb=True, transform=test_transformations, download=True)
    elif dataset_name == 'bloodmnist':
        train_data = BloodMNIST(root='../datasets/bloodmnist', split='train', size=224, as_rgb=True, transform=training_transformations, download=True)
        test_data = BloodMNIST(root='../datasets/bloodmnist', split='test', size=224, as_rgb=True, transform=test_transformations, download=True)
    else:
        raise TypeError(f"Unknown dataset: {dataset_name}. Supported dataset names are cifar10, cifar100, pets, cars, food, tiny, dermamnist, bloodmnist.")


    logger.info(f'Dataset {dataset_name} loaded and processed!\n')
        
    return train_data, test_data

### EfficientNetV2 definition

In [8]:
class CustomEfficientNetV2(nn.Module):
    def __init__(self, size, num_classes, pretrained=True):
        assert size in ['s', 'm', 'l'] # safe check
        super(CustomEfficientNetV2, self).__init__()
        
        if pretrained:
            if size == 's':
                weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1
                self.model = efficientnet_v2_s(weights=weights)
            elif size == 'm':
                weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1
                self.model = efficientnet_v2_m(weights=weights)
            elif size == 'l':
                weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1
                self.model = efficientnet_v2_l(weights=weights)
        else:
            if size == 's':
                self.model = efficientnet_v2_s(weights=None)
            elif size == 'm':
                self.model = efficientnet_v2_m(weights=None)
            elif size == 'l':
                self.model = efficientnet_v2_l(weights=None)
        
        num_features = self.model.classifier[1].in_features
        
        self.model.classifier[1] = nn.Linear(num_features, num_classes)
        
    def forward(self, x):
        return self.model(x)

### Auxiliary functions for training

In [9]:
def learning_rate_scheduling(optimizer, scheduler_name=None, warmup=False, warmup_epochs=0, epochs=100):
    if scheduler_name is not None:
        scheduler_list = []

        if warmup:
            lr_lambda = lambda epoch: (epoch  / warmup_epochs) + 1e-5
            warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
            scheduler_list.append(warmup_scheduler)
            milestones = [warmup_epochs]
        else:
            milestones = []

        if scheduler_name == 'CosineAnnealingLR':
            scheduler_lr = CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs, eta_min=0, last_epoch=-1, verbose=False)
            scheduler_list.append(scheduler_lr)
            
        if scheduler_list:
            scheduler = SequentialLR(optimizer, schedulers=scheduler_list, milestones=milestones)
        else:
            scheduler = None
    else:
        scheduler = None

    return scheduler

In [10]:
def train_epoch(
    model, 
    train_dataloader,
    dataset_name, 
    criterion, 
    optimizer, 
    scaler, 
    mixed_precision, 
    scheduler,
    device,
    cutmix_or_mixup
):
    model.train()
    total_loss, correct_predictions, total_samples, nan_encountered = 0.0, 0, 0, False
    
    for inputs, labels in train_dataloader:
        inputs, labels = inputs.float().to(device), labels.to(device)
        optimizer.zero_grad()

        # Squeeze (if not already done in collate_fn) for dermamnist or bloodmnist: from shape [batch_size, 1] to shape [batch_size]
        if not cutmix_or_mixup and dataset_name in ['dermamnist', 'bloodmnist'] and labels.dim() == 2:
            labels = labels.squeeze(1)
        
        with torch.cuda.amp.autocast(enabled=mixed_precision):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            if torch.isnan(loss).any():
                logger.info("==> Encountered NaN value in loss.")
                return 0, 0, True
            
            scaler.scale(loss).backward() if mixed_precision else loss.backward()
            scaler.step(optimizer) if mixed_precision else optimizer.step()
            scaler.update() if mixed_precision else None
        
        total_loss += loss.item()
        _, int_predictions = torch.max(outputs, 1)
        if cutmix_or_mixup:
            _, labels = torch.max(labels, 1)
        correct_predictions += (int_predictions == labels).sum().item()
        total_samples += labels.size(0)
    
    if scheduler:
        scheduler.step()
    
    average_loss = total_loss / len(train_dataloader)
    accuracy = correct_predictions / total_samples
    return average_loss, accuracy, nan_encountered

def test_epoch(
    model,
    test_dataloader,
    dataset_name,
    criterion,
    device
):
    model.eval()
    total_loss, correct_predictions, total_samples = 0.0, 0, 0
    
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.float().to(device), labels.to(device)

            # Squeeze only for dermamnist or bloodmnist: from shape [batch_size, 1] to shape [batch_size]
            if dataset_name in ['dermamnist', 'bloodmnist'] and labels.dim() == 2:
                labels = labels.squeeze(1)
            elif dataset_name in ['dermamnist', 'bloodmnist'] and labels.dim() != 2:
                raise ValueError(f"Medmnist labels have an unexpected shape: {labels.shape}")
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, int_predictions = torch.max(outputs, 1)
            correct_predictions += (int_predictions == labels).sum().item()
            total_samples += labels.size(0)
    
    average_loss = total_loss / len(test_dataloader)
    accuracy = correct_predictions / total_samples
    return average_loss, accuracy

### Training function

In [11]:
def train(
    model_size,
    dataset_name,
    image_size=(224, 224),
    epochs=50,
    batch_size=96,
    optimizer_name='AdamW',
    learning_rate=0.001,
    weight_decay=0.00005,
    warmup=False,
    warmup_epochs=0,
    mixed_precision=True,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    scheduler_name='CosineAnnealingLR',
    patience=30,
    horizontal_flip=True,
    huge_augment='aug_mix',
    random_crop=True,
    random_erasing=False,
    cutmix=False,
    mixup=True,
    label_smoothing=0.1,
):

    # Write on the log file the training parameters
    logger.info("Training Parameters:\n")
    logger.info(f"seed: {seed}")
    for param_name, param_value in locals().items(): # locals() returns a dictionary with the current local variables
        logger.info(f"{param_name}: {param_value}")
    logger.info("---------------------------------------------------------\n")

    # Record the start time of the script
    start_time_script = time.time()

    # Load the real dataset
    train_dataset, test_dataset = get_real_dataset(
        dataset_name = dataset_name, 
        image_size = image_size,
        horizontal_flip = horizontal_flip, 
        huge_augment = huge_augment, 
        random_crop = random_crop, 
        random_erasing = random_erasing
    )

    # Get the number of classes of the dataset
    if dataset_name == 'dermamnist': # since the medmnist datasets don't have the "classes" attribute
        num_classes = 7
    elif dataset_name == 'bloodmnist':
        num_classes = 8
    else:
        num_classes = len(train_dataset.classes)

    # Create the train and test dataloaders
    logger.info("Creating Train and Test Dataloaders...")
    if cutmix or mixup:
        if cutmix and mixup:
            advanced_transform = v2.RandomChoice([v2.CutMix(num_classes=num_classes), v2.MixUp(num_classes=num_classes)])
        elif cutmix:
            advanced_transform = v2.CutMix(num_classes=num_classes)
        else:
            advanced_transform = v2.MixUp(num_classes=num_classes)
        
        def collate_fn(batch):
            images, labels = default_collate(batch)
            if dataset_name in ['dermamnist', 'bloodmnist'] and labels.dim() == 2:
                labels = labels.squeeze(1)
            elif dataset_name in ['dermamnist', 'bloodmnist'] and labels.dim() != 2:
                raise ValueError(f"Medmnist labels have an unexpected shape: {labels.shape}")
            return advanced_transform(images, labels)

        cutmix_or_mixup = True
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2, collate_fn=collate_fn)
    else:
        cutmix_or_mixup = False
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2)

    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, prefetch_factor=2)
    logger.info(f"Train and Test Dataloaders created!\n")

    # Create the classifier model
    logger.info("Creating the classifier...")
    model = CustomEfficientNetV2(size=model_size, num_classes=num_classes)
    model = model.to(device)
    logger.info("Classifier created!\n")

    # Create the scaler, the loss function, the optimizer and the learning rate scheduler
    scaler = torch.cuda.amp.GradScaler()
    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    if optimizer_name == 'AdamW':
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, nesterov=True, momentum=0.9)
    if warmup:
        epochs += warmup_epochs
    scheduler = learning_rate_scheduling(optimizer, scheduler_name, warmup, warmup_epochs, epochs)

    # Initialize the training variables
    best_accuracy = 0.0
    best_model = None
    no_improvement_epochs = 0
    history = {
        'epoch': [],
        'train_loss': [],
        'train_accuracy': [],
        'test_loss': [],
        'test_accuracy': [],
        'learning_rate': [],
        'epoch_time': []
    }

    # Start the training loop
    logger.info("The training loop starts now!\n")
    for epoch in range(epochs):
        start_time = time.time()

        train_loss, train_accuracy, nan_encountered = train_epoch(
            model=model,
            train_dataloader=train_dataloader,
            dataset_name=dataset_name,
            criterion=criterion,
            optimizer=optimizer,
            scaler=scaler,
            mixed_precision=mixed_precision,
            scheduler=scheduler,
            device=device,
            cutmix_or_mixup=cutmix_or_mixup
        )
        
        if nan_encountered:
            train_loss, train_accuracy, nan_encountered = train_epoch(
                model=model,
                train_dataloader=train_dataloader,
                dataset_name=dataset_name,
                criterion=criterion,
                optimizer=optimizer,
                scaler=scaler,
                mixed_precision=False,
                scheduler=scheduler,
                device=device
            )
            
        test_loss, test_accuracy = test_epoch(
            model=model,
            test_dataloader=test_dataloader,
            dataset_name=dataset_name,
            criterion=criterion,
            device=device
        )

        scheduler_last_lr = scheduler.get_last_lr()[0] if scheduler is not None else learning_rate
        epoch_time = time.time() - start_time

        history['epoch'].append(epoch + 1)
        history['train_loss'].append(train_loss)
        history['train_accuracy'].append(train_accuracy)
        history['test_loss'].append(test_loss)
        history['test_accuracy'].append(test_accuracy)
        history['learning_rate'].append(scheduler_last_lr)
        history['epoch_time'].append(epoch_time)

        logger.info(f"Epoch {epoch+1}/{epochs} - Time: {epoch_time:.2f}s - "
          f"Train Loss: {train_loss:.4f} - Train Acc: {train_accuracy:.4f} - "
          f"Test Loss: {test_loss:.4f} - Test Acc: {test_accuracy:.4f} - "
          f"LR: {scheduler_last_lr:.6f}")

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            best_model = copy.deepcopy(model)
            no_improvement_epochs = 0
        else:
            no_improvement_epochs += 1
            if no_improvement_epochs >= patience:
                logger.info(f"Early stopping at epoch {epoch+1} due to no improvement in test accuracy for {patience} consecutive epochs.")
                break

    torch.cuda.empty_cache()

    metadata = {
        'model': best_model,
        'best_accuracy': best_accuracy,
        'history': history
    }

    # Save the best model and print its accuracy
    torch.save(best_model.state_dict(), os.path.join(output_dir, 'best_model.pth'))
    logger.info(f"Best accuracy: {best_accuracy * 100:.2f} %\n")

    # Record the duration of the script
    end_time_script = time.time()
    elapsed_time_script = end_time_script - start_time_script
    hours, minutes, seconds = int(elapsed_time_script // 3600), int((elapsed_time_script % 3600) // 60), int(elapsed_time_script % 60)
    logger.info(f"Execution completed in {hours}h {minutes}m {seconds}s.\n")


### Train!

In [None]:
try:
    train(
        model_size = model_size,
        dataset_name = dataset_name,
    )

except Exception as e:
    logger.error(f"An error occurred during the training process: {e}")
    logger.error(traceback.format_exc())
    raise e

finally:
    IPython.Application.instance().kernel.do_shutdown(restart=True)