[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/IvanNece/Detection-of-Anomalies-with-Localization/blob/main/notebooks/10_global_model_clean.ipynb)

# Global Model Training and Evaluation

This notebook implements the "Global Model" strategy, where a single anomaly detection model (PatchCore and PaDiM) is trained on **all 3 classes** (Hazelnut, Carpet, Zipper) simultaneously.

The goal is to evaluate the performance degradation caused by modeling a heterogeneous normal distribution, as discussed in the UniAD paper (You et al., 2022).

In [None]:
# Imports
import os
import sys
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader

# Add src to path
sys.path.append('..')

from src.data.dataset import MVTecDataset
from src.models.patchcore import PatchCore
from src.models.padim_wrapper import PadimWrapper
from src.utils.paths import ProjectPaths
from src.utils.reproducibility import set_seed
from src.evaluation.evaluator import Evaluator, evaluate_model_on_dataloader

In [None]:
# Usage Setup
set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
paths = ProjectPaths()

In [None]:
# Define Classes
CLASSES = ['hazelnut', 'carpet', 'zipper']

# Load clean splits
splits_path = paths.data_processed / 'clean_splits.json'
with open(splits_path, 'r') as f:
    splits = json.load(f)

print("Splits loaded for classes:", list(splits.keys()))

## 1. Prepare Global Training Data

We merge the training sets of all three classes into a single Global Training Set.

In [None]:
# Create Global Training Data
global_train_images = []
global_train_labels = []
global_train_masks = []

for class_name in CLASSES:
    class_split = splits[class_name]
    # Training data is always normal (label 0, mask None)
    train_imgs = class_split['train']['images']
    train_lbls = class_split['train']['labels']
    
    global_train_images.extend(train_imgs)
    global_train_labels.extend(train_lbls)
    global_train_masks.extend([None] * len(train_imgs))

print(f"Global Training Set Size: {len(global_train_images)}")

In [None]:
# Define standard transform (ImageNet normalization)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create Global Dataset and Standard standard DataLoader
global_train_dataset = MVTecDataset(
    images=global_train_images,
    masks=global_train_masks,
    labels=global_train_labels,
    transform=transform,
    phase='train'
)

global_train_loader = DataLoader(
    global_train_dataset,
    batch_size=32,
    shuffle=False, # No need to shuffle for feature extraction/stats
    num_workers=2
)

## 2. Train Global PatchCore

We train a single PatchCore model on the combined dataset.

In [None]:
print("Initializing Global PatchCore...")
patchcore = PatchCore(
    backbone_layers=['layer2', 'layer3'],
    patch_size=3,
    coreset_ratio=0.01, # Keep 1% of total patches
    n_neighbors=9,
    device=device
)

print("Training Global PatchCore...")
patchcore.fit(global_train_loader)

print("Saving Global PatchCore...")
patchcore.save(paths.models, class_name='global', domain='clean')

## 3. Train Global PaDiM

We train a single PaDiM model on the combined dataset.

In [None]:
print("Initializing Global PaDiM...")
padim = PadimWrapper(
    backbone='wide_resnet50_2',
    layers=['layer1', 'layer2', 'layer3'],
    n_features=100,
    device=device
)

print("Training Global PaDiM...")
padim.fit(global_train_loader)

print("Saving Global PaDiM...")
padim.save(paths.models / 'padim_global_clean.pt')

## 4. Evaluate Global Models

We evaluate the Global Models on each class's Test Set individually.
We calibrate a threshold for each class using its specific Validation Set, then compute metrics on its Test Set.

