# **Project 9: Advanced Out-of-Distribution Detection for Multi-Class Classification**   

## Project Overview

The goal of this project is to address the challenge of out-of-distribution (OOD) detection in deep learning, particularly in the context of multi-class image classification. In real-world scenarios, models often encounter data that differ significantly from their training distribution, which can lead to unreliable or unsafe predictions. Therefore, detecting and properly handling OOD inputs is crucial for building robust and trustworthy AI systems.

The first step involves training a deep neural network classifier on the **Food-101** dataset, which serves as the in-distribution (ID) data. This dataset includes high-dimensional images with complex visual features, making it a suitable benchmark for evaluating model performance in realistic settings.

Once the classifier is trained, we will implement one or more OOD detection methods. The **SVHN (Street View House Numbers)** dataset will be used as the primary OOD source, providing visually distinct samples that test the model’s ability to separate known from unknown inputs. Techniques such as energy-based or gradient-based scoring can be applied to detect OOD samples effectively.

Additionally, the project can be extended by incorporating other OOD datasets to further assess the generalizability of the detection framework. The final objective is to evaluate the model’s capability to distinguish between in-distribution and out-of-distribution examples using appropriate metrics, such as AUROC, AUPR, and FPR@95.

# Imports

In this section, we import all the necessary libraries required for building, training, and evaluating the model, as well as for implementing OOD detection. These include:

- **Torch / torchvision** for model definition, training, and dataset loading
- **NumPy / Pandas** for data manipulation
- **Matplotlib** for visualization
- **Scikit-learn** for computing evaluation metrics
- **Custom utility functions** (if any) for training loops, loss functions, and OOD scoring

In [None]:
!pip install tqdm

# Core
import os
import time
import copy
import random
from collections import defaultdict
from itertools import cycle, islice


import numpy as np
from scipy.io import loadmat
from PIL import Image

# Visualization
import matplotlib.pyplot as plt

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset, Dataset, random_split
from torch.optim.lr_scheduler import (
    ReduceLROnPlateau, StepLR, CosineAnnealingLR, CosineAnnealingWarmRestarts
)

# torchvision
from torchvision import datasets, transforms, models
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.datasets import ImageFolder, VisionDataset

# Progress bar
from tqdm.notebook import tqdm

# Scikit-learn
from sklearn.metrics import (
    roc_auc_score, roc_curve, precision_recall_curve, auc,
    confusion_matrix, average_precision_score
)
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler

# Hardware Check

Before starting the training process, we check whether a GPU is available for computation. If CUDA is available, the code will use the GPU to accelerate training. Otherwise, it will fall back to the CPU. Basic system information, such as GPU model and memory or CPU details, is also displayed to help monitor the hardware configuration.

In [None]:
def hardware_check():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"GPU is available!")
        print(f"  -> GPU - {torch.cuda.get_device_name()}")
        print(f"  -> Total Memory: {torch.cuda.get_device_properties().total_memory / 1024**3:.2f} GB")
    else:
        device = torch.device("cpu")
        print("GPU is not available, using CPU.")
        print("\nCPU Information:")
        cpu_model = os.popen("cat /proc/cpuinfo | grep \"model name\" | uniq").read().strip()
        print(f"CPU Model: {cpu_model}")
        print(f"Number of CPU cores: {os.cpu_count()}")
    return device

device = hardware_check()
print(f"\nUsing {device} for computation")

# Globals

We use a batch size of 64 throughout, and two separate initial learning rates:
* LR_ID_INIT for the base Food-101 fine-tuning phase
* LR_OE_INIT for the Outlier Exposure stage

A small weight decay (WEIGHT_DECAY=1e-4) helps regularize training in both phases. Each phase runs for 15 epochs (EPOCHS_ID / EPOCHS_OE), with early stopping after 6 epochs of no improvement during OE (PATIENCE=6).

For Outlier Exposure specifically, we introduce:
* energy margins M_IN / M_OUT to define hinge losses on in- and out-of-distribution energies
* a maximum OE weight LAMBDA_MAX and balance term ALPHA_ID
* a warm-up (3 epochs) and linear ramp (2 epochs) schedule to gradually apply the OE penalty
* a softmax temperature TEMPERATURE for energy scoring

All images are normalized with the standard ImageNet channel means and standard deviations (MEAN, STD) for consistency with the pretrained backbone.

In [None]:
BATCH_SIZE      = 64
LR_ID_INIT      = 2e-4      # 1° training
LR_OE_INIT      = 3e-4      # 2° training
WEIGHT_DECAY    = 1e-4
EPOCHS_ID       = 15
EPOCHS_OE       = 15 

# Parameters for OE training
M_IN, M_OUT     = -6.0, -0.5
LAMBDA_MAX      = 0.06
ALPHA_ID        = 0.05
WARMUP_EP       = 3
RAMP_EP         = 2
PATIENCE        = 7
TEMPERATURE     = 1.0

MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]

# Data and Trasformation

Below we define the image‐level preprocessing pipelines for our in-distribution (ID) Food-101 data and out-of-distribution (OOD) SVHN data. During training we apply a rich suite of random augmentations—resizing, random crops, flips, rotations, color jitter, affine and perspective warps, plus random erasing—to help the model learn robust, invariant features. At test time (both for ID and OOD) we use a single deterministic resize, center crop and normalize flow so that evaluation is stable and reproducible.

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random')
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

ood_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(MEAN,STD)
])

**Food-101**

We load Food-101 via ImageFolder, pointing it at the top-level image directories, and then apply the official train.txt / test.txt metadata files to split exactly as the dataset creators intended. From the resulting training set, we carve off a stratified 10% sample, meaning we randomly select examples in proportion to each of the 101 classes, to serve as our validation set. This simple hold-out preserves the original class balance, so we can tune hyperparameters and monitor for overfitting without disturbing the integrity of the published train/test split.

In [None]:
val_split = 0.1  # 10% per validation

# Import from Kaggle
dataset_root = "/kaggle/input/food101/food-101"
images_root = os.path.join(dataset_root, "images")
train_txt = os.path.join(dataset_root, "meta", "train.txt")
test_txt = os.path.join(dataset_root, "meta", "test.txt")

full_dataset_train = ImageFolder(root=images_root, transform=train_transform)
full_dataset_eval  = ImageFolder(root=images_root, transform=test_transform)

print(f"Complete dataset loaded: {len(full_dataset_train)} total images")
print(f"Number of classes: {len(full_dataset_train.classes)}")

# Split in train and test
def load_indices(txt_file, dataset):
    with open(txt_file, "r") as f:
        lines = f.read().splitlines()
    idxs = []
    for rel_path in lines:
        for i, (img_path, _) in enumerate(dataset.samples):
            if rel_path in img_path:
                idxs.append(i)
                break
    return idxs

train_indices = load_indices(train_txt, full_dataset_train)
test_indices = load_indices(test_txt, full_dataset_train)

# Balanced Validation Set 
train_labels = [full_dataset_train.samples[i][1] for i in train_indices]
splitter = StratifiedShuffleSplit(n_splits=1, test_size=val_split, random_state=42)
train_idx, val_idx = next(splitter.split(train_indices, train_labels))

train_id = Subset(full_dataset_train, [train_indices[i] for i in train_idx])
val_id   = Subset(full_dataset_eval,  [train_indices[i] for i in val_idx])
test_id  = Subset(full_dataset_eval,  test_indices)

print(f"Train: {len(train_id)}, Validation: {len(val_id)}, Test: {len(test_id)}")

**SVHN**

For Outlier Exposure, we can’t simply use ImageFolder because SVHN is stored in MATLAB .mat files.  So we wrap those with a tiny VisionDataset subclass that (a) loads the .mat, (b) transposes the arrays into standard H×W×C form, and then (c) returns only images, no labels, so that our OE loss sees “unlabeled” outliers.  From the SVHN training pool we then draw a random 30% sample (via torch.randperm) to use during model fitting, while reserving the full SVHN test set as a clean benchmark for OOD detection after training.

In [None]:
# Return only img (without label)
class SVHNFromMat(VisionDataset):
    def __init__(self, mat_path, transform=None):
        super().__init__(root="", transform=transform)
        data = loadmat(mat_path)
        self.images = np.transpose(data['X'], (3, 0, 1, 2))  

    def __getitem__(self, index):
        img = Image.fromarray(self.images[index])
        if self.transform:
            img = self.transform(img)
        return img  # 🔁 SOLO immagine, nessuna label

    def __len__(self):
        return len(self.images)

# Import from Kaggle
svhn_train_path = "/kaggle/input/svhndataset/train_32x32.mat"
svhn_test_path = "/kaggle/input/svhndataset/test_32x32.mat"

svhn_train = SVHNFromMat(svhn_train_path, transform=ood_transform)
svhn_test = SVHNFromMat(svhn_test_path, transform=ood_transform)

# Subset for training OOD (30%)
ood_fraction = 0.30
subset_idx   = torch.randperm(len(svhn_train))[: int(ood_fraction * len(svhn_train))]
svhn_subset  = Subset(svhn_train, subset_idx)

print(f"SVHN subset created: {len(svhn_subset)} images "
      f"({ood_fraction*100:.0f}% di {len(svhn_train)})")

After loading the datasets, we wrap each split in a PyTorch DataLoader to efficiently iterate over mini-batches during training and evaluation. We shuffle only the training streams and keep validation/test loaders deterministic. We also halve the OOD batch size so that mixing ID+OOD examples still yields a balanced overall batch.

