In [None]:
"""
OOD Detection Notebook Template
Convert to .ipynb using: jupytext --to notebook 02_ood_detection.py
"""

# Q2: Out-of-Distribution Detection

Evaluate 6 OOD scoring methods across training epochs:
- Output-based: MSP, MaxLogit, Energy
- Distance-based: Mahalanobis  
- Feature-based: ViM, NECO (TPT only)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os 
if not os.path.exists('/content/OOD-Detection-Project---CSC_5IA23'):
    !git clone https://github.com/DiegoFleury/OOD-Detection-Project---CSC_5IA23.git
%cd /content/OOD-Detection-Project---CSC_5IA23

In [None]:
import torch
import numpy as np
import yaml
import glob
import pickle
from tqdm import tqdm

from src.models import ResNet18
from src.data import get_cifar100_loaders, get_ood_loaders
from src.ood_scores import (
    MSPScorer, MaxLogitScorer, EnergyScorer,
    MahalanobisScorer, ViMScorer, NECOScorer
)
from src.utils.ood_metrics import compute_auroc, compute_fpr_at_tpr
from src.utils.visualization import plot_ood_scores

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

In [None]:
# Load config
with open('configs/config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print(yaml.dump(config,default_flow_style=False))

In [None]:
# ID data (CIFAR-100 test)
_, _, id_test_loader = get_cifar100_loaders(
    data_dir=config['data']['data_dir'],
    batch_size=config['training']['batch_size'],
    num_workers=config['data']['num_workers']
)

# OOD data (proportional sampling)
ood_loaders = get_ood_loaders(
    ood_datasets=config['ood']['datasets'],
    data_dir=config['data']['data_dir'],
    batch_size=config['training']['batch_size'],
    num_workers=config['data']['num_workers'],
    sampling_ratio=config['ood']['sampling_ratio']
)

print(f"ID test samples: {len(id_test_loader.dataset)}")

print("OOD samples:")
for name, loader in ood_loaders.items():
    print(f"({name}) : {len(loader.dataset)}")

In [None]:
def initialize_scorers(model):
    return {
        'MSP': MSPScorer(model, device),
        'MaxLogit': MaxLogitScorer(model, device),
        'Energy': EnergyScorer(model, device),
        'Mahalanobis': MahalanobisScorer(model, device),
        'ViM': ViMScorer(model, device),
        'NECO': NECOScorer(model, device)
    }

In [None]:
# Extract epoch numbers
def get_epoch_num(path):
    import re
    match = re.search(r'epoch(\d+)', path)
    return int(match.group(1)) if match else 0

checkpoint_dir = config['paths']['checkpoints']
checkpoints = sorted(glob.glob(f"{checkpoint_dir}/resnet18_cifar100_epoch*.pth"), key = get_epoch_num)

checkpoint_epochs = [get_epoch_num(cp) for cp in checkpoints]
tpt_mask = config['ood']['tpt_mask']

# Results storage
results = {
    'config': {
        'epochs': checkpoint_epochs,
        'ood_datasets': config['ood']['datasets'],
        'sampling_ratio': config['ood']['sampling_ratio'],
        'tpt_mask': tpt_mask
    },
    'scorers': {}
}

# init structure for each ood dataset
for scorer_name in ['MSP', 'MaxLogit', 'Energy', 'Mahalanobis', 'ViM', 'NECO']:
    results['scorers'][scorer_name] = {}
    
    for ood_dataset in config['ood']['datasets']:
        results['scorers'][scorer_name][ood_dataset] = {
            'auroc': [],
            'fpr95': []
        }

print(f"\nEvaluating {len(checkpoints)} checkpoints...")
print(f"TPT mask: {tpt_mask}")

In [None]:
for epoch_idx, (checkpoint_path, epoch) in enumerate(zip(checkpoints, checkpoint_epochs)):
    print(f"\n{'='*60}")
    print(f"Checkpoint: Epoch {epoch} ({epoch_idx+1}/{len(checkpoints)})")
    print('='*60)
    
    # Load model
    model = ResNet18(num_classes=config['model']['num_classes'])
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Initialize scorers
    scorers = initialize_scorers(model)
    
    # Fit statistics-based scorers (once per checkpoint)
    print("\nFitting statistics-based scorers...")
    train_loader, _, _ = get_cifar100_loaders(
        data_dir=config['data']['data_dir'],
        batch_size=config['training']['batch_size'],
        num_workers=config['data']['num_workers']
    )
    
    for name in ['Mahalanobis', 'ViM', 'NECO']:
        if name == 'NECO' and not tpt_mask[epoch_idx]:
            continue  # Skip NECO if not in TPT
        print(f"  Fitting {name}...")
        scorers[name].fit(train_loader, num_classes=config['model']['num_classes'])
    
    # Evaluate each scorer
    print("\nEvaluating scorers...")
    
    # Compute ID scores once (same for all OOD datasets)
    print("  Computing ID scores...")
    id_scores_dict = {}
    for scorer_name, scorer in scorers.items():
        if scorer_name == 'NECO' and not tpt_mask[epoch_idx]:
            continue
        id_scores_dict[scorer_name] = scorer.score_loader(id_test_loader)
    
    # Evaluate against each OOD dataset separately
    for dataset_name, ood_loader_single in ood_loaders.items():
        print(f"\n  OOD Dataset: {dataset_name}")
        
        for scorer_name, scorer in scorers.items():
            # Skip NECO if not in TPT
            if scorer_name == 'NECO' and not tpt_mask[epoch_idx]:
                if dataset_name == list(ood_loaders.keys())[0]:  # Print only once
                    print(f"    {scorer_name}: Skipped (not in TPT)")
                continue
            
            print(f"    {scorer_name}...", end=' ')
            
            # Get ID scores (already computed)
            id_scores = id_scores_dict[scorer_name]
            
            # Compute OOD scores for this dataset
            ood_scores = scorer.score_loader(ood_loader_single)
            
            # Compute metrics
            auroc = compute_auroc(id_scores, ood_scores)
            fpr95 = compute_fpr_at_tpr(id_scores, ood_scores, tpr_target=0.95)
            
            # Store results per dataset
            results['scorers'][scorer_name][dataset_name]['auroc'].append(auroc)
            results['scorers'][scorer_name][dataset_name]['fpr95'].append(fpr95)
            
            print(f"AUROC: {auroc:.3f}, FPR@95: {fpr95:.1f}%")

In [None]:
output_dir = config['paths']['ood_detection']
import os
os.makedirs(output_dir, exist_ok=True)

# Save pickle
results_path = os.path.join(output_dir, 'ood_scores_results.pkl')
with open(results_path, 'wb') as f:  
    pickle.dump(results, f)
print(f"\nResults saved: {results_path}")

# Save CSV summary
import pandas as pd
summary_data = []

for scorer_name, datasets_dict in results['scorers'].items():
    for dataset_name, metrics in datasets_dict.items():
        if len(metrics['auroc']) > 0: 
            summary_data.append({
                'Scorer': scorer_name,
                'OOD Dataset': dataset_name,
                'Final AUROC': metrics['auroc'][-1],
                'Final FPR@95': metrics['fpr95'][-1],
                'Epochs Evaluated': len(metrics['auroc'])
            })

df = pd.DataFrame(summary_data)
csv_path = os.path.join(output_dir, 'ood_scores_summary.csv')
df.to_csv(csv_path, index=False)
print(f"\nCSV summary saved: {csv_path}")
print("\n", df)

# Also save aggregated summary (average across datasets)
print("\n" + "="*60)
print("Aggregated Results (Average across OOD datasets)")
print("="*60)

agg_data = []
for scorer_name, datasets_dict in results['scorers'].items():
    aurocs = []
    fprs = []
    for dataset_name, metrics in datasets_dict.items():
        if len(metrics['auroc']) > 0:
            aurocs.append(metrics['auroc'][-1])
            fprs.append(metrics['fpr95'][-1])
    
    if aurocs:
        agg_data.append({
            'Scorer': scorer_name,
            'Avg AUROC': np.mean(aurocs),
            'Avg FPR@95': np.mean(fprs),
            'Std AUROC': np.std(aurocs),
            'Std FPR@95': np.std(fprs)
        })

df_agg = pd.DataFrame(agg_data)
agg_path = os.path.join(output_dir, 'ood_scores_aggregated.csv')
df_agg.to_csv(agg_path, index=False)
print(df_agg)

In [None]:
plot_ood_scores_per_dataset(results, save_dir=output_dir)

print("\n" + "="*60)
print("OOD Detection Evaluation Complete!")
print("="*60)