In [None]:
def evaluate_global_model(model, model_name, classes, splits, transform, device):
    print(f"\nEvaluating Global {model_name}...")
    
    results = {}
    
    for class_name in classes:
        print(f"\n--- Evaluating on {class_name} ---")
        class_split = splits[class_name]
        
        # 1. Create DataLoaders
        val_dataset = MVTecDataset(
            images=class_split['val']['images'],
            masks=class_split['val']['masks'],
            labels=class_split['val']['labels'],
            transform=transform,
            phase='test' # Val is treated as test for loading
        )
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
        
        test_dataset = MVTecDataset(
            images=class_split['test']['images'],
            masks=class_split['test']['masks'],
            labels=class_split['test']['labels'],
            transform=transform,
            phase='test'
        )
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
        
        # 2. Initialize Evaluator
        evaluator = Evaluator(model_name, class_name, domain='clean')
        
        # 3. Validation & Calibration
        print("Running Validation...")
        val_results = evaluate_model_on_dataloader(
            model, val_loader, device, return_heatmaps=False, verbose=False
        )
        evaluator.calibrate_threshold(val_results['scores'], val_results['labels'])
        
        # 4. Test Evaluation
        print("Running Test Evaluation...")
        test_results = evaluate_model_on_dataloader(
            model, test_loader, device, return_heatmaps=True, verbose=True
        )
        
        # Image-level metrics
        evaluator.evaluate_image_level(test_results['scores'], test_results['labels'])
        
        # Pixel-level metrics
        evaluator.evaluate_pixel_level(test_results['masks'], test_results['heatmaps'])
        
        # Save results with 'global' prefix
        evaluator.save_results(paths.results, prefix='global')
        
        # Store results + curves + raw scores for visualization
        class_res = evaluator.get_results()
        class_res['roc_curve'] = evaluator.roc_curve
        class_res['pr_curve'] = evaluator.pr_curve
        class_res['test_scores'] = test_results['scores']
        class_res['test_labels'] = test_results['labels']
        
        results[class_name] = class_res
        
    return results

In [None]:
# Evaluate PatchCore
patchcore_results = evaluate_global_model(
    patchcore, 'patchcore', CLASSES, splits, transform, device
)

In [None]:
# Evaluate PaDiM
padim_results = evaluate_global_model(
    padim, 'padim', CLASSES, splits, transform, device
)

## 5. Performance Summary & Visualization

Comparison of per-class performance for the Global Models and qualitative visualization.

In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Plotting Functions
def plot_curves(results_dict, model_name):
    """Plot ROC and PR curves for each class."""
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    fig.suptitle(f'Global {model_name} - Evaluation Curves', fontsize=16)
    
    classes = list(results_dict.keys())
    
    for i, class_name in enumerate(classes):
        res = results_dict[class_name]
        
        # ROC Curve
        fpr, tpr, _ = res['roc_curve']
        auroc = res['image_level']['auroc']
        
        axes[0, i].plot(fpr, tpr, label=f'AUROC = {auroc:.4f}')
        axes[0, i].plot([0, 1], [0, 1], 'k--', alpha=0.5)
        axes[0, i].set_title(f'{class_name.title()} - ROC Curve')
        axes[0, i].set_xlabel('False Positive Rate')
        axes[0, i].set_ylabel('True Positive Rate')
        axes[0, i].legend()
        axes[0, i].grid(True, alpha=0.3)
        
        # PR Curve
        precision, recall, _ = res['pr_curve']
        auprc = res['image_level']['auprc']
        
        axes[1, i].plot(recall, precision, label=f'AUPRC = {auprc:.4f}')
        axes[1, i].set_title(f'{class_name.title()} - PR Curve')
        axes[1, i].set_xlabel('Recall')
        axes[1, i].set_ylabel('Precision')
        axes[1, i].legend()
        axes[1, i].grid(True, alpha=0.3)
        
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

def plot_confusion_matrices(results_dict, model_name):
    """Plot Confusion Matrix for each class."""
    fig, axes = plt.subplots(1, 3, figsize=(20, 5))
    fig.suptitle(f'Global {model_name} - Confusion Matrices', fontsize=16)
    
    classes = list(results_dict.keys())
    
    for i, class_name in enumerate(classes):
        res = results_dict[class_name]
        
        # Recompute confusion matrix at optimal threshold
        scores = res['test_scores']
        labels = res['test_labels']
        threshold = res['threshold']
        
        preds = (scores >= threshold).astype(int)
        cm = confusion_matrix(labels, preds)
        
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[i],
                    xticklabels=['Normal', 'Anomaly'],
                    yticklabels=['Normal', 'Anomaly'])
        axes[i].set_title(f'{class_name.title()}')
        axes[i].set_xlabel('Predicted')
        axes[i].set_ylabel('True')
        
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