In [None]:
train_id_loader = DataLoader(train_id , batch_size=BATCH_SIZE, shuffle=True ,
                             num_workers=2)
val_id_loader   = DataLoader(val_id   , batch_size=BATCH_SIZE, shuffle=False,
                             num_workers=2)
svhn_loader     = DataLoader(svhn_subset, batch_size=BATCH_SIZE//2, shuffle=True,
                             num_workers=2)
test_loader     = DataLoader(test_id, batch_size=BATCH_SIZE, shuffle=False, 
                             num_workers=2)
ood_loader      = DataLoader(svhn_test, batch_size=BATCH_SIZE, shuffle=False, 
                             num_workers=2)

**Oxford 102 Flower Dataset**

To further evaluate the model's ability to distinguish in-distribution (Food-101) from out-of-distribution (OOD) samples, we use the Flower Dataset as an additional OOD source.

This dataset consists of flower images from various categories and lies outside the food domain, making it a suitable benchmark to test the generalization of OOD detection.



In [None]:
flower_ood_dataset = ImageFolder(
    root='/kaggle/input/pytorch-challange-flower-dataset/dataset/train',
    transform=ood_transform
)

flower_ood_loader = DataLoader(flower_ood_dataset, batch_size=64, shuffle=False, num_workers=0)

# Network

As the starting point for our project, we chose ResNet-50 as our model because it offers a great trade-off between accuracy and speed. Its residual blocks let us go deep enough to learn complex visual patterns (textures, shapes, colors) without training instabilities, and starting from ImageNet–pretrained weights means we need fewer epochs and less data to reach strong performance on 101 food categories.

**How `build_resnet50_food101` works**  
1. **Load the backbone**  
   We instantiate the standard ResNet-50 model, optionally loading ImageNet weights to give us rich, general-purpose feature detectors from the very first iteration.  
2. **Swap in a new head**  
   The original 1000-class head is replaced with a small block that first applies dropout (to prevent overfitting) and then a linear layer that outputs exactly 101 logits—one per food category.  

By wrapping these steps in a single function, we guarantee that both our baseline and Outlier Exposure experiments use the exact same architecture and initialization.  

In [None]:
def build_resnet50_food101(
        num_classes: int = 101,
        pretrained: bool = True,
        dropout_p: float = 0.3,
        device: torch.device | str = None
) -> nn.Module:
    
    weights = ResNet50_Weights.DEFAULT if pretrained else None
    model = resnet50(weights=weights)

    model.fc = nn.Sequential(
        nn.Dropout(dropout_p),
        nn.Linear(model.fc.in_features, num_classes)
    )

    return model

#  Standard In-Distribution Training

We begin by training a standard ResNet-50 model on Food-101 only, no out-of-distribution data, no auxiliary losses.  Our objectives are:  
1. Teach the network to discriminate among the 101 food categories as effectively as possible.  
2. Produce a “vanilla” classifier that we can later probe on truly out-of-distribution images (SVHN and beyond) to quantify its native OOD rejection behavior.

During each epoch we:

1. **Forward + Backward**  
   For each mini-batch, we compute the network logits and convert them to probabilities via softmax:  
   $$
   p_{i,c} \;=\; \frac{\exp(z_{i,c})}{\sum_{k=1}^{101}\exp(z_{i,k})}\,.
   $$  
   We then apply label smoothing (\$\epsilon=0.1\$) to the true-class probability:  
   $$
   p_{i,y_i}^* \;=\; (1 - \epsilon)\times 1 \;+\; \frac{\epsilon}{101}\,.
   $$  
   The resulting cross-entropy loss is  
   $$
   \mathcal{L}_{\mathrm{CE}}
   \;=\;
   -\frac{1}{B}\sum_{i=1}^{B} \log p_{i,y_i}^*\,.
   $$  
    Calling `loss.backward()` back-propagates the gradient  
    $$\nabla \mathcal{L}_{\mathrm{CE}}$$  
    through the network, and `optimizer.step()` (Adam) updates every parameter accordingly. This pair of operations—forward pass to compute $$\mathcal{L}_{\mathrm{CE}}$$  
    then backward pass to adjust parameters—drives the model to better classify each food category.  
2. **Learning-rate schedule**  
   Every 5 epochs we halve the learning rate. This lets the optimizer make big updates early, then finer adjustments later.

3. **Progress monitoring**  
   A `tqdm` progress bar shows the running loss and top-1 accuracy for each mini-batch.

This run produces our baseline Food-101 classifier. In the next phase we’ll evaluate its OOD detection performance (energy and MSP scores on SVHN) before adding any Outlier Exposure penalties.  

In [None]:
model = build_resnet50_food101()    
model = model.to(device)

criterion  = nn.CrossEntropyLoss(label_smoothing=0.1)  

optimizer  = optim.Adam(model.parameters(),lr=LR_ID_INIT, weight_decay=WEIGHT_DECAY)


