In [None]:
import torchvision
import torch
import torchvision.transforms as transforms
import random
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import time
import os
import copy
import wandb
import json
import pandas as pd
import seaborn as sns


# Seeds for reproducibility
def set_seed(seed: int = 123):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(123)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
""" PLOT SETTINGS """

plt.style.use("seaborn-v0_8")
plt.rcParams.update({
    "font.size": 18,            # base font size
    "axes.titlesize": 24,       # axis titles
    "axes.labelsize": 22,       # axis labels
    "xtick.labelsize": 18,      # X axis numbers
    "ytick.labelsize": 18,      # Y axis numbers
    "legend.fontsize": 20,      # legend text
    "lines.linewidth": 3.0      # line thickness
})


In [None]:
""" CLASS TO CREATE THE HEAD with 100 CLASSES """
class DINOWithHead(nn.Module):
    def __init__(self, backbone, num_classes=100, p=None):
        super().__init__()
        self.backbone = backbone
        layers = []
        if p is not None:
            layers.append(nn.Dropout(p=p))
        layers.append(nn.Linear(384, num_classes))
        self.head = nn.Sequential(*layers)

    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)

### Dataset Download and Transformations

In [None]:
""" DATASET DOWNLOAD """

ROOT = './data'
BATCH_SIZE = 64
#BATCH_SIZE = 128
NUM_WORKERS = os.cpu_count()

tot_train_data = torchvision.datasets.CIFAR100(root=ROOT, train=True, download=True, transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR100(root=ROOT, train=False, download=True, transform=torchvision.transforms.ToTensor())

In [None]:
""" SPLIT TOT_TRAININ in VALIDATION and TRAIN """

def split_dataset(tot_train_data, valid_ratio=0.8):
    """
    Splits the given dataset randomly into training and validation subsets.
    """
    train_size = int(valid_ratio * len(tot_train_data))
    val_size = len(tot_train_data) - train_size
    train_data, val_data = random_split(tot_train_data, [train_size, val_size])
    return train_data, val_data

train_data, val_data = split_dataset(tot_train_data, valid_ratio=0.8)

In [None]:
""" DATA TRANSFORMATION """

def data_trasform(dataset, data_augmentation=False):   ### train_data or tot_train_data
    """
    Returns train and val/test transforms based on dataset stats.
    Dataset (for computing mean and std) can be either training only or combined train+validation.

    If data_augmentation=True, applies augmentation on training transforms, otherwise only resize and normalize.
    """

    # MEAN and VARIANCE (considering 3 channels)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    nb_samples = 0

    for img, _ in dataset:
        img = img.view(3, -1)  # Flatten H*W in seconda dimensione
        mean += img.mean(1)
        std += img.std(1)
        nb_samples += 1

    mean /= nb_samples
    std /= nb_samples


    if data_augmentation:
        train_transforms = transforms.Compose([
            transforms.Resize(64, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.RandomCrop(64, padding=4),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=15),
            transforms.RandAugment(num_ops=2, magnitude=9),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])
    else:
        train_transforms = transforms.Compose([
            transforms.Resize(64, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])

    ### NO DATA AUGMENTATION for val/test!
    val_test_transforms = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)   # Normalization using the training statistics
    ])


    return train_transforms, val_test_transforms

In [None]:
""" DATA TRANSFORMATION and LOADERS """

### ===== For hyperparameter tuning considering train_data and val_data =====
train_transforms, val_test_transforms = data_trasform(train_data)

train_data.dataset.transform = train_transforms
val_data.dataset.transform = val_test_transforms

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_data,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)



### ===== For model testing considering tot_train_data and test_data =====
train_transforms, val_test_transforms = data_trasform(tot_train_data)

tot_train_data = torchvision.datasets.CIFAR100(root=ROOT, train=True, download=False, transform=train_transforms)
test_data = torchvision.datasets.CIFAR100(root=ROOT, train=False, download=False, transform=val_test_transforms)