def visualize_global_results(results_dict, model_name):
    # Existing table summary
    print(f"\nGlobal {model_name} Performance Summary:")
    print(f"{'Class':<15} {'AUROC':<10} {'F1':<10} {'PRO':<10}")
    print("-" * 45)
    
    avg_auroc, avg_f1, avg_pro = 0, 0, 0
    
    for class_name, res in results_dict.items():
        auroc = res['image_level']['auroc']
        f1 = res['image_level']['f1']
        pro = res['pixel_level'].get('pro', 0)
        if pro is None: pro = 0
        
        print(f"{class_name:<15} {auroc:.4f}     {f1:.4f}     {pro:.4f}")
        
        avg_auroc += auroc
        avg_f1 += f1
        avg_pro += pro
        
    print("-" * 45)
    print(f"{'AVERAGE':<15} {avg_auroc/3:.4f}     {avg_f1/3:.4f}     {avg_pro/3:.4f}")
    
    # New Plots
    plot_curves(results_dict, model_name)
    plot_confusion_matrices(results_dict, model_name)

# Run Visualization
visualize_global_results(patchcore_results, 'PatchCore')
visualize_global_results(padim_results, 'PaDiM')

In [None]:
# Visualization Function
def show_sample_predictions(model, class_name, splits, transform, device, n_samples=3):
    print(f"Visualizing anomalies for {class_name}...")
    class_split = splits[class_name]
    dataset = MVTecDataset(
        images=class_split['test']['images'],
        masks=class_split['test']['masks'],
        labels=class_split['test']['labels'],
        transform=transform,
        phase='test'
    )
    # Shuffle to get random samples
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    model.eval()
    images, masks, labels, _ = next(iter(loader))
    
    # Filter for anomalies
    anom_indices = [i for i, l in enumerate(labels) if l == 1]
    if len(anom_indices) < n_samples:
        indices = range(min(n_samples, len(images)))
    else:
        indices = anom_indices[:n_samples]
        
    batch_images = images[indices].to(device)
    _, heatmaps = model.predict(batch_images, return_heatmaps=True)
    
    # Plot
    fig, axes = plt.subplots(len(indices), 3, figsize=(15, 5*len(indices)))
    if len(indices) == 1:
        axes = np.array([axes])
    if len(indices) > 1 and len(axes.shape) == 1: # Handle 1D array if subplots returns 1D
         axes = axes.reshape(-1, 3)
        
    for i, idx in enumerate(indices):
        # Original Image
        img = images[idx].permute(1, 2, 0).numpy()
        img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406] # Denormalize
        img = np.clip(img, 0, 1)
        
        # Mask
        mask = masks[idx]
        if mask is not None:
            mask = mask.numpy().squeeze()
            
        # Heatmap
        hm = heatmaps[i]
        
        ax_row = axes[i]
        
        ax_row[0].imshow(img)
        ax_row[0].set_title(f"{class_name} Image")
        ax_row[0].axis('off')
        
        if mask is not None:
            ax_row[1].imshow(mask, cmap='gray')
        ax_row[1].set_title("Ground Truth")
            
        ax_row[1].axis('off')
        
        im = ax_row[2].imshow(hm, cmap='jet')
        ax_row[2].set_title("Global Model Heatmap")
        ax_row[2].axis('off')
        plt.colorbar(im, ax=ax_row[2])
        
    plt.tight_layout()
    plt.show()

# Visualize for each class (using PatchCore as example)
print("\n--- Visualizing Global PatchCore Predictions ---")
for class_name in CLASSES:
    show_sample_predictions(patchcore, class_name, splits, transform, device)