# Training Loop

print(f"\n Initial Learning-rate: {LR_ID_INIT:.1e}")

for epoch in range(EPOCHS_ID):                   
    # scheduler: halves the LR every 5 epochs
    if epoch > 0 and epoch % 5 == 0:    
        for g in optimizer.param_groups:
            g["lr"] *= 0.5            
        print(f"\n LR halved : {optimizer.param_groups[0]['lr']:.1e}")

    print(f"\nEpoch {epoch+1}/{EPOCHS_ID} — samples: {len(train_id_loader.dataset)}")
    model.train()
    running_loss = 0.0
    correct = total = 0

    loop = tqdm(train_id_loader, leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        # forward + backward
        outputs = model(images)
        loss    = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total   += labels.size(0)
        correct += (outputs.argmax(1) == labels).sum().item()

        loop.set_postfix(
            loss = running_loss / (loop.n + 1),
            acc  = 100.0 * correct / total
        )

    print(f" Epoch {epoch+1} — Train ACC: {100.0 * correct / total:.2f}%")

# Evaluation

In the evaluation phase, we first measure the classifier’s accuracy on in-distribution (Food-101) test images to quantify standard recognition performance. Next, we assess the model’s ability to detect out-of-distribution (OOD) samples by feeding it SVHN images and computing energy and max-softmax scores. From these scores we derive:
	
* AUROC to capture overall separability between ID and OOD.
* FPR@95 TPR to report the false positive rate when correctly identifying 95 % of in-distribution samples.
* AUPR-In to summarize precision-recall trade-offs treating ID as the positive class.


Plotting ROC and PR curves alongside score histograms completes the evaluation, giving both quantitative metrics and visual insight into how well the model distinguishes familiar from unfamiliar images.

In [None]:
model.eval()
T = 0.2                          # Temperature found with grid-search

correct = 0
total = 0

# Accuracy

with torch.no_grad():
    test_loop = tqdm(test_loader, total=len(test_loader), desc="Testing")
    for images, labels in test_loop:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"\nFinal Test Accuracy: {100 * correct / total:.2f}%")

# Useful functions

def fpr_at_given_tpr(fpr, tpr, target_tpr=0.95):
    if tpr[-1] < target_tpr:     
        return 1.0
    return float(np.interp(target_tpr, tpr, fpr))

def _images_from(batch):
    return batch[0] if isinstance(batch, (list, tuple)) else batch


# Scoring Functions

def energy_scores(loader, T):
    out = []
    for batch in loader:
        x = _images_from(batch).to(device, non_blocking=True)
        out.append((-torch.logsumexp(model(x) / T, 1)).cpu())
    return torch.cat(out).numpy()        # shape (N,)

@torch.no_grad()
def msp_scores(loader):
    out = []
    for batch in tqdm(loader, desc="MSP", leave=False):
        x = _images_from(batch).to(device, non_blocking=True)
        out.append(torch.softmax(model(x), dim=1).max(1).values.cpu())
    return torch.cat(out).numpy()

id_energy  = energy_scores(test_loader, T)
ood_energy = energy_scores(ood_loader,  T)

id_msp  = msp_scores(test_loader)
ood_msp = msp_scores(ood_loader)


# m-shift sull’Energy  

m            = np.percentile(id_energy, 95)
id_energy_s  = id_energy  - m
ood_energy_s = ood_energy - m
energy_cat   = -np.concatenate([id_energy_s, ood_energy_s])   # alto = ID
labels       = np.concatenate([np.ones_like(id_energy_s),
                               np.zeros_like(ood_energy_s)])


# AUROC / AUPR  
auroc_energy = roc_auc_score(labels, energy_cat)
auroc_msp    = roc_auc_score(labels, np.concatenate([id_msp, ood_msp]))

fpr_e, tpr_e, _ = roc_curve(labels, energy_cat)
fpr_s, tpr_s, _ = roc_curve(labels, np.concatenate([id_msp, ood_msp]))

# False Positive Rate
tau_energy   = 0.0                     
tau_msp      = np.percentile(id_msp, 5)

fpr95_energy = (ood_energy_s >  tau_energy).mean()
fpr95_msp    = (ood_msp      >  tau_msp   ).mean()

pr_e, rc_e, _ = precision_recall_curve(labels, energy_cat)
pr_s, rc_s, _ = precision_recall_curve(labels,
                                       np.concatenate([id_msp, ood_msp]))
aupr_e = average_precision_score(labels, energy_cat)
aupr_m = average_precision_score(labels, np.concatenate([id_msp, ood_msp]))

# Results

