# Imports

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import optuna
from torch.utils.data import Subset
import random

# Globals

In [17]:
class Config:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data_root = "./data"
    model_save_path = "./models"
    plot_save_path = "./plots"
    os.makedirs(model_save_path, exist_ok=True)
    os.makedirs(plot_save_path, exist_ok=True)

    model_name = 'mobilenetv3' # resnet50, mobilenetv3, repnext
    use_pretrained = False

    ood_datasets = ['SVHN', 'CIFAR10', 'FashionMNIST', 'Flowers102', 'DTD',
                    'FGVCAircraft', 'OxfordIIITPet', 'EuroSAT']
    ood_dataset = 'DTD'

    # data
    batch_size = 256
    num_workers = 4
    image_size = 224
    num_classes = 101

    # standard training
    learning_rate_std = 1e-3
    epochs_std = 30

    # energy-based fine tuning
    learning_rate_energy = 1e-4
    epochs_energy = 10
    lambda_energy = 0.1
    m_in = -11
    m_out = -5

    # GReg fine-tuning
    use_grad_reg = False
    lambda_grad = 1.0

# Utils

In [4]:
def get_energy_score(logits, T=1.0):
    return -T * torch.logsumexp(logits / T, dim=1)

def calculate_ood_metrics(id_scores, ood_scores):
    scores = np.concatenate([id_scores, ood_scores])
    scores = -scores # higher score = more confident (ID)

    labels = np.concatenate(
        [np.ones_like(id_scores), np.zeros_like(ood_scores)]
    )

    auroc = roc_auc_score(labels, scores)
    aupr = average_precision_score(labels, scores)

    # calculate FPR at 95% TPR
    fpr, tpr, _ = roc_curve(labels, scores)
    idx = np.searchsorted(tpr, 0.95)
    fpr_at_95_tpr = fpr[idx] if idx < len(fpr) else 1.0

    return auroc, aupr, fpr_at_95_tpr

def plot_distributions(id_scores, ood_scores, ood_name, title, save_path):
    plt.figure(figsize=(12, 7))
    sns.kdeplot(data=-id_scores, label=f'ID (Food-101) mean: {-np.mean(id_scores):.2f}', fill=True)
    sns.kdeplot(data=-ood_scores, label=f'OOD ({ood_name}) mean: {-np.mean(ood_scores):.2f}', fill=True)
    plt.title(title, fontsize=16)
    plt.xlabel("Score", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.savefig(save_path)
    print(f"Distribution plot saved to {save_path}")
    plt.close()

def get_ood_loader(dataset_name, transform, config, purpose):
    print(f"Loading OOD dataset: {dataset_name} ({purpose})")

    # for FashionMNIST we need to convert grayscale -> 3 channels
    if dataset_name == 'FashionMNIST':
        ood_transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            *transform.transforms
        ])
    else:
        ood_transform = transform

    if dataset_name == 'SVHN':
        ood_dataset = torchvision.datasets.SVHN(root=config.data_root, split='test', download=True, transform=ood_transform)
    elif dataset_name == 'CIFAR10':
        ood_dataset = torchvision.datasets.CIFAR10(root=config.data_root, train=False, download=True, transform=ood_transform)
    elif dataset_name == 'FashionMNIST':
        ood_dataset = torchvision.datasets.FashionMNIST(root=config.data_root, train=False, download=True, transform=ood_transform)
    elif dataset_name == 'Flowers102':
        ood_dataset = torchvision.datasets.Flowers102(root=config.data_root, split='test', download=True, transform=ood_transform)
    elif dataset_name == 'DTD':
        ood_dataset = torchvision.datasets.DTD(root=config.data_root, split='test', download=True, transform=ood_transform)
    elif dataset_name == 'FGVCAircraft':
        ood_dataset = torchvision.datasets.FGVCAircraft(root=config.data_root, split='test', download=True, transform=ood_transform)
    elif dataset_name == 'OxfordIIITPet':
        ood_dataset = torchvision.datasets.OxfordIIITPet(root=config.data_root, split='test', download=True, transform=ood_transform)
    elif dataset_name == 'EuroSAT':
        ood_dataset = torchvision.datasets.EuroSAT(root=config.data_root, download=True, transform=ood_transform)
    else:
        raise ValueError(f"Unknown OOD dataset name: {dataset_name}")

    ood_loader = DataLoader(ood_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=True)
    print(f"{dataset_name} ({purpose}) loader ready.")
    return ood_loader

# for hyperparameter search
def run_ood_evaluation(model, id_test_loader, ood_test_loader):
    model.eval()
    id_scores, ood_scores = [], []
    with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
        for inputs, _ in id_test_loader:
            inputs = inputs.to(Config.device)
            logits = model(inputs)
            id_scores.extend(get_energy_score(logits).cpu().numpy())
        for inputs, _ in ood_test_loader:
            inputs = inputs.to(Config.device)
            logits = model(inputs)
            ood_scores.extend(get_energy_score(logits).cpu().numpy())
    id_scores = np.array(id_scores)
    ood_scores = np.array(ood_scores)
    auroc, _, _ = calculate_ood_metrics(id_scores, ood_scores)
    return auroc