tot_train_loader = DataLoader(tot_train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

### Test and Training Function

In [None]:
""" TRAINING and TESTING """

def evaluate_model(model, data_loader, criterion):
    """
    The evaluate_model function computes the average loss and accuracy of a model on a dataset.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    total_loss = 0.0
    total_corrects = 0

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            total_corrects += torch.sum(preds == labels.data)

    avg_loss = total_loss / len(data_loader.dataset)
    avg_acc = total_corrects.double() / len(data_loader.dataset)
    return avg_loss, avg_acc.item()




def save_checkpoint(model, optimizer, scheduler, epoch, train_losses, train_accuracies,
                    val_test_losses, val_test_accuracies, best_acc, best_loss, best_model_wts, path):
    """
    The save_checkpoint function saves the model’s state, optimizer, scheduler, training/validation metrics,
    and best performance to a specified file path.
    """
    dir_name = os.path.dirname(path)
    if dir_name:
        os.makedirs(dir_name, exist_ok=True)
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_test_losses': val_test_losses,
        'val_test_accuracies': val_test_accuracies,
        'best_acc': best_acc,
        'best_loss': best_loss,
        'best_model_state_dict': best_model_wts
    }
    torch.save(checkpoint, path)




def init_checkpoint(model, optimizer, scheduler, path=None, device='cpu'):
    """
    Initialize a checkpoint. If path is None, create default checkpoint with empty/default values.
    If path is given and file exists, load it.
    """
    if path is None:  # default path
        os.makedirs("checkpoints", exist_ok=True)
        path = "checkpoints/latest.pth"
        print(f"Initializing new checkpoint at {path}")
        checkpoint = {   # save default empty checkpoint
            'epoch': 1,
            'best_acc': 0.0,
            'best_loss': 1e10,
            'train_losses': [],
            'train_accuracies': [],
            'val_test_losses': [],
            'val_test_accuracies': [],
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'best_model_state_dict': copy.deepcopy(model.state_dict())
        }
        torch.save(checkpoint, path)
        return 1, 0.0, 1e10, [], [], [], [], path, copy.deepcopy(model.state_dict())

    else: # load existing checkpoint
        if not os.path.isfile(path):
            raise FileNotFoundError(f"Checkpoint file {path} does not exist.")

        print(f"Loading checkpoint from {path}")
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler and checkpoint.get('scheduler_state_dict'):
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        best_model_wts = checkpoint.get('best_model_state_dict', copy.deepcopy(model.state_dict()))
        return (checkpoint['epoch'],
                checkpoint.get('best_acc', 0.0),
                checkpoint.get('best_loss', 1e10),
                checkpoint.get('train_losses', []),
                checkpoint.get('train_accuracies', []),
                checkpoint.get('val_test_losses', []),
                checkpoint.get('val_test_accuracies', []),
                path,
                best_model_wts )




def train_test_model(model, criterion, optimizer, scheduler, train_loader, val_test_loader,
                          num_epochs=10, checkpoint_path=None, checkpoints = True, verbose = 1):
                                        # If checkpoint_path = None, a path is created and training starts from scratch
                                        # If checkpoints = False, we don't save anything (used for the calibration part)
    """
    Trains a model with logging and evaluation, returning the best model and metrics.
    If a checkpoint path is provided, training will resume from the saved state in that file.
    """

    since = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # === Checkpoints (Initialize if None, Load if it exists already) ===
    if checkpoints: # Training with checkpoints
        start_epoch, best_acc, best_loss, train_losses, train_accuracies, val_test_losses, val_test_accuracies, checkpoint_path, best_model_wts = \
            init_checkpoint(model, optimizer, scheduler, path=checkpoint_path, device=device)

    else: # No checkpoint
        start_epoch = 1
        best_acc = 0.0
        best_loss = 1e10
        train_losses, train_accuracies, val_test_losses, val_test_accuracies = [], [], [], []
        best_model_wts = copy.deepcopy(model.state_dict())

    # ===== Epoch loop =====
    for epoch in range(start_epoch, num_epochs+1):
        if checkpoints:
            if verbose:
                print(f'\nEpoch {epoch}/{num_epochs}')
                print('-' * 30)

        # ===== Training =====
        model.train()
        train_loss = 0.0
        train_corrects = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)
            train_corrects += torch.sum(preds == labels.data)

        if scheduler:
            scheduler.step()

        epoch_train_loss = train_loss / len(train_loader.dataset)
        epoch_train_acc = train_corrects.double() / len(train_loader.dataset)

        if len(train_losses) >= epoch:     # Overwriting the current epoch's results if resuming from this position (if the current loop was never compelted)
            train_losses[epoch-1] = epoch_train_loss
            train_accuracies[epoch-1] = epoch_train_acc.item()
        else:
            train_losses.append(epoch_train_loss)
            train_accuracies.append(epoch_train_acc.item())

        if verbose == True:  # Print each round
            print(f'Train Loss: {epoch_train_loss:.4f} Train Acc: {epoch_train_acc:.4f}')


        # ===== Validation/Test =====
        epoch_val_test_loss, epoch_val_test_acc = evaluate_model(model, val_test_loader, criterion)

        if len(val_test_losses) >= epoch:    # Overwriting the current epoch's results if resuming from this position
            val_test_losses[epoch-1] = epoch_val_test_loss
            val_test_accuracies[epoch-1] = epoch_val_test_acc
        else:
            val_test_losses.append(epoch_val_test_loss)
            val_test_accuracies.append(epoch_val_test_acc)
        if verbose == True:
            print(f'Val/Test Loss: {epoch_val_test_loss:.4f}, Val/Test Acc: {epoch_val_test_acc:.4f}')

        if verbose == 'mid':
            if epoch == 1 or epoch % 5 == 0:   # Print occasionally
                print(f"Epoch {epoch}")
                print(f'Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}')
                print(f'Val/Test Loss: {epoch_val_test_loss:.4f}, Val/Test Acc: {epoch_val_test_acc:.4f}')

        # ===== Update and Save the best model =====
        if epoch_val_test_acc > best_acc:
            best_acc = epoch_val_test_acc
            best_loss = epoch_val_test_loss
            best_model_wts = copy.deepcopy(model.state_dict())

        # ===== Save checkpoint =====
        if checkpoints:
            save_checkpoint(
                model, optimizer, scheduler,
                epoch + 1,
                train_losses, train_accuracies,
                val_test_losses, val_test_accuracies,
                best_acc, best_loss, best_model_wts,
                checkpoint_path
            )

    # Training completed
    time_elapsed = time.time() - since
    print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best Acc: {best_acc:.4f}, Best Loss: {best_loss:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)

    return model, train_losses, val_test_losses, train_accuracies, val_test_accuracies

# (1) CENTRALIZED MODEL - HEAD ONLY

### (1.I) Hyperparameter Selection

In [None]:
''' WANDB TRAINING '''

def wandb_train(config=None):
    with wandb.init(config=config):
        config = wandb.config

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # === MODEL ===
        dino_vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
        dino_vits16.eval()  # evaluation mode for the backbone
        dino_vits16 = dino_vits16.to(device)

        # Freeze backbone params
        for param in dino_vits16.parameters():
            param.requires_grad = False

        # Create the final model with trainable Head
        model = DINOWithHead(dino_vits16, num_classes=100).to(device)
        for param in model.head.parameters():
            param.requires_grad = True
        model.head.train()    # training mode for the head


        # === LOSS, OPTIMIZER, SCHEDULER ===
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(
              model.parameters(),
              lr=config.lr,
              momentum=config.momentum,
              weight_decay=config.weight_decay
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.t_max)

        # === TRAINING ===
        model, train_losses, val_losses, train_accs, val_accs = train_test_model(
            model,
            criterion,
            optimizer,
            scheduler,
            train_loader,
            val_loader,
            num_epochs=config.epochs,
            verbose='mid'
        )


        # === LOG ===
        for epoch in range(config.epochs):
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_losses[epoch],
                "val_loss": val_losses[epoch] if val_losses else None,
                "train_accuracy": train_accs[epoch],
                "val_accuracy": val_accs[epoch] if val_accs else None,
            })

        # === BEST LOG ===
        if val_accs:
            best_idx = val_accs.index(max(val_accs))
            wandb.run.summary["best_val_accuracy"] = val_accs[best_idx]
            wandb.run.summary["best_val_loss"] = val_losses[best_idx]
            wandb.run.summary["best_train_accuracy"] = train_accs[best_idx]
            wandb.run.summary["best_train_loss"] = train_losses[best_idx]

In [None]:
wandb.login()

sweep_config = {
    'method': 'bayes',    # probabilistic model (based on previous results, it predicts which hyperparameter combinations are likely to lead to better performance)
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    'parameters': {
        'lr': {
            'distribution': 'log_uniform_values',   # log_uniform as numbers are small
            'min': 0.0001,
            'max': 0.001
        },
        'batch_size': {
            'value': 64
        },
        'momentum': {
            'distribution': 'uniform',
            'min': 0.7,
            'max': 0.9
        },
        'dropout': {
            'distribution': 'uniform',
            'min': 0.0,
            'max': 0.1
        },
        'weight_decay': {
            'distribution': 'log_uniform_values',    # log_uniform as numbers are small
            'min': 1e-6,
            'max': 1e-5
        },
        'epochs': {
            'value': 10
        },
        't_max': {
            'value': 10   # same as n_epochs
        }
    }
}

sweep_id = wandb.sweep(sweep_config, project="Project_Central_Grid")

[34m[1mwandb[0m: Currently logged in as: [33mgabriele_[0m ([33mgabriele-politecnico-di-torino[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Create sweep with ID: 0wz9r8de
Sweep URL: https://wandb.ai/gabriele-politecnico-di-torino/Project_Central_Grid/sweeps/0wz9r8de


In [None]:
wandb.agent(sweep_id, function=wandb_train, count=10)

### (1.II) Final Model (with best hyperparameters)


In [None]:
set_seed(seed=123)
N_EP = 40
T_MAX = N_EP
LR = 0.001
MOMENTUM = 0.8
WEIGHT_DECAY = 5e-6


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === MODEL ===
dino_vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
dino_vits16.eval()  # evaluation mode for the backbone
dino_vits16 = dino_vits16.to(device)

# Freeze backbone params
for param in dino_vits16.parameters():
    param.requires_grad = False

# Create the final model (trainable head)
model_central = DINOWithHead(dino_vits16, num_classes=100).to(device)
for param in model_central.head.parameters():
            param.requires_grad = True
model_central.head.train()    # training mode for the head

# === LOSS, OPTIMIZER, SCHEDULER ===
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
      model_central.head.parameters(), # We optimize only the head
      lr=LR,
      momentum=MOMENTUM,
      weight_decay=WEIGHT_DECAY
      )
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_MAX)

In [None]:
set_seed(seed=123)

# === MODEL TRAINING and TESTING with logging & checkpointing ===
#checkpoint_path = "checkpoints/latest.pth"     # It will start from where it stopped and "re-do" the last round if it was incomplete
checkpoint_path = None

start_time = time.time()
final_model, train_losses, test_losses, train_accuracies, test_accuracies = train_test_model(
    model_central,
    criterion,
    optimizer,
    scheduler,
    tot_train_loader,
    val_test_loader=test_loader,
    num_epochs=N_EP,
    checkpoint_path=checkpoint_path,
)
end_time = time.time()
elapsed_time = end_time - start_time

In [None]:
# === SAVING RESULTS ===
results = {
    "epochs": N_EP,
    "train_losses": train_losses,
    "train_accuracies": train_accuracies,
    "test_losses": test_losses,
    "test_accuracies": test_accuracies,
    "time_sec": round(elapsed_time, 2)
}
print(results)

json_filename = "results_central_BEST.json"

with open(json_filename, 'w') as f:
    json.dump(results, f, indent=2)
torch.save(final_model.state_dict(), "final_model_weights_BEST.pth")

In [None]:
# === PLOTTING RESULTS ===
epochs = range(1, N_EP+1)

# === Plot Loss ===
plt.figure(figsize=(9,5))
plt.plot(epochs, train_losses, label="Training Loss", marker="o")
if test_losses is not None:
    plt.plot(epochs, test_losses, label="Test Loss", marker="s")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training vs Test Loss", fontsize=26, fontweight="bold")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()

# === Plot Accuracy ===
plt.figure(figsize=(9,5))
plt.plot(epochs, train_accuracies, label="Training Accuracy", marker="o")
if test_accuracies is not None:
    plt.plot(epochs, test_accuracies, label="Test Accuracy", marker="s")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Training vs Test Accuracy", fontsize=26, fontweight="bold")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()