print("\n METRICHE OOD ")
print(f"• AUROC  (Energy)      : {auroc_energy:.4f}")
print(f"• AUROC  (Soft-max)    : {auroc_msp:.4f}")
print(f"• FPR@95TPR (Energy)   : {fpr95_energy*100:.2f}%")
print(f"• FPR@95TPR (Soft-max) : {fpr95_msp*100:.2f}%")
print(f"• AUPR-In (Energy)     : {aupr_e:.4f}")
print(f"• AUPR-In (Soft-max)   : {aupr_m:.4f}")


# Curve ROC / PR

plt.figure(figsize=(14, 5))
# ROC
plt.subplot(1, 2, 1)
plt.plot(fpr_e, tpr_e, label=f"Energy (AUROC={auroc_energy:.4f})")
plt.plot(fpr_s, tpr_s, label=f"MSP   (AUROC={auroc_msp:.4f})")
plt.plot([0, 1], [0, 1], "--", color="gray")
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
plt.title("ROC — OOD Detection"); plt.legend(); plt.grid(True)

# PR
plt.subplot(1, 2, 2)
plt.plot(rc_e, pr_e, label=f"Energy (AUPR={aupr_e:.4f})")
plt.plot(rc_s, pr_s, label=f"MSP   (AUPR={aupr_m:.4f})")
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("Precision-Recall (ID positive)"); plt.legend(); plt.grid(True)
plt.tight_layout(); plt.show()

# Istogrammi
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Energy
axes[0].hist(id_energy_s,  bins=100, alpha=0.6, label="ID (Food101)")
axes[0].hist(ood_energy_s, bins=100, alpha=0.6, label="OOD (SVHN)")
axes[0].axvline(0, color="k", ls="--")
axes[0].set_title("Distribuzione Energy (shift)")
axes[0].set_xlabel("Energy shiftato  (alto = ID)")
axes[0].legend()

# MSP
axes[1].hist(id_msp,  bins=100, alpha=0.6, label="ID (Food101)")
axes[1].hist(ood_msp, bins=100, alpha=0.6, label="OOD (SVHN)")
axes[1].axvline(tau_msp, color="k", ls="--")
axes[1].set_title("Distribuzione Max-Softmax")
axes[1].set_xlabel("Max Softmax  (alto = ID)")
axes[1].legend()

plt.tight_layout(); plt.show()

# Outlier Exposure: Training ID + OOD 
In this phase, we augment the standard Food-101 fine-tuning with SVHN images and an auxiliary energy-based penalty to push OOD samples away from the in-distribution manifold.  The network still learns the 101 food classes, but now also learns to assign higher “energy” to outliers.

During each epoch:

1. **Outlier Exposure Loss**

    We process mixed mini-batches of ID and OOD images.  For ImageNet-pretrained ResNet-50 with a 101-way head, let
    $$
    E(\mathbf{z}) = -T,\log\sum_c\exp\bigl(z_c/T\bigr)
    $$
    be the energy of logits $\mathbf{z}$.  We impose two hinge-style penalties:
    $$
    \mathcal{L}_{\text{ID}}  = \bigl[\max(E_{\text{ID}} - m_{\text{in}},0)\bigr]^2,\quad
    \mathcal{L}_{\text{OOD}} = \bigl[\max(m_{\text{out}} - E_{\text{OOD}},0)\bigr]^2,
    $$
    and combine them into
    $$
    \mathcal{L}_{\text{OE}} = \lambda*\bigl(\mathcal{L}_{\text{OOD}} + \alpha*\mathcal{L}_{\text{ID}}\bigr).
    $$
    The full training objective is
    $$
    \mathcal{L} = \mathcal{L}_{\text{CE}} + \mathcal{L}_{\text{OE}}.
    $$
    After computing the combined loss, we perform back-propagation of its gradient through the network and then update every model parameter via the optimizer’s update rule.
    	
2. **Warm-up & Ramp for OE Strength**


   We gradually increase $\lambda$ over the first few epochs (warm-up+ramp), then hold it constant at its maximum value.  This avoids overwhelming the classifier before it has learned good food-class features.
	
4. **Label-smoothing Schedule**


   After the OE ramp completes, we decay the label-smoothing factor by 0.05 each epoch to sharpen the ID classification head over time.
	
6.	**Cosine-Restarts Scheduler**


  We use CosineAnnealingWarmRestarts to briefly boost the learning rate at regular intervals (T₀=5, Tₘₚₗ=2), enabling fresh exploration without extending total epochs.
	
8.	**Monitoring & Early-Stop**


  A tqdm bar tracks batch-level CE and OE losses plus training accuracy.  Validation loss triggers best-model snapshots, and training stops automatically if no improvement occurs for PATIENCE epochs.

10.	**Batch-Norm re-calibration**

   Once we’ve picked the best model, we feed all Food-101 training images through it one last time (still in training mode but without changing any weights). This lets each BatchNorm layer recompute its running averages so they match the true data—making the model’s behavior more reliable when we finally test it.

