## EXERCISE 0

Download and extract the Python version of the CIFAR dataset from the CIFAR website.  
The dataset is structured as follows:

- **5 training batches** of 10,000 images each (50,000 images total)  
- **1 test batch** of 10,000 images  

Each batch file contains a **dictionary** with the following elements:

- **`data`** – a 10,000 × 3,072 NumPy array of `uint8`.  
  Each row represents a 32×32 color image:  
  - The first 1,024 entries are the red channel,  
  - The next 1,024 entries are the green channel,  
  - The final 1,024 entries are the blue channel.  
  Images are stored in **row-major order**, i.e., the first 32 entries correspond to the red values of the first row of the image.

- **`labels`** – a list of 10,000 integers in the range 0–9.  
  The number at index `i` indicates the label of the `i`th image in `data`.

The dataset also contains another file called **`batches.meta`**, which is a Python dictionary with the following entry:

- **`label_names`** – a 10-element list giving meaningful names to the numeric labels in the `labels` array described above.  
  For example:  
  ```python
  label_names[0] == "airplane"
  label_names[1] == "automobile"
  # etc.


## EXERCISE 1
Create a Dataset class to read the data. When initialized, this class should
take as arguments the path to the data, the transformation to be applied to each
image and if the dataset is train or test. If train you should load all
the 5 batches that composed the whole CIFAR training set. [2.0 pts]

The resulting class, `CIFAR10Dataset`, was designed to integrate seamlessly with the PyTorch framework, allowing full use of PyTorch tools for subsequent data processing and augmentation. 

In [None]:
import os
import pickle
import time
import json
import csv
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from sklearn.metrics import (roc_auc_score, roc_curve, auc,
                             confusion_matrix, classification_report, accuracy_score)
from sklearn.preprocessing import label_binarize


class CIFAR10Dataset(Dataset):
    """
    CIFAR-10 dataset class compatible with PyTorch.

    Parameters
    ----------
    path : str
        Path to the CIFAR-10 batch files.
    data_type : str, default='train'
        Type of dataset to load. Must be 'train' or 'test'.
    transform : callable, optional
        A function/transform to apply to each image.

    Methods
    -------
    __len__()
        Returns the number of samples in the dataset.
    __getitem__(idx)
        Returns the image and label at the given index, applying the transform if specified.
    visualize(img, mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
        Display a single image (tensor or NumPy array), un-normalizing if needed.
    """
    def __init__(self, path, data_type='train', transform=None):

        self.path = path
        self.transform = transform

        # Load training data
        if data_type == 'train':
            data_list = []
            labels_list = []
            #Iterate for each of the 5 training batches
            for i_batch in range(1, 6):
                file = os.path.join(path, f"data_batch_{i_batch}")
                with open(file, 'rb') as fo:
                    # Deserialize binary batch files
                    batch = pickle.load(fo, encoding='bytes')
                    # Reshape data to (num_images, channels, height, width)
                    imgs = np.reshape(batch[b'data'], (10000, 3, 32, 32))
                    data_list.append(imgs)
                    #Creates one single label list for training
                    labels_list.extend(batch[b'labels'])
            self.data = np.vstack(data_list)  # combines all batches of train in a 4D Numpy array (50000,3,32,32)
            self.labels = np.array(labels_list) # converts all labels of train in a 1D Numpy array (50000,)

        # Load test data
        elif data_type == 'test':
            file = os.path.join(path, "test_batch")
            with open(file, 'rb') as fo:
                # Deserialize binary batch file
                batch = pickle.load(fo, encoding='bytes')
                # Reshape images to (num_images, channels, height, width)
                self.data = np.reshape(batch[b'data'], (10000, 3, 32, 32))
                self.labels = np.array(batch[b'labels']) # converts all labels of test in a 1D Numpy array (50000,)
        else:
            raise ValueError("data_type must be 'train' or 'test'")

    # Return number of samples
    def __len__(self):
        return len(self.data)


    #Function overrides __getitem__ standard python method called when an object is indexed.
    #The function will now return one sample (image and label) by index when object is indexed.
    # Any specified transformation is applied on-the-fly for efficiency and to allow online augmentation.

    def __getitem__(self, idx):
        img = self.data[idx]       # single image in CHW format (3,32,32)
        label = self.labels[idx]   # corresponding label
        img = np.transpose(img, (1, 2, 0))  # convert to HWC format for transforms/visualization
        
        if self.transform:
            img = self.transform(img)    # apply optional transform;
        return img, label

    # Display a single image (tensor or NumPy array)
    def visualize(self, img, mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)):
        #Assuming data is transformed
        if isinstance(img, torch.Tensor):
            # Check if the image is already normalized otherwise un-normalize
            if img.min() < 0 or img.max() > 1:
                mean_t = torch.tensor(mean).view(3,1,1)
                std_t  = torch.tensor(std).view(3,1,1)
                img = img * std_t + mean_t  # Un-normalize
            npimg = img.numpy()
            plt.imshow(npimg)     
        else:
            plt.imshow(img)          # For NumPy arrays
        plt.axis('off')
        plt.show()



## Exercise 2
Build a CNN model to predict a class from the input image (you can use
the Conv2D module and one of the plenty pooling layers already implemented).
Which are the main hyperparameters you should set to build the main model?
Good practice is to build the model class as general as possible, and specify the
hyperparaemeters when the class is called. [2.0 pts]

The resulting class, `GeneralCNN`, was designed to integrate seamlessly with the PyTorch framework, allowing full use of PyTorch tools. Specifically, the class inherits from `nn.Module`, which enables automatic parameter management, supports custom forward passes, allows gradient computation via autograd, and ensures compatibility with optimizers and model saving/loading.


In [2]:
"""
MODULES in PyTorch are the fundamental building blocks of neural networks.
A module is any class that inherits from `nn.Module` and can encapsulate both
layers with learnable parameters (like `nn.Linear` or `nn.Conv2d` or personalized Moudles like GeneralCNN) 
and other submodules, forming a hierarchical structure.
Each module defines a `forward` method specifying how input data is transformed
into output, and all parameters registered within the module are automatically
tracked for gradient computation, optimization, and serialization.
"""

class GeneralCNN(nn.Module):
    """
    General CNN class compatible with PyTorch.

    Parameters
    ----------
    in_channels : int, default=3
        Number of input channels (e.g., 3 for RGB images).
    num_classes : int, default=10
        Number of output classes for classification.
    conv_layers : list of tuples, default=[(6,5), (16,5)]
        Each tuple defines a convolutional layer as (out_channels, kernel_size).
    pool_type : str, default='max'
        Type of pooling layer, either 'max' or 'avg'.
    pool_kernel : int, default=3
        Kernel size of the pooling layer.
    pool_stride : int, default=2
        Stride of the pooling layer.
    fc_layers : list of int, default=[120, 84]  
        Specifies the number of neurons in each fully connected layer 
        before the output layer and after the flattening.
    activation : str, default='relu'
        Activation function: 'relu', 'leaky_relu', or 'elu'.
    input_size : tuple, default=(32,32)
        Input image size as (height, width).

    Methods
    -------
    forward(x)
        Computes the forward pass of the network.
    """
    def __init__(self,
                 in_channels=3,
                 num_classes=10,
                 conv_layers=[(6, 5), (16, 5)],
                 pool_type='max',
                 pool_kernel=3,
                 pool_stride=2,
                 fc_layers=[120, 84],
                 activation='relu',
                 input_size=(32, 32)):
        
        #Call the init method of the nn.Module class
        super(GeneralCNN, self).__init__()

        # Select activation function
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'leaky_relu':
            self.activation = nn.LeakyReLU()
        elif activation == 'elu':
            self.activation = nn.ELU()
        else:
            raise ValueError("Unsupported activation")

        # Build convolutional layers
        self.layers = nn.ModuleList() #Create an object list of modules for the layers
        current_channels = in_channels
        for out_channels, kernel_size in conv_layers:
            conv = nn.Conv2d(current_channels, out_channels, kernel_size=kernel_size)
            self.layers.append(conv) #use the append of the nn.Modulelist class, to append the module
            current_channels = out_channels

        # Define pooling layer
        if pool_type == 'max':
            self.pool = nn.MaxPool2d(kernel_size=pool_kernel, stride=pool_stride)
        elif pool_type == 'avg':
            self.pool = nn.AvgPool2d(kernel_size=pool_kernel, stride=pool_stride)
        else:
            raise ValueError("pool_type must be 'max' or 'avg'")

        self.fc_config = fc_layers
        self.num_classes = num_classes

        # Compute flattened size after conv + pool layers using a dummy input
        # Specifically, `torch.no_grad()` is used to avoid tracking gradients since no
        # backward pass or parameter update is needed during this computation.
        with torch.no_grad():
            # Creates a zero tensor of size [1, C, H, W] (1 image, C channels, H height, W width)
            # This dummy input is used to compute the flattened size after conv + pool layers;
            # the first value represents the batch size and can be any number selected
            dummy = torch.zeros(1, in_channels, input_size[0], input_size[1])
            x = dummy
            for conv in self.layers:
                x = conv(x)
                x = self.activation(x)
                if self.pool is not None:
                    x = self.pool(x)
            # This will flatten all dimensions except the batch and return the total number of features per image
            flattened_size = torch.flatten(x, start_dim=1).shape[1]

        # Build fully connected layers
        self.fc_layers = nn.ModuleList() #Create an object list of modules for the fully connected
        input_size_fc = flattened_size
        for units in self.fc_config:
            self.fc_layers.append(nn.Linear(input_size_fc, units))
            input_size_fc = units
        self.fc_layers.append(nn.Linear(input_size_fc, self.num_classes))

    def forward(self, x):
        """
        Forward pass through the network.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch_size, in_channels, height, width).

        Returns
        -------
        torch.Tensor
            Output logits of shape (batch_size, num_classes). 
            Logits are the raw, unnormalized scores produced by the network for each class.
            They are not probabilities yet. To convert logits to probabilities, the torch.softmax() 
            function can be used.
        """
        for conv in self.layers:
            x = conv(x)
            x = self.activation(x)
            if self.pool is not None:
                x = self.pool(x)
        x = torch.flatten(x, 1)  # Flatten for fully connected layers
        for i, fc in enumerate(self.fc_layers):
            if i < len(self.fc_layers) - 1:
                x = self.activation(fc(x))
            else:
                x = fc(x)
        return x


In [3]:
"""
    # --- ADDED METRICS ---

What changed / added:
 - A helper function `metrics_from_confusion_matrix(cm, class_names=None)` that
   computes TP, FP, FN, TN and derived metrics (precision, recall/sensitivity,
   specificity, F1, support) per class and macro averages.
   (Marked "# --- ADDED METRICS ---" where defined and where used.)
 - The run_experiment() routine now calls this helper for both train and test
   confusion matrices, prints the per-class metrics and saves them as CSV.
   (Marked "# --- ADDED METRICS ---" where results are printed/saved.)
 - The script preserves previous behavior: timing, plotting and saving confusions,
   train-vs-test confusion difference, and two experiments (normalized vs plain).

Run: python train_cifar_confusion_metrics.py
"""


# ----------------------------
# --- ADDED METRICS ---
# Helper: compute per-class metrics from confusion matrix (one-vs-all)
# This follows the approach this requested article:
# https://towardsdatascience.com/multi-class-classification-extracting-performance-metrics-from-the-confusion-matrix-b379b427a872/
# ----------------------------
def metrics_from_confusion_matrix(cm, class_names=None):
    """
    Given a square confusion matrix cm (actual rows, predicted cols),
    compute per-class TP, FP, FN, TN and derived metrics:
      precision, recall (sensitivity), specificity, f1, support

    Returns:
      metrics_per_class: dict[class_name or index] -> dict of metrics
      summary: dict with macro-averages and overall accuracy
    """
    cm = np.array(cm, dtype=np.int64)
    n_classes = cm.shape[0]
    total = cm.sum()
    diag = np.diag(cm)
    metrics_per_class = {}
    eps = 1e-12

    for i in range(n_classes):
        TP = int(cm[i, i])
        FP = int(cm[:, i].sum() - TP)
        FN = int(cm[i, :].sum() - TP)
        TN = int(total - TP - FP - FN)

        precision = TP / (TP + FP + eps)
        recall = TP / (TP + FN + eps)  # sensitivity
        specificity = TN / (TN + FP + eps)
        f1 = 2 * precision * recall / (precision + recall + eps)
        support = int(cm[i, :].sum())

        name = class_names[i] if class_names is not None else str(i)
        metrics_per_class[name] = {
            'TP': TP, 'FP': FP, 'FN': FN, 'TN': TN,
            'precision': precision, 'recall': recall,
            'specificity': specificity, 'f1': f1, 'support': support
        }

    # Overall metrics
    accuracy = diag.sum() / (total + eps)
    # macro averages
    macro_precision = np.mean([m['precision'] for m in metrics_per_class.values()])
    macro_recall = np.mean([m['recall'] for m in metrics_per_class.values()])
    macro_f1 = np.mean([m['f1'] for m in metrics_per_class.values()])

    summary = {
        'accuracy': accuracy,
        'macro_precision': macro_precision,
        'macro_recall': macro_recall,
        'macro_f1': macro_f1,
        'total': int(total)
    }

    return metrics_per_class, summary
# ----------------------------
# --- END ADDED METRICS ---
# ----------------------------

In [8]:
'''
This class is reponsable for performing to expose the method for 
performing the training of the model, to evaluate it and to save the metrics
'''


class Experiment:
    def __init__(self, exp_name, transform, data_path='data', batch_size=128, num_epochs=50,
                 device=None, num_workers=0, pin_memory=False, class_names=None):
        self.exp_name = exp_name
        self.transform = transform
        self.data_path = data_path
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.device = device
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.class_names = class_names


        # Instanciate datasets & loaders
        self.train_dataset = CIFAR10Dataset(path=data_path, data_type='train', transform=transform)
        self.test_dataset  = CIFAR10Dataset(path=data_path, data_type='test', transform=transform)
        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True,
                                       num_workers=num_workers, pin_memory=pin_memory)
        self.eval_train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=False,
                                   num_workers=num_workers, pin_memory=pin_memory)
        self.test_loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False,
                                      num_workers=num_workers, pin_memory=pin_memory)
        
        # Visualize first image, to see if data was loaded correctly 
        try:
            img0, label0 = self.train_dataset[0]
            print(f"Sample visualization (exp={exp_name}) - Label: {label0}")
            # If the transform includes Normalize((0.5,...),(0.5,...)) this will unnormalize inside visualize
            # We cannot automatically detect the Normalize params reliably here, so we assume the common .5/.5 default.
            self.train_dataset.visualize(img0, mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
        except Exception as e:
            print("Could not visualize sample image:", e)

        # Instanciate model, criterion and optimizer
        self.model = GeneralCNN()
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        # model.parameters() returns an iterator over all learnable parameters (weights and biases).
        # The optimizer updates these parameters by exploiting the iteratior inside the class during training based on the gradients

    def train_epoch(self):
        """Train the model for a single epoch and return the average epoch loss."""
        self.model.train()                              # Enable training mode (dropout, BN updates)
        running_loss = 0.0                              # Accumulate batch losses over the epoch

        for images, labels in self.train_loader:        # Iterate over training batches
            images = images.to(self.device, non_blocking=(self.num_workers > 0))   # Move images to device
            labels = labels.to(self.device, non_blocking=(self.num_workers > 0))   # Move labels to device

            self.optimizer.zero_grad()                  # Clear gradients from previous iteration
            outputs = self.model(images)                # Forward pass: compute predictions
            loss = self.criterion(outputs, labels)      # Compute loss
            loss.backward()                             # Backward pass: compute gradients
            self.optimizer.step()                       # Update model parameters

            running_loss += loss.item()                 # Accumulate loss for this batch

        return running_loss / len(self.train_loader)    # Return average epoch loss



    def train(self):
        """Train the model for all epochs, tracking loss and timing."""
        print(f"\n=== Starting training for experiment: {self.exp_name} ===")
        print()
        total_params = sum(p.numel() for p in self.model.parameters())
        opt_params = sum(p.numel() for g in self.optimizer.param_groups for p in g['params'])
        print(f"Total model params: {total_params}, Optimizer params: {opt_params}")
        assert total_params == opt_params, "Parameter count mismatch: optimizer may not include all model parameters"
        train_losses = []                               # Store average loss per epoch
        epoch_times = []                                # Store duration of each epoch

        total_start = time.perf_counter()               # Start total training timer

        for epoch in range(self.num_epochs):            # Loop over epochs
            epoch_start = time.perf_counter()           # Start timer for this epoch

            epoch_loss = self.train_epoch()             # Train for one epoch

            epoch_end = time.perf_counter()             # End timer
            epoch_duration = epoch_end - epoch_start    # Compute epoch duration

            train_losses.append(epoch_loss)             # Save epoch loss
            epoch_times.append(epoch_duration)          # Save epoch time

            print(f"Epoch [{epoch+1}/{self.num_epochs}] - "
                f"Avg Loss: {epoch_loss:.4f} - Time: {epoch_duration:.2f}s")

        total_end = time.perf_counter()                 # End total training timer
        total_training_time = total_end - total_start   # Compute total training time

        print(f"Total training time for {self.exp_name}: {total_training_time:.2f}s")

        # Store metrics for later use (plots, logs, etc.)
        self.train_losses = train_losses
        self.epoch_times = epoch_times

        return train_losses, epoch_times



    def eval_train(self):
        self.model.eval()  # Set model to evaluation mode (disables dropout, batchnorm updates)
        all_probs, all_labels = [], []  # Lists to store batch-wise probabilities and labels

        with torch.no_grad():  # Disable gradient computation for efficiency
            for images, labels in self.eval_train_loader:
                images = images.to(self.device)  # Move batch to device (CPU/GPU)
                outputs = self.model(images)     # Forward pass, outputs are logits of shape (batch_size, num_classes)
                probs = torch.softmax(outputs, dim=1)  # Convert logits to probabilities
                all_probs.append(probs.cpu())    # Move to CPU and collect
                all_labels.append(labels.cpu())  # Move to CPU and collect

        # Concatenate all batches along the first dimension
        self.all_probs_train = torch.cat(all_probs, dim=0).numpy()  # Shape: (N_train, C)
        self.all_labels_train = torch.cat(all_labels, dim=0).numpy() # Shape: (N_train,)

        # Return full probability and label tensors
        return self.all_probs_train, self.all_labels_train
        # For example, CIFAR-10: N_train=50000, C=10 -> all_probs_train: (50000, 10), all_labels_train: (50000,)

    def eval_test(self):
        self.model.eval()  # Set model to evaluation mode
        all_probs, all_labels = [], []  # Lists to store batch-wise probabilities and labels

        with torch.no_grad():
            for images, labels in self.test_loader:
                images = images.to(self.device)  # Move batch to device
                outputs = self.model(images)      # Forward pass, logits (batch_size, num_classes)
                probs = torch.softmax(outputs, dim=1)  # Convert logits to probabilities
                all_probs.append(probs.cpu())     # Collect probabilities
                all_labels.append(labels.cpu())   # Collect labels

        # Concatenate all batches
        self.all_probs_test = torch.cat(all_probs, dim=0).numpy()   # Shape: (N_test, C)
        self.all_labels_test = torch.cat(all_labels, dim=0).numpy() # Shape: (N_test,)

        # Return full probability and label tensors
        return self.all_probs_test, self.all_labels_test
        # For example, CIFAR-10: N_test=10000, C=10 -> all_probs_test: (10000, 10), all_labels_test: (10000,)



    def compute_metrics(self, prob, labels):
        #---------------------
        # ACCURACY computation
        #---------------------
        pred = prob.argmax(axis=1)  # Predicted class indices
        accuracy = accuracy_score(labels, pred)
        report = classification_report(labels, pred, digits=4)

        #---------------------
        # ROC & AUC computation
        #---------------------
        n_classes = len(self.class_names) if self.class_names is not None else 10
        all_labels_bin = label_binarize(labels, classes=np.arange(n_classes))

        roc_curves = {}  # Store FPR, TPR, AUC for each class
        try:
            for i in range(n_classes):
                fpr, tpr, _ = roc_curve(all_labels_bin[:, i], prob[:, i])
                auc_value = auc(fpr, tpr)
                roc_curves[i] = {'fpr': fpr, 'tpr': tpr, 'auc': auc_value}
            roc_auc_overall = roc_auc_score(all_labels_bin, prob, multi_class='ovr')
        except Exception:
            roc_curves = None
            roc_auc_overall = float('nan')

        #---------------------
        # CONFUSION MATRICES
        #---------------------
        cm = confusion_matrix(labels, pred)

        # --- PER-CLASS METRICS ---
        metrics_per_class, metrics_summary = metrics_from_confusion_matrix(cm, self.class_names)

        # Return all metrics including ROC curves
        return accuracy, report, roc_auc_overall, roc_curves, cm, metrics_per_class, metrics_summary

            

In [14]:

# Function to plot and save confusion matrix
def plot_cm(cm, title, path, class_names=None):
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    plt.xlabel('Predicted')
    plt.ylabel('True')
    thresh = cm.max() / 2. if cm.max() > 0 else 0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(int(cm[i, j]), 'd'),
                        horizontalalignment="center",
                        color="white" if cm[i, j] > thresh else "black")
    if class_names is not None:
        plt.xticks(np.arange(len(class_names)), class_names, rotation=45, ha='right')
        plt.yticks(np.arange(len(class_names)), class_names)
    plt.tight_layout()
    plt.savefig(path)
    plt.close()
    print(f"Saved confusion matrix to {path}")


# Save per-class metrics as CSVs for easy inspection
def save_metrics_csv(metrics_per_class, summary, csv_path):
    # metrics_per_class is dict[class] -> dict(metrics)
    with open(csv_path, 'w', newline='') as csvfile:
        fieldnames = ['class', 'TP', 'FP', 'FN', 'TN', 'precision', 'recall', 'specificity', 'f1', 'support']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for cls, metrics in metrics_per_class.items():
            row = {'class': cls}
            row.update({k: metrics[k] for k in ['TP', 'FP', 'FN', 'TN', 'precision', 'recall', 'specificity', 'f1', 'support']})
            writer.writerow(row)
        # write summary as final rows
        writer.writerow({})
        writer.writerow({'class': 'SUMMARY', 'TP': '', 'FP': '', 'FN': '', 'TN': '', 'precision': summary['macro_precision'],
                            'recall': summary['macro_recall'], 'specificity': '', 'f1': summary['macro_f1'], 'support': summary['total']})
        

import matplotlib.pyplot as plt
import os

def plot_roc_curves(roc_curves, class_names=None, title="ROC Curves", save_path=None):
    """
    Plot multi-class ROC curves from a dict of {class_idx: {'fpr':..., 'tpr':..., 'auc':...}}.
    """
    if roc_curves is None:
        print("No ROC curves to plot.")
        return

    plt.figure(figsize=(10, 8))
    
    for class_idx, data in roc_curves.items():
        fpr = data['fpr']
        tpr = data['tpr']
        auc_value = data['auc']
        name = class_names[class_idx] if class_names is not None else str(class_idx)
        plt.plot(fpr, tpr, label=f'{name} (AUC = {auc_value:.2f})')

    plt.plot([0, 1], [0, 1], 'k--', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc='lower right')
    plt.grid(True)
    plt.tight_layout()
    
    if save_path is not None:
        plt.savefig(save_path)
        plt.close()
        print(f"Saved ROC curves to {save_path}")
    else:
        plt.show()


def  plot_loss(exp_name,num_epochs,train_losses, plots_dir):
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, num_epochs + 1), train_losses, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title(f'Training Loss over Epochs ({exp_name})')
    plt.grid(True)
    plt.tight_layout()

    # Save inside the experiment's plot directory
    loss_path = os.path.join(plots_dir, "loss.png")
    plt.savefig(loss_path)
    plt.close()

    print(f"Saved loss plot to {loss_path}")



#----------------
#     MAIN
#-----------------
if __name__ == "__main__":
    import torch, os, pickle, csv
    import numpy as np
    import matplotlib.pyplot as plt
    from torchvision import transforms
    from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
    from sklearn.preprocessing import label_binarize

    # ----------------------------
    # Seeds and device
    # ----------------------------
    torch.manual_seed(42)
    np.random.seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # ----------------------------
    # Data paths and parameters
    # ----------------------------
    data_path = "data"
    num_workers = 0
    pin_memory = True if device.type == 'cuda' else False
    batch_size = 128
    num_epochs = 2

    transform_norm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    transform_plain = transforms.Compose([transforms.ToTensor()])

    # ----------------------------
    # Load class names
    # ----------------------------
    # label_names_file = os.path.join(data_path, 'batches.meta')
    # with open(label_names_file, 'rb') as fo:
    #     labels_name = pickle.load(fo, encoding='bytes')
    #     labels_name=list(labels_name.keys())


    labels_name = ['airplane', 'automobile', 'bird', 'cat',
                   'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

    n_classes = len(labels_name)

    # ----------------------------
    # Instantiate experiments
    # ----------------------------
    res_norm = Experiment("normalized", transform_norm,
                          data_path=data_path,
                          batch_size=batch_size,
                          num_epochs=num_epochs,
                          device=device,
                          num_workers=num_workers,
                          pin_memory=pin_memory,
                          class_names=labels_name)

    res_plain = Experiment("not_normalized", transform_plain,
                           data_path=data_path,
                           batch_size=batch_size,
                           num_epochs=num_epochs,
                           device=device,
                           num_workers=num_workers,
                           pin_memory=pin_memory,
                           class_names=labels_name)

    results = []

    # ----------------------------
    # Run experiments
    # ----------------------------
    for ex in [res_norm, res_plain]:
        print(f"\n=== Running experiment: {ex.exp_name} ===")

        plots_dir = os.path.join("plots", ex.exp_name)
        os.makedirs(plots_dir, exist_ok=True)

        # Train and plot loss
        ex.train()
        plot_loss(ex.exp_name,ex.num_epochs,ex.train_losses, plots_dir)


        # Evaluate
        prob_train, labels_train = ex.eval_train()
        prob_test, labels_test = ex.eval_test()

        # Compute metrics
        acc_train, report_train, roc_auc_train, roc_curves_train, cm_train, metrics_train_per_class, metrics_train_summary = ex.compute_metrics(prob_train, labels_train)
        acc_test, report_test, roc_auc_test, roc_curves_test, cm_test, metrics_test_per_class, metrics_test_summary = ex.compute_metrics(prob_test, labels_test)


        # Plot ROC curves for train
        plot_roc_curves(
            roc_curves_train,
            class_names=ex.class_names,
            title=f'Multi-class ROC Curves (Train) - {ex.exp_name}',
            save_path=os.path.join(plots_dir, "roc_curves_train.png")
        )

        # Plot ROC curves for test
        plot_roc_curves(
            roc_curves_test,
            class_names=ex.class_names,
            title=f'Multi-class ROC Curves (Test) - {ex.exp_name}',
            save_path=os.path.join(plots_dir, "roc_curves_test.png")
        )

        # Plot confusion matrices
        plot_cm(cm_train, title=f'Confusion Matrix (Train) - {ex.exp_name}',
                path=os.path.join(plots_dir, "confusion_train.png"), class_names=ex.class_names)
        plot_cm(cm_test, title=f'Confusion Matrix (Test) - {ex.exp_name}',
                path=os.path.join(plots_dir, "confusion_test.png"), class_names=ex.class_names)
        plot_cm(cm_train - cm_test, title=f'Confusion Matrix (Train - Test) - {ex.exp_name}',
                path=os.path.join(plots_dir, "confusion_diff.png"), class_names=ex.class_names)

        # Save per-class metrics
        save_metrics_csv(metrics_train_per_class, metrics_train_summary, os.path.join(plots_dir, "metrics_train.csv"))
        save_metrics_csv(metrics_test_per_class, metrics_test_summary, os.path.join(plots_dir, "metrics_test.csv"))

        # Print summaries
        print(f"\n--- Summary for {ex.exp_name} ---")
        print(f"Train accuracy: {acc_train:.4f}, Test accuracy: {acc_test:.4f}, Test ROC AUC: {roc_auc_test:.4f}")
        print(f"Total training time: {sum(ex.epoch_times):.2f}s, Epoch times: {[round(t,2) for t in ex.epoch_times]}")

        # Store results
        results.append({
            'exp_name': ex.exp_name,
            'train_losses': ex.train_losses,
            'epoch_times': ex.epoch_times,
            'total_time': sum(ex.epoch_times),
            'accuracy_train': acc_train,
            'accuracy_test': acc_test,
            'report_train': report_train,
            'report_test': report_test,
            'roc_auc_test': roc_auc_test,
            'confusion_train': cm_train,
            'confusion_test': cm_test,
            'confusion_diff': cm_train - cm_test,
            'metrics_test_per_class': metrics_test_per_class,
            'metrics_train_per_class': metrics_train_per_class,
            'plots_dir': plots_dir
        })

    # ----------------------------
    # High-level comparison summary
    # ----------------------------
    summary_path = os.path.join("plots", "comparison_summary.txt")
    os.makedirs("plots", exist_ok=True)
    with open(summary_path, 'w') as f:
        f.write("Comparison: normalized vs not-normalized\n\n")
        for r in results:
            f.write(f"{r['exp_name']} results:\n")
            f.write(f"  Train acc: {r['accuracy_train']:.4f}\n")
            f.write(f"  Test  acc: {r['accuracy_test']:.4f}\n")
            f.write(f"  Total training time: {r['total_time']:.2f}s\n\n")

    print(f"\nSaved comparison summary to {summary_path}")
    print("Done.")


Device: cpu
Sample visualization (exp=normalized) - Label: 6
Could not visualize sample image: cannot access local variable 't' where it is not associated with a value
Sample visualization (exp=not_normalized) - Label: 6
Could not visualize sample image: cannot access local variable 't' where it is not associated with a value

=== Running experiment: normalized ===

=== Starting training for experiment: normalized ===

Total model params: 44726, Optimizer params: 44726
Epoch [1/2] - Avg Loss: 2.2986 - Time: 9.08s
Epoch [2/2] - Avg Loss: 2.2694 - Time: 8.40s
Total training time for normalized: 17.48s
Saved loss plot to plots/normalized/loss.png


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Saved ROC curves to plots/normalized/roc_curves_train.png
Saved ROC curves to plots/normalized/roc_curves_test.png
Saved confusion matrix to plots/normalized/confusion_train.png
Saved confusion matrix to plots/normalized/confusion_test.png
Saved confusion matrix to plots/normalized/confusion_diff.png

--- Summary for normalized ---
Train accuracy: 0.1852, Test accuracy: 0.1860, Test ROC AUC: 0.7046
Total training time: 17.48s, Epoch times: [9.08, 8.4]

=== Running experiment: not_normalized ===

=== Starting training for experiment: not_normalized ===

Total model params: 44726, Optimizer params: 44726
Epoch [1/2] - Avg Loss: 2.3030 - Time: 10.10s
Epoch [2/2] - Avg Loss: 2.2991 - Time: 6.35s
Total training time for not_normalized: 16.44s
Saved loss plot to plots/not_normalized/loss.png


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Saved ROC curves to plots/not_normalized/roc_curves_train.png
Saved ROC curves to plots/not_normalized/roc_curves_test.png
Saved confusion matrix to plots/not_normalized/confusion_train.png
Saved confusion matrix to plots/not_normalized/confusion_test.png
Saved confusion matrix to plots/not_normalized/confusion_diff.png

--- Summary for not_normalized ---
Train accuracy: 0.1373, Test accuracy: 0.1328, Test ROC AUC: 0.6387
Total training time: 16.44s, Epoch times: [10.1, 6.35]

Saved comparison summary to plots/comparison_summary.txt
Done.