def run_energy_tuning_trial(model, id_loader, ood_loader, val_loader, ood_val_loader, epochs, lr, m_in, m_out):
    criterion_ce = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    scaler = torch.amp.GradScaler('cuda')

    best_auroc = 0.0

    for epoch in range(epochs):
        model.train()
        ood_iter = iter(ood_loader)
        pbar = tqdm(id_loader, desc=f"Trial epoch {epoch+1}/{epochs}", leave=False)
        for id_inputs, id_labels in pbar:
            try:
                ood_inputs, _ = next(ood_iter)
            except StopIteration:
                ood_iter = iter(ood_loader)
                ood_inputs, _ = next(ood_iter)

            id_inputs, id_labels = id_inputs.to(Config.device), id_labels.to(Config.device)
            ood_inputs = ood_inputs.to(Config.device)
            optimizer.zero_grad()

            with torch.amp.autocast(device_type='cuda'):
                id_logits = model(id_inputs)
                ood_logits = model(ood_inputs)
                loss_ce = criterion_ce(id_logits, id_labels)
                id_energy = get_energy_score(id_logits)
                ood_energy = get_energy_score(ood_logits)
                loss_energy = (torch.pow(torch.relu(id_energy - m_in), 2).mean() +
                               torch.pow(torch.relu(m_out - ood_energy), 2).mean())
                loss = loss_ce + Config.lambda_energy * loss_energy

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        auroc = run_ood_evaluation(model, val_loader, ood_val_loader)
        if auroc > best_auroc:
            best_auroc = auroc

    return best_auroc

def objective(trial):
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    m_in = trial.suggest_float("m_in", -27.0, -8.0)
    m_out = trial.suggest_float("m_out", -7.0, -1.0)

    print(f"\n--- Starting trial {trial.number} ---")
    print(f"  Params: lr={lr:.6f}, m_in={m_in:.2f}, m_out={m_out:.2f}")

    model_search = create_model()
    model_std_path = os.path.join(Config.model_save_path, f"{Config.model_name}_standard.pth")
    model_search.load_state_dict(torch.load(model_std_path, map_location=Config.device))

    epochs_per_trial = 3
    best_auroc = run_energy_tuning_trial(
        model=model_search,
        id_loader=id_train_loader,
        ood_loader=ood_train_loader_energy,
        val_loader=id_val_loader_subset,
        ood_val_loader=ood_val_loader_subset,
        epochs=epochs_per_trial,
        lr=lr,
        m_in=m_in,
        m_out=m_out
    )

    print(f"--- Trial {trial.number} finished --- AUROC: {best_auroc:.4f}")
    return best_auroc

# GReg
def get_gradient_norm(model, inputs):
    inputs.requires_grad_(True)
    logits = model(inputs)
    energy_score = get_energy_score(logits)

    grad_output = torch.ones_like(energy_score)
    gradient = torch.autograd.grad(
        outputs=energy_score,
        inputs=inputs,
        grad_outputs=grad_output,
        create_graph=True,
        retain_graph=True
    )[0]

    grad_norm = torch.linalg.norm(gradient.view(gradient.size(0), -1), dim=1)
    inputs.requires_grad_(False)
    return grad_norm

# CORES
# consider both logits and internal feature maps
# IDs produce stronger and more frequent responses for important kernels than OODs
# => internal filters fire strongly and frequently for ID features => high RM and RF
class CORESScorer:
    def __init__(self, model, k_percent=0.2):
        self.model = model
        self.k = k_percent          # % kernels to select
        self.feature_maps = {}      # store output of the layer we are interested in
        self.hooks = []             # store listeners

        # get key layers
        if isinstance(model, MobileNetV3):
            self.final_conv_name = 'final_conv'
            self.fc_layer = model.head[1]
        elif isinstance(model, RepNeXt):
            self.final_conv_name = 'layer3'
            self.fc_layer = model.linear

        # listen to the final conv layer
        self._register_hook(self.final_conv_name)

    # find the layer
    def _register_hook(self, layer_name):
        for name, module in self.model.named_modules():
            if name == layer_name:
                target_layer = module
                break
        else:
            raise NameError(f"Layer {layer_name} not found")

        handle = target_layer.register_forward_hook(self._hook_fn(layer_name))
        self.hooks.append(handle)

    # listener fn
    def _hook_fn(self, layer_name):
        def hook(module, input, output):
            self.feature_maps[layer_name] = output
        return hook

    # sample-relevant kernel selection
    def _get_selected_kernels(self, logits):
        # find the most and least likely class predictions
        # take logits and find class index with highest and lowest score for each img in batch
        c_max = torch.argmax(logits, dim=1)
        c_min = torch.argmin(logits, dim=1)

        if self.fc_layer.weight.dim() == 4:
            fc_weights = self.fc_layer.weight.data.squeeze()
        else:
            fc_weights = self.fc_layer.weight.data

        num_kernels = self.feature_maps[self.final_conv_name].shape[1]
        k_num = max(1, int(num_kernels * self.k))

        selected_indices = []
        for i in range(logits.shape[0]):
            out_channels = fc_weights.shape[0]
            weights_max = fc_weights[c_max[i] % out_channels]
            weights_min = fc_weights[c_min[i] % out_channels]

            # backtracking
            # find k_num channels (kernels) with strongest weights (most influence) for both c_max and c_min predictions
            _, top_indices = torch.topk(weights_max, k_num)
            _, bot_indices = torch.topk(weights_min, k_num)

            # combine them to get sample-relevant kernels for this image
            indices = torch.cat([top_indices, bot_indices])
            selected_indices.append(torch.unique(indices))

        return selected_indices

    def _calculate_score(self, feature_map_batch, selected_indices_batch):
        batch_scores = []
        for i in range(feature_map_batch.shape[0]):
            feature_map = feature_map_batch[i]
            selected_indices = selected_indices_batch[i]

            # keep only selected kernels
            selected_features = feature_map[selected_indices]

            # find peak positive and negative response for each kernel
            max_responses = torch.amax(selected_features, dim=(1, 2))
            min_responses = torch.amin(selected_features, dim=(1, 2))

            # score components computation
            rm_pos = torch.mean(max_responses.clamp(min=1e-6))                  # response magnitude positive: avg of highest values
            rm_neg = torch.mean((-min_responses).clamp(min=1e-6))               # response magnitude negative: avg of negated lowest values
            rf_pos = torch.mean((max_responses > 0).float()).clamp(min=1e-6)    # response frequency positive: fraction of selected kernels that had a positive peak response
            rf_neg = torch.mean((min_responses < 0).float()).clamp(min=1e-6)    # response frequency negative: fraction of selected kernels that had a negative peak response

            # combine them to get the final score (sum of logs = product)
            score = torch.log(rm_pos) + torch.log(rm_neg) + torch.log(rf_pos) + torch.log(rf_neg)
            batch_scores.append(score.item())

        return np.array(batch_scores)

    def __call__(self, x):
        self.feature_maps.clear()
        logits = self.model(x)

        kernel_indices = self._get_selected_kernels(logits)
        final_conv_maps = self.feature_maps[self.final_conv_name]

        scores = self._calculate_score(final_conv_maps, kernel_indices)
        return scores

    def close(self):
        for handle in self.hooks:
            handle.remove()