This procedure teaches the model not only to recognise food categories, but also to push SVHN samples into a high-energy “outlier” regime, laying the groundwork for robust OOD detection.

In [None]:
model = build_resnet50_food101()    
model = model.to(device)

optimizer  = optim.Adam(model.parameters(), lr=LR_OE_INIT,weight_decay=WEIGHT_DECAY)

scheduler = CosineAnnealingWarmRestarts(optimizer,T_0=5, T_mult=2, eta_min=1e-6)

label_smooth = 0.1                     
criterion_id = nn.CrossEntropyLoss(label_smoothing=label_smooth)
criterion_val = nn.CrossEntropyLoss()


def energy(logits, T=TEMPERATURE):
    return -T * torch.logsumexp(logits / T, dim=1)


# Training loop

best_val_loss = float("inf")
best_wts      = copy.deepcopy(model.state_dict())
no_improve    = 0

for epoch in range(EPOCHS_OE):

    # warm-up + ramp for lambda
    if   epoch < WARMUP_EP:
        lam = 0.0
    elif epoch < WARMUP_EP + RAMP_EP:
        lam = LAMBDA_MAX * (epoch - WARMUP_EP + 1) / RAMP_EP
    else:
        lam = LAMBDA_MAX

    # progressive label-smoothing 
    if epoch >= WARMUP_EP + RAMP_EP and label_smooth > 0:
        label_smooth = max(0.0, label_smooth - 0.05)
        criterion_id.label_smoothing = label_smooth

    print(f"\nEpoch {epoch+1}/{EPOCHS_OE}  |  λ={lam:.3f}  |  "
          f"LR={optimizer.param_groups[0]['lr']:.2e}  |  LS={label_smooth:.2f}")

    model.train()
    ood_iter = iter(svhn_loader)
    run_gap = correct = total = 0

    train_loop = tqdm(
        train_id_loader,
        total=len(train_id_loader),
        desc=f"Train {epoch+1}",
        leave=False
    )

    for i, (x_id, y_id) in enumerate(train_loop, 1):
        try:
            x_ood = next(ood_iter)
        except StopIteration:
            ood_iter = iter(svhn_loader)
            x_ood    = next(ood_iter)

        x_id, y_id, x_ood = x_id.to(device), y_id.to(device), x_ood.to(device)

        # forward + losses 
        logits_id  = model(x_id)
        logits_ood = model(x_ood)
        loss_ce    = criterion_id(logits_id, y_id)

        e_id    = energy(logits_id)
        e_ood   = energy(logits_ood)
        gap     = (e_ood.mean() - e_id.mean()).item()
        loss_id  = torch.clamp(e_id  - M_IN , min=0).pow(2).mean()
        loss_ood = torch.clamp(M_OUT - e_ood, min=0).pow(2).mean()
        loss_oe  = lam * (loss_ood + ALPHA_ID * loss_id)
        loss     = loss_ce + loss_oe

        # backward 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        run_gap += gap
        correct += (logits_id.argmax(1) == y_id).sum().item()
        total   += y_id.size(0)

        if (i+1) % 100 == 0 or (i+1) == len(train_id_loader):
            print(f"batch {i+1:4d} | gap={gap:4.2f} | "f"CE={loss_ce.item():.3f} | OE={loss_oe.item():.3f}")


    train_acc = 100 * correct / total
    print(f"Train Acc {train_acc:.2f}% — mean gap {run_gap/len(train_id_loader):.2f}")

    # Validation 
    model.eval()
    val_loss = val_correct = val_total = 0

    val_loop = tqdm(
        val_id_loader,
        total=len(val_id_loader),
        desc=f"Val   {epoch+1}",
        leave=False
    )

    with torch.no_grad():
        for x, y in val_loop:
            out = model(x.to(device))
            l = criterion_val(out, y.to(device)).item()
            val_loss    += l
            val_correct += (out.argmax(1) == y.to(device)).sum().item()
            val_total   += y.size(0)
            val_loop.set_postfix({"val_loss": f"{val_loss/(val_loop.n+1):.4f}", 
                                  "val_acc": f"{100*val_correct/val_total:5.2f}%"})

    val_loss /= len(val_id_loader)
    val_acc   = 100 * val_correct / val_total
    print(f"Val Loss {val_loss:.4f} — Acc {val_acc:.2f}%")

    # early-stop 
    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        best_wts      = copy.deepcopy(model.state_dict())
        torch.save(best_wts, "best_model_weights.pth")
        no_improve = 0
        print("Best model saved.")
    else:
        no_improve += 1
        print(f"⏳ no improvement ({no_improve}/{PATIENCE})")
        if no_improve >= PATIENCE:
            print("Early stop triggered.")
            break

    # scheduler 
    scheduler.step()

# Re-calibration BN 
model.load_state_dict(best_wts)
model.train()
with torch.no_grad():
    for x, _ in DataLoader(train_id, batch_size=256, shuffle=False):
        _ = model(x.to(device))