# Networks

## RepNeXt

In [5]:
class RepNeXtBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv_3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv_1x1 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self, x):
        out = self.act(self.bn1(self.conv_3x3(x)))
        out = self.bn2(self.conv_1x1(out))
        out += self.shortcut(x)
        return self.act(out)

class RepNeXt(nn.Module):
    def __init__(self, num_blocks, num_classes=101):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.act = nn.ReLU()
        self.layer1 = self._make_layer(64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(256, num_blocks[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(256, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(RepNeXtBlock(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def get_features(self, x):
        out = self.act(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        return out.view(out.size(0), -1)

    def forward(self, x):
        features = self.get_features(x)
        logits = self.linear(features)
        return logits, features

## MobileNetV3

In [5]:
class SqueezeExcite(nn.Module):
    def __init__(self, in_channels, reduced_dim):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, reduced_dim, 1),
            nn.SiLU(),
            nn.Conv2d(reduced_dim, in_channels, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return x * self.se(x)

class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, expansion_factor, use_se=True):
        super().__init__()
        self.stride = stride
        hidden_dim = in_channels * expansion_factor
        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        layers = []
        # expansion phase with 1x1 pointwise conv
        if expansion_factor != 1:
            layers.append(nn.Conv2d(in_channels, hidden_dim, 1, 1, 0, bias=False))
            layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.SiLU())

        # depthwise conv
        layers.extend([
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, kernel_size//2, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
        ])

        # squeeze-and-excite layer
        if use_se:
            layers.append(SqueezeExcite(hidden_dim, in_channels // 4))

        # projection phase with 1x1 pointwise conv
        layers.append(nn.Conv2d(hidden_dim, out_channels, 1, 1, 0, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))

        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MobileNetV3(nn.Module):
    def __init__(self, num_classes=101):
        super().__init__()
        config = [
            # expansion, out_channels, num_repeats, kernel_size, stride, use_se
            [1, 16, 1, 3, 1, True],
            [4, 24, 2, 3, 2, False],
            [3, 40, 3, 5, 2, True],
            [6, 80, 3, 3, 2, False],
            [6, 112, 2, 3, 1, True],
            [6, 160, 3, 5, 2, True],
        ]

        self.stem = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.SiLU(),
        )

        in_channels = 16
        blocks = []
        for expansion, out_channels, num_repeats, kernel_size, stride, use_se in config:
            for i in range(num_repeats):
                block_stride = stride if i == 0 else 1
                blocks.append(InvertedResidual(in_channels, out_channels, kernel_size, block_stride, expansion, use_se))
                in_channels = out_channels

        self.blocks = nn.Sequential(*blocks)

        self.final_conv = nn.Sequential(
            nn.Conv2d(in_channels, 960, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(960),
            nn.SiLU()
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(960, 1280, 1),
            nn.SiLU(),
            nn.Dropout(0.5),
            nn.Conv2d(1280, num_classes, 1),
        )

    def forward(self, x, return_features=False):
        x = self.stem(x)
        features = self.blocks(x)
        x = self.final_conv(features)
        x = self.head(x)
        logits = x.view(x.size(0), -1)

        if return_features:
            return logits, features
        return logits

## Model creation

In [6]:
def create_model():
    if Config.model_name == 'resnet50':
        print(f"Creating ResNet50 model {f'(pretrained)' if Config.use_pretrained else None}")
        weights = models.ResNet50_Weights.IMAGENET1K_V2 if Config.use_pretrained else None
        model = models.resnet50(weights=weights)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, Config.num_classes)
    elif Config.model_name == 'mobilenetv3':
        print("Creating MobileNetV3 model")
        model = MobileNetV3(num_classes=Config.num_classes)
    elif Config.model_name == 'repnext':
        print('Creating RepNeXt model')
        model = RepNeXt(num_blocks=[2, 2, 2], num_classes=Config.num_classes)
    else:
        raise ValueError(f"Unknown model {Config.model_name}; choose among resnet50, mobilnetv3, repnext")
    return model.to(Config.device)

# Train


In [7]:
def train(model, train_loader, val_loader, epochs, lr, save_path):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Training]")
        for inputs, labels in pbar:
            inputs, labels = inputs.to(Config.device), labels.to(Config.device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            pbar.set_postfix(loss=f'{running_loss/len(pbar):.4f}')

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Validation]"):
                inputs, labels = inputs.to(Config.device), labels.to(Config.device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        print(f"Epoch {epoch+1} | Validation accuracy: {acc:.2f}% | LR: {scheduler.get_last_lr()[0]:.6f}")

        if acc > best_acc:
            best_acc = acc
            print(f"New best accuracy, saving model to {save_path}")
            torch.save(model.state_dict(), save_path)

        scheduler.step()


def train_energy(model, id_loader, ood_loader, val_loader, ood_val_loader, epochs, lr, save_path, use_grad_reg=False):
    criterion_ce = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = torch.amp.GradScaler('cuda')
    best_auroc = 0.0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        ood_iter = iter(ood_loader)
        pbar = tqdm(id_loader, desc=f"Epoch {epoch+1}/{epochs} [Energy training]")
        for id_inputs, id_labels in pbar:
            try:
                ood_inputs, _ = next(ood_iter)
            except StopIteration:
                ood_iter = iter(ood_loader)
                ood_inputs, _ = next(ood_iter)

            id_inputs, id_labels = id_inputs.to(Config.device), id_labels.to(Config.device)
            ood_inputs = ood_inputs.to(Config.device)

            optimizer.zero_grad()
            with torch.amp.autocast(device_type='cuda'):
                id_logits = model(id_inputs)
                ood_logits = model(ood_inputs)

                loss_ce = criterion_ce(id_logits, id_labels)
                id_energy = get_energy_score(id_logits)
                ood_energy = get_energy_score(ood_logits)

                loss_energy = (torch.pow(torch.relu(id_energy - Config.m_in), 2).mean() +
                               torch.pow(torch.relu(Config.m_out - ood_energy), 2).mean())

                loss = loss_ce + Config.lambda_energy * loss_energy

                if use_grad_reg:
                    loss_grad_id = torch.tensor(0.0, device=Config.device)
                    loss_grad_ood = torch.tensor(0.0, device=Config.device)

                    id_grad_norm = get_gradient_norm(model, id_inputs)
                    ood_grad_norm = get_gradient_norm(model, ood_inputs)

                    # apply loss only to well-behaved samples
                    id_mask = id_energy < Config.m_in
                    if id_mask.any():
                        loss_grad_id = (id_grad_norm[id_mask]**2).mean()
                    ood_mask = ood_energy > Config.m_out
                    if ood_mask.any():
                        loss_grad_ood = (ood_grad_norm[ood_mask]**2).mean()
                    loss_grad = loss_grad_id + loss_grad_ood

                    loss += Config.lambda_grad * loss_grad

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            postfix_dict = {
                'loss': f'{running_loss/len(pbar):.4f}',
                'ce': f'{loss_ce.item():.4f}',
                'en': f'{loss_energy.item():.4f}'
            }
            if use_grad_reg:
                postfix_dict['grad'] = f'{loss_grad.item():.4f}'
            pbar.set_postfix(postfix_dict)

        # validation on OOD metrics
        print(f"\n--- Evaluating after epoch {epoch+1} ---")
        plot_path_epoch = os.path.join(Config.plot_save_path, f"energy_dist_epoch_{epoch+1}.png")
        val_acc, auroc, _, _ = evaluate_model(model, val_loader, ood_val_loader, "Energy validation", plot_save_path=plot_path_epoch)
        print(f"--- Current val acc: {val_acc:.2f}%, AUROC: {auroc:.4f} ---")
        if auroc > best_auroc:
            best_auroc = auroc
            print(f"New best AUROC, saving model to {save_path}")
            torch.save(model.state_dict(), save_path)

        scheduler.step()

# Evaluation

In [8]:
def evaluate_model_old(model, id_test_loader, ood_test_loader, eval_name="Evaluation", plot_save_path=None, score_type='energy'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
        for inputs, labels in tqdm(id_test_loader, desc=f"{eval_name} [Accuracy]"):
            inputs, labels = inputs.to(Config.device), labels.to(Config.device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    print(f"\n{eval_name} - Classification accuracy on ID test set: {acc:.2f}%")

    id_scores, ood_scores = [], []
    with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
        for inputs, _ in tqdm(id_test_loader, desc=f"{eval_name} [OOD scores - ID]"):
            inputs = inputs.to(Config.device)
            logits = model(inputs)
            id_scores.extend(get_energy_score(logits).cpu().numpy())

        for inputs, _ in tqdm(ood_test_loader, desc=f"{eval_name} [OOD scores - OOD]"):
            inputs = inputs.to(Config.device)
            logits = model(inputs)
            ood_scores.extend(get_energy_score(logits).cpu().numpy())

    id_scores = np.array(id_scores)
    ood_scores = np.array(ood_scores)
    auroc, aupr, fpr_at_95_tpr = calculate_ood_metrics(id_scores, ood_scores)

    print(f"{eval_name} - OOD detection performance:")
    print(f"  AUROC: {auroc:.4f}")
    print(f"  AUPR: {aupr:.4f}")
    print(f"  FPR @ 95% TPR: {fpr_at_95_tpr:.4f}")

    if plot_save_path:
        plot_distributions(id_scores, ood_scores, Config.ood_dataset, f"Energy score distribution ({eval_name})", plot_save_path)

    return acc, auroc, aupr, fpr_at_95_tpr

In [10]:
def evaluate_model(model, id_test_loader, ood_test_loader, eval_name="Evaluation", plot_save_path=None, score_type='energy'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
        for inputs, labels in tqdm(id_test_loader, desc=f"{eval_name} [Accuracy]"):
            inputs, labels = inputs.to(Config.device), labels.to(Config.device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    print(f"\n{eval_name} - Classification accuracy on ID test set: {acc:.2f}%")

    model.eval()
    correct, total = 0, 0
    with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
        for inputs, labels in tqdm(id_test_loader, desc=f"{eval_name} [Accuracy]"):
            inputs, labels = inputs.to(Config.device), labels.to(Config.device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    print(f"\n{eval_name} - Classification accuracy on ID test set: {acc:.2f}%")


    id_scores, ood_scores = [], []

    if score_type == 'cores':
        scorer = CORESScorer(model)

    with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
        for loader, score_list in [(id_test_loader, id_scores), (ood_test_loader, ood_scores)]:
            desc = f"{eval_name} [Scores - {'ID' if loader==id_test_loader else 'OOD'}]"
            for inputs, _ in tqdm(loader, desc=desc):
                inputs = inputs.to(Config.device)

                if score_type == 'energy':
                    logits = model(inputs)
                    scores = get_energy_score(logits).cpu().numpy()
                elif score_type == 'cores':
                    scores = scorer(inputs)
                else:
                    raise ValueError("score_type must be 'energy' or 'cores'")

                score_list.extend(scores)

    if score_type == 'cores':
        scorer.close()

    id_scores, ood_scores = np.array(id_scores), np.array(ood_scores)

    if score_type == 'cores':
        id_scores = -id_scores
        ood_scores = -ood_scores

    auroc, aupr, fpr95 = calculate_ood_metrics(id_scores, ood_scores)

    print(f"{eval_name} - OOD detection performance ({score_type.upper()}):")
    print(f"  AUROC: {auroc:.4f}")
    print(f"  AUPR: {aupr:.4f}")
    print(f"  FPR @ 95% TPR: {fpr95:.4f}")

    if plot_save_path:
        plot_title = f"{score_type.capitalize()} Score Distribution ({eval_name})"
        plot_distributions(id_scores, ood_scores, Config.ood_dataset, plot_title, plot_save_path)

    return acc, auroc, aupr, fpr95

# Data

In [11]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(Config.image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    normalize,
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(Config.image_size),
    transforms.ToTensor(),
    normalize,
])

food_train_dataset = torchvision.datasets.Food101(root=Config.data_root, split='train', download=True, transform=transform_train)
food_test_dataset = torchvision.datasets.Food101(root=Config.data_root, split='test', download=True, transform=transform_test)

id_train_loader = DataLoader(food_train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=Config.num_workers, pin_memory=True)
id_test_loader = DataLoader(food_test_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=Config.num_workers, pin_memory=True)

ood_train_loader_energy = get_ood_loader(Config.ood_dataset, transform_train, Config, 'train')
ood_test_loader = get_ood_loader(Config.ood_dataset, transform_test, Config, 'test')

Loading OOD dataset: DTD (train)
DTD (train) loader ready.
Loading OOD dataset: DTD (test)
DTD (test) loader ready.


# Main

## Training

In [None]:
model_std = create_model()
model_std_path = os.path.join(Config.model_save_path, f"{Config.model_name}_standard.pth")

if os.path.exists(model_std_path):
    print(f"Loading pre-trained standard model from {model_std_path}")
    model_std.load_state_dict(torch.load(model_std_path, map_location=Config.device))
else:
    print("No pre-trained model found, starting training...")
    train(
        model=model_std,
        train_loader=id_train_loader,
        val_loader=id_test_loader,
        epochs=Config.epochs_std,
        lr=Config.learning_rate_std,
        save_path=model_std_path
    )
    model_std.load_state_dict(torch.load(model_std_path, map_location=Config.device))

print("\nEvaluating baseline model")
acc, auroc, aupr, fpr95 = evaluate_model(
    model=model_std,
    id_test_loader=id_test_loader,
    ood_test_loader=ood_test_loader,
    eval_name="Baseline model",
    plot_save_path=os.path.join(Config.plot_save_path, "baseline_distribution.png")
)

## Energy-based fine-tuning

In [None]:
model_std_path = os.path.join(Config.model_save_path, f"{Config.model_name}_standard.pth")
model_energy_path = os.path.join(Config.model_save_path, f"{Config.model_name}_energy.pth")

model_energy = create_model()
print(f"Loading weights from baseline model: {model_std_path}")
model_energy.load_state_dict(torch.load(model_std_path, map_location=Config.device))

# freeze all layers
for param in model_energy.parameters():
    param.requires_grad = False

# unfreeze the head
for param in model_energy.head.parameters():
    param.requires_grad = True

print("Starting energy-based fine-tuning against {Config.ood_dataset}")
train_energy(
    model=model_energy,
    id_loader=id_train_loader,
    ood_loader=ood_train_loader_energy,
    val_loader=id_test_loader,
    ood_val_loader=ood_test_loader,
    epochs=Config.epochs_energy,
    lr=Config.learning_rate_energy,
    save_path=model_energy_path,
    use_grad_reg=Config.use_grad_reg
)

print(f"Loading best energy model from: {model_energy_path}")
model_energy.load_state_dict(torch.load(model_energy_path, map_location=Config.device))

print("\nEvaluating final energy model")
evaluate_model(
    model=model_energy,
    id_test_loader=id_test_loader,
    ood_test_loader=ood_test_loader,
    eval_name="Energy model",
    plot_save_path=os.path.join(Config.plot_save_path, "energy_tuned_distribution.png")
)

## Results

In [19]:
model_path = os.path.join(Config.model_save_path, f"{Config.model_name}_standard.pth")
#model_path = os.path.join(Config.model_save_path, f"{Config.model_name}_energy.pth")

print(f"--- Evaluating model: {model_path} ---")

model_to_eval = create_model()
model_to_eval.load_state_dict(torch.load(model_path, map_location=Config.device))
model_to_eval.eval()

# eval acc
overall_accuracy = 0
correct, total = 0, 0
with torch.no_grad():
    for inputs, labels in tqdm(id_test_loader, desc="Calculating accuracy"):
        outputs = model_to_eval(inputs.to(Config.device))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(Config.device)).sum().item()
overall_accuracy = 100 * correct / total

# compute OOD metrics
results = []
model_short_name = os.path.basename(model_path).replace('.pth', '')

id_scores = []
with torch.no_grad():
    for inputs, _ in tqdm(id_test_loader, desc="Calculating ID scores"):
        id_scores.extend(get_energy_score(model_to_eval(inputs.to(Config.device))).cpu().numpy())
id_scores = np.array(id_scores)

for ood_name in Config.ood_datasets:
    ood_loader_eval = get_ood_loader(ood_name, transform_test, Config, 'test')

    ood_scores = []
    with torch.no_grad():
        for inputs, _ in tqdm(ood_loader_eval, desc=f"OOD scores for {ood_name}"):
            ood_scores.extend(get_energy_score(model_to_eval(inputs.to(Config.device))).cpu().numpy())

    ood_scores = np.array(ood_scores)
    auroc, aupr, fpr95 = calculate_ood_metrics(id_scores, ood_scores)
    results.append({"OOD dataset": ood_name, "AUROC": auroc, "AUPR": aupr, "FPR@95TPR": fpr95})

    # plot
    plot_save_path = os.path.join(Config.plot_save_path, f"{model_short_name}_{ood_name}.png")
    plot_title = f"Energy distribution ({model_short_name}): Food-101 vs {ood_name}"
    plot_distributions(id_scores, ood_scores, ood_name, plot_title, plot_save_path)


# summary table
print(f"\n--- Performance summary: {model_path} ---")
print(f"Classification accuracy (Food-101): {overall_accuracy:.2f}%")
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False, float_format="%.4f"))

--- Evaluating model: ./models/mobilenetv3_standard.pth ---
Creating MobileNetV3 model


Calculating accuracy: 100%|███████████████████████████████| 99/99 [00:27<00:00,  3.65it/s]
Calculating ID scores (once): 100%|███████████████████████| 99/99 [00:26<00:00,  3.69it/s]


Loading OOD dataset: SVHN (test)
SVHN (test) loader ready.


OOD scores for SVHN: 100%|██████████████████████████████| 102/102 [00:11<00:00,  8.63it/s]


Distribution plot saved to ./plots/mobilenetv3_standard_SVHN.png
Loading OOD dataset: CIFAR10 (test)
CIFAR10 (test) loader ready.


OOD scores for CIFAR10: 100%|█████████████████████████████| 40/40 [00:04<00:00,  8.77it/s]


Distribution plot saved to ./plots/mobilenetv3_standard_CIFAR10.png
Loading OOD dataset: FashionMNIST (test)
FashionMNIST (test) loader ready.


OOD scores for FashionMNIST: 100%|████████████████████████| 40/40 [00:05<00:00,  7.74it/s]


Distribution plot saved to ./plots/mobilenetv3_standard_FashionMNIST.png
Loading OOD dataset: Flowers102 (test)
Flowers102 (test) loader ready.


OOD scores for Flowers102: 100%|██████████████████████████| 25/25 [00:08<00:00,  3.08it/s]


Distribution plot saved to ./plots/mobilenetv3_standard_Flowers102.png
Loading OOD dataset: DTD (test)
DTD (test) loader ready.


OOD scores for DTD: 100%|███████████████████████████████████| 8/8 [00:02<00:00,  2.80it/s]


Distribution plot saved to ./plots/mobilenetv3_standard_DTD.png
Loading OOD dataset: FGVCAircraft (test)
FGVCAircraft (test) loader ready.


OOD scores for FGVCAircraft: 100%|████████████████████████| 14/14 [00:12<00:00,  1.16it/s]


Distribution plot saved to ./plots/mobilenetv3_standard_FGVCAircraft.png
Loading OOD dataset: OxfordIIITPet (test)
OxfordIIITPet (test) loader ready.


OOD scores for OxfordIIITPet: 100%|███████████████████████| 15/15 [00:05<00:00,  2.89it/s]


Distribution plot saved to ./plots/mobilenetv3_standard_OxfordIIITPet.png
Loading OOD dataset: EuroSAT (test)
EuroSAT (test) loader ready.


OOD scores for EuroSAT: 100%|███████████████████████████| 106/106 [00:12<00:00,  8.19it/s]


Distribution plot saved to ./plots/mobilenetv3_standard_EuroSAT.png

--- Performance summary: ./models/mobilenetv3_standard.pth ---
Classification accuracy (Food-101): 67.56%
  OOD dataset  AUROC   AUPR  FPR@95TPR
         SVHN 0.9992 0.9992     0.0020
      CIFAR10 0.9948 0.9978     0.0217
 FashionMNIST 1.0000 1.0000     0.0000
   Flowers102 0.8825 0.9673     0.5097
          DTD 0.9353 0.9935     0.2532
 FGVCAircraft 0.9961 0.9995     0.0177
OxfordIIITPet 0.9239 0.9878     0.4064
      EuroSAT 0.9997 0.9997     0.0005


In [24]:
#model_path = os.path.join(Config.model_save_path, f"{Config.model_name}_standard.pth")
model_path = os.path.join(Config.model_save_path, f"{Config.model_name}_energy.pth")

print(f"--- Evaluating model: {os.path.basename(model_path)} with CORES scores ---")
model_to_eval = create_model()
model_to_eval.load_state_dict(torch.load(model_path, map_location=Config.device))
model_to_eval.eval()
model_short_name = os.path.basename(model_path).replace('.pth', '')

correct, total = 0, 0
with torch.no_grad():
    for inputs, labels in tqdm(id_test_loader, desc="Calculating accuracy"):
        outputs = model_to_eval(inputs.to(Config.device))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(Config.device)).sum().item()
overall_accuracy = 100 * correct / total

id_scores = []
scorer = CORESScorer(model_to_eval)
with torch.no_grad():
    for inputs, _ in tqdm(id_test_loader, desc="Calculating ID scores (CORES)"):
        inputs = inputs.to(Config.device)
        id_scores.extend(scorer(inputs))
id_scores = np.array(id_scores)

results = []
for ood_name in Config.ood_datasets:
    ood_loader_eval = get_ood_loader(ood_name, transform_test, Config, 'test')
    ood_scores = []
    with torch.no_grad():
        for inputs, _ in tqdm(ood_loader_eval, desc=f"OOD scores for {ood_name}"):
            inputs = inputs.to(Config.device)
            ood_scores.extend(scorer(inputs))
    ood_scores = np.array(ood_scores)

    metric_id_scores = -id_scores
    metric_ood_scores = -ood_scores
    auroc, aupr, fpr95 = calculate_ood_metrics(metric_id_scores, metric_ood_scores)
    results.append({"OOD dataset": ood_name, "AUROC": auroc, "AUPR": aupr, "FPR@95TPR": fpr95})

    plot_save_path = os.path.join(Config.plot_save_path, f"{model_short_name}_{ood_name}_cores.png")
    plot_title = f"CORES distribution ({model_short_name}): Food-101 vs {ood_name}"
    plot_distributions(metric_id_scores, metric_ood_scores, ood_name, plot_title, plot_save_path)

scorer.close()

print(f"\n--- Performance summary: {os.path.basename(model_path)} (CORES) ---")
print(f"Classification accuracy (Food-101): {overall_accuracy:.2f}%")
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False, float_format="%.4f"))

--- Evaluating model: mobilenetv3_energy.pth with CORES scores ---
Creating MobileNetV3 model


Calculating accuracy: 100%|███████████████████████████████| 99/99 [00:26<00:00,  3.78it/s]
Calculating ID scores (CORES): 100%|██████████████████████| 99/99 [00:25<00:00,  3.82it/s]


Loading OOD dataset: SVHN (test)
SVHN (test) loader ready.


OOD scores for SVHN: 100%|██████████████████████████████| 102/102 [00:16<00:00,  6.14it/s]


Distribution plot saved to ./plots/mobilenetv3_energy_SVHN_cores.png
Loading OOD dataset: CIFAR10 (test)
CIFAR10 (test) loader ready.


OOD scores for CIFAR10: 100%|█████████████████████████████| 40/40 [00:07<00:00,  5.62it/s]


Distribution plot saved to ./plots/mobilenetv3_energy_CIFAR10_cores.png
Loading OOD dataset: FashionMNIST (test)
FashionMNIST (test) loader ready.


OOD scores for FashionMNIST: 100%|████████████████████████| 40/40 [00:06<00:00,  5.75it/s]


Distribution plot saved to ./plots/mobilenetv3_energy_FashionMNIST_cores.png
Loading OOD dataset: Flowers102 (test)
Flowers102 (test) loader ready.


OOD scores for Flowers102: 100%|██████████████████████████| 25/25 [00:08<00:00,  2.93it/s]


Distribution plot saved to ./plots/mobilenetv3_energy_Flowers102_cores.png
Loading OOD dataset: DTD (test)
DTD (test) loader ready.


OOD scores for DTD: 100%|███████████████████████████████████| 8/8 [00:03<00:00,  2.57it/s]


Distribution plot saved to ./plots/mobilenetv3_energy_DTD_cores.png
Loading OOD dataset: FGVCAircraft (test)
FGVCAircraft (test) loader ready.


OOD scores for FGVCAircraft: 100%|████████████████████████| 14/14 [00:12<00:00,  1.14it/s]


Distribution plot saved to ./plots/mobilenetv3_energy_FGVCAircraft_cores.png
Loading OOD dataset: OxfordIIITPet (test)
OxfordIIITPet (test) loader ready.


OOD scores for OxfordIIITPet: 100%|███████████████████████| 15/15 [00:05<00:00,  2.85it/s]


Distribution plot saved to ./plots/mobilenetv3_energy_OxfordIIITPet_cores.png
Loading OOD dataset: EuroSAT (test)
EuroSAT (test) loader ready.


OOD scores for EuroSAT: 100%|███████████████████████████| 106/106 [00:17<00:00,  6.12it/s]


Distribution plot saved to ./plots/mobilenetv3_energy_EuroSAT_cores.png

--- Performance summary: mobilenetv3_energy.pth (CORES) ---
Classification accuracy (Food-101): 76.00%
  OOD dataset  AUROC   AUPR  FPR@95TPR
         SVHN 0.9998 0.9998     0.0001
      CIFAR10 0.9949 0.9981     0.0123
 FashionMNIST 0.9999 1.0000     0.0003
   Flowers102 0.9204 0.9793     0.4329
          DTD 0.9836 0.9986     0.0803
 FGVCAircraft 0.9922 0.9989     0.0324
OxfordIIITPet 0.8052 0.9657     0.7296
      EuroSAT 0.9994 0.9994     0.0010


## Hyperparameter search

In [None]:
print("Creating validation subsets for Optuna search...")
ood_search_dataset = get_ood_loader(Config.ood_dataset, transform_test, Config, 'test').dataset

val_subset_fraction = 0.2
id_test_indices = random.sample(range(len(food_test_dataset)), int(len(food_test_dataset) * val_subset_fraction))
ood_test_indices = random.sample(range(len(ood_search_dataset)), int(len(ood_search_dataset) * val_subset_fraction))

id_val_subset = Subset(food_test_dataset, id_test_indices)
ood_val_subset = Subset(ood_search_dataset, ood_test_indices)

id_val_loader_subset = DataLoader(id_val_subset, batch_size=Config.batch_size, num_workers=Config.num_workers)
ood_val_loader_subset = DataLoader(ood_val_subset, batch_size=Config.batch_size, num_workers=Config.num_workers)

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)

print(f"Best trial number: {study.best_trial.number}")
print(f"Best AUROC: {study.best_value:.4f}")
print("Best hyperparameters found:", study.best_params)

## Debug

In [21]:
#model_path = os.path.join(Config.model_save_path, f"{Config.model_name}_standard.pth")
model_path = os.path.join(Config.model_save_path, f"{Config.model_name}_energy.pth")

model_to_diag = create_model()
model_to_diag.load_state_dict(torch.load(model_path, map_location=Config.device))
model_to_diag.eval()

id_scores, ood_scores = [], []
with torch.no_grad():
    for inputs, _ in tqdm(id_test_loader, desc="Calculating ID scores"):
        id_scores.extend(get_energy_score(model_to_diag(inputs.to(Config.device))).cpu().numpy())

    for inputs, _ in tqdm(ood_test_loader, desc=f"Calculating OOD scores for {Config.ood_dataset}"):
        ood_scores.extend(get_energy_score(model_to_diag(inputs.to(Config.device))).cpu().numpy())

print(f"\nAggregate results for model: {model_path}")
print(f"Overall mean ID energy: {np.mean(id_scores):.2f} (+/-: {np.std(id_scores):.2f})")
print(f"Overall mean OOD energy: {np.mean(ood_scores):.2f} (+/-: {np.std(ood_scores):.2f})")

Creating MobileNetV3 model


Calculating ID scores: 100%|██████████████████████████████| 99/99 [00:26<00:00,  3.78it/s]
Calculating OOD scores for DTD: 100%|███████████████████████| 8/8 [00:02<00:00,  2.81it/s]


Aggregate results for model: ./models/mobilenetv3_energy.pth
Overall mean ID energy: -12.27 (+/-: 4.36)
Overall mean OOD energy: -6.38 (+/-: 2.01)