model.eval()
print("Best weights loaded & BN re-calibrated.")

# Evaluation

After training, we assess both classification accuracy on Food-101 (ID) and out-of-distribution performance on SVHN (OOD). First, we switch the model to evaluation mode and run a simple test loop over held-out Food-101 images to report accuracy.

Next, we compute two OOD detection scores for each image: energy and maximum softmax probability (MSP). Both scores are computed using temperature scaling, by dividing logits by a fixed temperature T before applying softmax (for MSP) or computing the energy score. This calibration step is crucial for obtaining meaningful and separable confidence values across ID and OOD samples.


In [None]:
T = 0.35                      

model.eval()

correct = 0
total = 0

# Accuracy

with torch.no_grad():
    test_loop = tqdm(test_loader, total=len(test_loader), desc="Testing")
    for images, labels in test_loop:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"\nFinal Test Accuracy: {100 * correct / total:.2f}%")

# Scoring function

@torch.no_grad()
def get_scores(loader):
    energy, msp = [], []
    for batch in tqdm(loader, desc="Scoring", leave=False):
        x = batch[0] if isinstance(batch, (list, tuple)) else batch
        x = x.to(device, non_blocking=True)

        logits = model(x)
        energy.append((-torch.logsumexp(logits / T, 1)).cpu())          # Energy
        msp.append(torch.softmax(logits, 1).max(1).values.cpu())        # MSP
    return torch.cat(energy).numpy(), torch.cat(msp).numpy()


id_energy,  id_msp  = get_scores(test_loader)      # Food-101 
ood_energy, ood_msp = get_scores(ood_loader)   # SVHN


# AUROC / AUPR  

labels = np.concatenate([np.ones_like(id_energy),
                         np.zeros_like(ood_energy)])

auroc_e = roc_auc_score(labels, -np.concatenate([id_energy, ood_energy]))
auroc_m = roc_auc_score(labels,  np.concatenate([id_msp,   ood_msp  ]))

pr_e, rc_e, _ = precision_recall_curve(labels, -np.concatenate([id_energy, ood_energy]))
pr_m, rc_m, _ = precision_recall_curve(labels,  np.concatenate([id_msp,    ood_msp ]))
aupr_e = average_precision_score(labels, -np.concatenate([id_energy, ood_energy]))
aupr_m = average_precision_score(labels,  np.concatenate([id_msp,    ood_msp ]))


# False-positive rate 

fpr_e, tpr_e, _ = roc_curve(labels, -np.concatenate([id_energy, ood_energy]))
fpr_m, tpr_m, _ = roc_curve(labels,  np.concatenate([id_msp,   ood_msp  ]))

def fpr_at_tpr(fpr, tpr, target=0.95):
    return np.interp(target, tpr, fpr) if tpr[-1] >= target else 1.0

fpr95_e = fpr_at_tpr(fpr_e, tpr_e)
fpr95_m = fpr_at_tpr(fpr_m, tpr_m)


# Results

print("\n  METRICHE OOD")
print(f"• AUROC  (Energy)      : {auroc_e:.4f}")
print(f"• AUROC  (Soft-max)    : {auroc_m:.4f}")
print(f"• FPR@95TPR (Energy)   : {fpr95_e*100:.2f}%")
print(f"• FPR@95TPR (Soft-max) : {fpr95_m*100:.2f}%")
print(f"• AUPR-In (Energy)     : {aupr_e:.4f}")
print(f"• AUPR-In (Soft-max)   : {aupr_m:.4f}")


# Curve ROC / PR

plt.figure(figsize=(14,5))
# ROC
plt.subplot(1,2,1)
plt.plot(fpr_e, tpr_e, label=f"Energy (AUROC={auroc_e:.4f})")
plt.plot(fpr_m, tpr_m, label=f"MSP   (AUROC={auroc_m:.4f})")
plt.plot([0,1],[0,1],'--',color='gray')
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
plt.title("ROC — OOD Detection"); plt.legend(); plt.grid(True)
# PR
plt.subplot(1,2,2)
plt.plot(rc_e, pr_e, label=f"Energy (AUPR={aupr_e:.4f})")
plt.plot(rc_m, pr_m, label=f"MSP   (AUPR={aupr_m:.4f})")
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("Precision-Recall (ID positive)"); plt.legend(); plt.grid(True)
plt.tight_layout(); plt.show()


# Histograms

fig, ax = plt.subplots(1,2, figsize=(14,4))
ax[0].hist(id_energy,  bins=100, alpha=0.6, label="ID (Food-101)")
ax[0].hist(ood_energy, bins=100, alpha=0.6, label="OOD (SVHN)")
ax[0].set_title("Distribuzione Energy"); ax[0].legend()

ax[1].hist(id_msp,  bins=100, alpha=0.6, label="ID (Food-101)")
ax[1].hist(ood_msp, bins=100, alpha=0.6, label="OOD (SVHN)")
ax[1].set_title("Distribuzione MSP");  ax[1].legend()
plt.tight_layout(); plt.show()

# Evaluation with Flower Dataset
We use the `train/` split of the Flower Dataset as our OOD test set to assess how well the model generalizes to previously unseen visual domains. Although the model was trained only to classify food categories from the Food-101 dataset, a robust OOD detector should also flag entirely unrelated categories, such as flowers, as unfamiliar.

In [None]:
T = 0.2    

model.eval()

# Scoring function

@torch.no_grad()
def get_scores(loader):
    energy, msp = [], []
    for batch in tqdm(loader, desc="Scoring", leave=False):
        x = batch[0] if isinstance(batch, (list, tuple)) else batch
        x = x.to(device, non_blocking=True)

        logits = model(x)
        energy.append((-torch.logsumexp(logits / T, 1)).cpu())          # Energy
        msp.append(torch.softmax(logits, 1).max(1).values.cpu())        # MSP
    return torch.cat(energy).numpy(), torch.cat(msp).numpy()


# Calcolo degli score
id_energy,  id_msp  = get_scores(test_loader)     # Food-101 
ood_energy, ood_msp = get_scores(flower_ood_loader)      # Tiny ImageNet 


# AUROC / AUPR  
labels = np.concatenate([np.ones_like(id_energy),
                         np.zeros_like(ood_energy)])

auroc_e = roc_auc_score(labels, -np.concatenate([id_energy, ood_energy]))
auroc_m = roc_auc_score(labels,  np.concatenate([id_msp,   ood_msp  ]))

pr_e, rc_e, _ = precision_recall_curve(labels, -np.concatenate([id_energy, ood_energy]))
pr_m, rc_m, _ = precision_recall_curve(labels,  np.concatenate([id_msp,    ood_msp ]))
aupr_e = average_precision_score(labels, -np.concatenate([id_energy, ood_energy]))
aupr_m = average_precision_score(labels,  np.concatenate([id_msp,    ood_msp ]))


# False-positive rate 

fpr_e, tpr_e, _ = roc_curve(labels, -np.concatenate([id_energy, ood_energy]))
fpr_m, tpr_m, _ = roc_curve(labels,  np.concatenate([id_msp,   ood_msp  ]))

def fpr_at_tpr(fpr, tpr, target=0.95):
    return np.interp(target, tpr, fpr) if tpr[-1] >= target else 1.0

fpr95_e = fpr_at_tpr(fpr_e, tpr_e)
fpr95_m = fpr_at_tpr(fpr_m, tpr_m)


# Risultati

print("\n  METRICHE OOD")
print(f"• AUROC  (Energy)      : {auroc_e:.4f}")
print(f"• AUROC  (Soft-max)    : {auroc_m:.4f}")
print(f"• FPR@95TPR (Energy)   : {fpr95_e*100:.2f}%")
print(f"• FPR@95TPR (Soft-max) : {fpr95_m*100:.2f}%")
print(f"• AUPR-In (Energy)     : {aupr_e:.4f}")
print(f"• AUPR-In (Soft-max)   : {aupr_m:.4f}")


# Curve ROC / PR

plt.figure(figsize=(14,5))
# ROC
plt.subplot(1,2,1)
plt.plot(fpr_e, tpr_e, label=f"Energy (AUROC={auroc_e:.4f})")
plt.plot(fpr_m, tpr_m, label=f"MSP   (AUROC={auroc_m:.4f})")
plt.plot([0,1],[0,1],'--',color='gray')
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
plt.title("ROC — OOD Detection (Tiny ImageNet)"); plt.legend(); plt.grid(True)

# PR
plt.subplot(1,2,2)
plt.plot(rc_e, pr_e, label=f"Energy (AUPR={aupr_e:.4f})")
plt.plot(rc_m, pr_m, label=f"MSP   (AUPR={aupr_m:.4f})")
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("Precision-Recall (ID positive)"); plt.legend(); plt.grid(True)
plt.tight_layout(); plt.show()


# Istogrammi

fig, ax = plt.subplots(1,2, figsize=(14,4))
ax[0].hist(id_energy,  bins=100, alpha=0.6, label="ID (Food-101)")
ax[0].hist(ood_energy, bins=100, alpha=0.6, label="OOD (Tiny ImageNet)")
ax[0].set_title("Distribuzione Energy"); ax[0].legend()

ax[1].hist(id_msp,  bins=100, alpha=0.6, label="ID (Food-101)")
ax[1].hist(ood_msp, bins=100, alpha=0.6, label="OOD (Tiny ImageNet)")
ax[1].set_title("Distribuzione MSP");  ax[1].legend()
plt.tight_layout(); plt.show()