[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/IvanNece/Detection-of-Anomalies-with-Localization/blob/main/notebooks/10_global_model_clean.ipynb)


# PHASE 8: Global Model - Unified Training

**Objective**: Train a **single** anomaly detection model on **all 3 classes** simultaneously.

## Key Differences from Per-Class Models:
1. **Single Model**: Train ONE PatchCore and ONE PaDiM on merged training data
2. **Per-Class Thresholds**: Calibrate separate thresholds for each class on validation
3. **Identical Shortcut Problem**: Can normals from one class be confused with anomalies from another?
4. **Performance Gap Analysis**: Quantify degradation vs per-class models

---

## Expected Outcome:
Global models should perform **worse** than per-class models due to:
- **Distribution heterogeneity**: Mixing different textures/objects
- **Identical shortcut**: Normal patterns of Class A may appear anomalous for Class B
- **Feature space contamination**: Shared representation struggles with diverse nominal distributions

## Terminological Note (Model-Unified vs Absolute-Unified)

Following the taxonomy from recent literature [CADA, Guo et al. 2024; HierCore, Heo & Kang 2025]:

| Setting | Training | Inference | Threshold | Our Experiment |
|---------|----------|-----------|-----------|----------------|
| **Per-Class** | Separate model per class | Class known | Per-class | ❌ |
| **Model-Unified** | Single model for all classes | Class known | Per-class | ✅ **This notebook** |
| **Absolute-Unified** | Single model for all classes | Class UNKNOWN | Single global | ❌ |

**Our setting**: We train ONE global model but calibrate **per-class thresholds** at inference 
(class is known). This is the "model-unified" setting, NOT "absolute-unified".

## 0. Setup and Imports

In [None]:
# ============================================================
# SETUP - Mount Google Drive & Clone Repository
# ============================================================

from google.colab import drive
from pathlib import Path
import os
import sys

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')
print("Done!\n")

# Clone repository on main branch
print("Cloning repository (branch: main)...")
repo_dir = '/content/Detection-of-Anomalies-with-Localization'

# Remove if exists
if os.path.exists(repo_dir):
    print("Removing existing repository...")
    !rm -rf {repo_dir}

# Clone from main branch
!git clone https://github.com/IvanNece/Detection-of-Anomalies-with-Localization.git {repo_dir}
print("Done!\n")

# Setup paths
PROJECT_ROOT = Path(repo_dir)

# Dataset location (only clean for this notebook)
CLEAN_DATASET_PATH = Path('/content/drive/MyDrive/mvtec_ad')

# Output directories
MODELS_DIR = PROJECT_ROOT / 'outputs' / 'models'
RESULTS_DIR = PROJECT_ROOT / 'outputs' / 'results'
THRESHOLDS_DIR = PROJECT_ROOT / 'outputs' / 'thresholds'
VIZ_DIR = PROJECT_ROOT / 'outputs' / 'visualizations' / 'global_model'

MODELS_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
THRESHOLDS_DIR.mkdir(parents=True, exist_ok=True)
VIZ_DIR.mkdir(parents=True, exist_ok=True)

# Verify dataset exists
if not CLEAN_DATASET_PATH.exists():
    raise FileNotFoundError(
        f"Clean dataset not found at {CLEAN_DATASET_PATH}\n"
        f"Please ensure mvtec_ad folder is in your Google Drive."
    )

# Add project root to Python path
sys.path.insert(0, str(PROJECT_ROOT))

print("\n" + "="*70)
print("SETUP COMPLETE - PHASE 8: GLOBAL MODEL (MODEL-UNIFIED)")
print("="*70)
print(f"Project:    {PROJECT_ROOT}")
print(f"Dataset:    {CLEAN_DATASET_PATH}")
print(f"Branch:     main")
print(f"Models:     {MODELS_DIR}")
print(f"Results:    {RESULTS_DIR}")
print(f"Thresholds: {THRESHOLDS_DIR}")
print(f"Viz:        {VIZ_DIR}")
print("="*70)

In [None]:
!pip install faiss-cpu --quiet
!pip install anomalib --quiet
!pip install umap-learn --quiet
import umap

In [None]:
# Standard imports
import os
import sys
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple

# Scientific computing
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Deep Learning
import torch
from torch.utils.data import DataLoader

# Dimensionality reduction
from sklearn.manifold import TSNE

# Project imports
from src.data.dataset import MVTecDataset
from src.data.transforms import get_clean_transforms
from src.models.patchcore import PatchCore
from src.models.padim_wrapper import PadimWrapper
from src.metrics.threshold_selection import calibrate_threshold
from src.metrics.image_metrics import compute_auroc, compute_auprc, compute_f1_at_threshold, compute_classification_metrics
from src.metrics.pixel_metrics import compute_pixel_auroc, compute_pro
from src.utils.reproducibility import set_seed
from src.utils.paths import ProjectPaths

# Set matplotlib style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("All imports successful")

In [None]:
# CRITICAL: Set seed for reproducibility
set_seed(42)

# Configuration
CLASSES = ['hazelnut', 'carpet', 'zipper']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
paths = ProjectPaths()

print(f"Device: {device}")
print(f"Classes: {CLASSES}")
print(f"Seed: 42 (FIXED for reproducibility)")

## 1. Data Preparation: Merge Training Sets from All Classes

**CRITICAL**: We create a **SINGLE** global training set by merging Train-clean from all 3 classes.

In [None]:
# Load clean splits
splits_path = paths.data_processed / 'clean_splits.json'
with open(splits_path, 'r') as f:
    splits = json.load(f)

print(f"Loaded splits for classes: {list(splits.keys())}")

In [None]:
# Merge all train_clean images into a single global dataset
global_train_images = []
global_train_masks = []
global_train_labels = []
global_train_class_ids = []  # Track which class each image belongs to

for class_idx, class_name in enumerate(CLASSES):
    class_splits = splits[class_name]['train']
    
    # Train only contains normal images
    n_samples = len(class_splits['images'])
    
    global_train_images.extend(class_splits['images'])
    global_train_masks.extend([None] * n_samples)  # Only normals
    global_train_labels.extend([0] * n_samples)  # Label 0 = normal
    global_train_class_ids.extend([class_idx] * n_samples)
    
    print(f"  {class_name:10s}: {n_samples:4d} normal images")

print(f"\nGlobal Train Set: {len(global_train_images)} total images")
print(f"   Distribution: Hazelnut={global_train_class_ids.count(0)}, "
      f"Carpet={global_train_class_ids.count(1)}, Zipper={global_train_class_ids.count(2)}")

In [None]:
# Create global training dataset
transform_clean = get_clean_transforms()

global_train_dataset = MVTecDataset(
    images=global_train_images,
    masks=global_train_masks,
    labels=global_train_labels,
    transform=transform_clean,
    phase='train'
)

global_train_loader = DataLoader(
    global_train_dataset,
    batch_size=32,
    shuffle=False,  # Important for reproducibility
    num_workers=4,
    pin_memory=True
)

print(f"Global Train DataLoader: {len(global_train_loader)} batches")

## 2. Train Single PatchCore Global Model

**KEY POINT**: Training ONE model on the merged dataset (not 3 separate models).

In [None]:
print("="*70)
print("TRAINING GLOBAL PATCHCORE MODEL")
print("="*70)

# Initialize PatchCore with same hyperparameters as per-class models
patchcore_global = PatchCore(
    backbone_layers=['layer2', 'layer3'],
    patch_size=3,
    coreset_ratio=0.05,  # 5% as per PHASE 3.5
    n_neighbors=9,
    device=device
)

# Train on global dataset
start_time = time.time()
patchcore_global.fit(global_train_loader, apply_coreset=True)
training_time = time.time() - start_time

print(f"\nPatchCore Global trained in {training_time:.2f}s")
print(f"   Memory bank size: {patchcore_global.memory_bank.features.shape[0]:,} patches")

# Save model
save_path = paths.models / 'patchcore_global_clean.npy'
patchcore_global.save(paths.models, class_name='global', domain='clean')
print(f"   Model saved to: {save_path}")

## 3. Train Single PaDiM Global Model

In [None]:
print("="*70)
print("TRAINING GLOBAL PADIM MODEL")
print("="*70)

# Initialize PaDiM
padim_global = PadimWrapper(
    backbone='wide_resnet50_2',
    layers=['layer1', 'layer2', 'layer3'],
    n_features=100,
    device=device
)

# Train on global dataset
start_time = time.time()
padim_global.fit(global_train_loader, verbose=True)
training_time = time.time() - start_time

print(f"\nPaDiM Global trained in {training_time:.2f}s")

# Save model
save_path = paths.models / 'padim_global_clean.pt'
padim_global.save(save_path)
print(f"   Model saved to: {save_path}")

## 4. Per-Class Threshold Calibration (Using Global Models)

**CRITICAL**: Although we have ONE model, we calibrate **SEPARATE thresholds** for each class on their validation sets.

This allows fair comparison: each class gets an optimal threshold despite using a shared model.

In [None]:
print("="*70)
print("PER-CLASS THRESHOLD CALIBRATION (GLOBAL MODELS)")
print("="*70)

thresholds_global = {
    'patchcore': {},
    'padim': {}
}

for class_name in CLASSES:
    print(f"\nCalibrating thresholds for {class_name}...")
    
    # Load validation split for this class
    val_split = splits[class_name]['val']
    
    val_dataset = MVTecDataset(
        images=val_split['images'],
        masks=val_split['masks'],
        labels=val_split['labels'],
        transform=transform_clean,
        phase='val'
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4
    )
    
    # Get predictions from GLOBAL models
    patchcore_scores, _ = patchcore_global.predict(
        torch.cat([batch[0] for batch in val_loader]).to(device),
        return_heatmaps=False
    )
    
    padim_scores, _ = padim_global.predict(
        torch.cat([batch[0] for batch in val_loader]).to(device),
        return_heatmaps=False
    )
    
    val_labels = np.array(val_split['labels'])
    
    # Calibrate thresholds to maximize F1
    threshold_pc = calibrate_threshold(patchcore_scores, val_labels)
    threshold_pd = calibrate_threshold(padim_scores, val_labels)
    
    thresholds_global['patchcore'][class_name] = float(threshold_pc)
    thresholds_global['padim'][class_name] = float(threshold_pd)
    
    # Compute F1 at calibrated thresholds
    f1_pc = compute_f1_at_threshold(val_labels, patchcore_scores, threshold_pc)
    f1_pd = compute_f1_at_threshold(val_labels, padim_scores, threshold_pd)
    
    print(f"   PatchCore: threshold={threshold_pc:.4f}, val_F1={f1_pc:.4f}")
    print(f"   PaDiM:     threshold={threshold_pd:.4f}, val_F1={f1_pd:.4f}")

# Save thresholds
thresholds_path = paths.thresholds / 'global_thresholds.json'
with open(thresholds_path, 'w') as f:
    json.dump(thresholds_global, f, indent=2)

print(f"\nThresholds saved to: {thresholds_path}")

## 5. Evaluate Global Models on Test-Clean (Per-Class)

Test the global models on each class separately using per-class thresholds.

In [None]:
print("="*70)
print("EVALUATING GLOBAL MODELS ON TEST-CLEAN")
print("="*70)

results_global = {
    'patchcore': {},
    'padim': {}
}

for class_name in CLASSES:
    print(f"\nEvaluating {class_name}...")
    
    # Load test split
    test_split = splits[class_name]['test']
    
    test_dataset = MVTecDataset(
        images=test_split['images'],
        masks=test_split['masks'],
        labels=test_split['labels'],
        transform=transform_clean,
        phase='test'
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4
    )
    
    # Get all test images and labels
    test_images = torch.cat([batch[0] for batch in test_loader]).to(device)
    test_masks_list = [batch[1] for batch in test_loader]
    test_labels = np.array(test_split['labels'])
    
    # PatchCore predictions
    pc_scores, pc_heatmaps = patchcore_global.predict(test_images, return_heatmaps=True)
    pc_predictions = (pc_scores >= thresholds_global['patchcore'][class_name]).astype(int)
    
    # PaDiM predictions
    pd_scores, pd_heatmaps = padim_global.predict(test_images, return_heatmaps=True)
    pd_predictions = (pd_scores >= thresholds_global['padim'][class_name]).astype(int)
    
    # Compute image-level metrics
    pc_auroc = compute_auroc(test_labels, pc_scores)
    pc_auprc = compute_auprc(test_labels, pc_scores)
    pc_metrics = compute_classification_metrics(test_labels, pc_predictions)
    
    pd_auroc = compute_auroc(test_labels, pd_scores)
    pd_auprc = compute_auprc(test_labels, pd_scores)
    pd_metrics = compute_classification_metrics(test_labels, pd_predictions)
    
    # Compute pixel-level metrics (for anomalous images with masks)
    anomaly_indices = np.where(test_labels == 1)[0]
    if len(anomaly_indices) > 0:
        test_masks_anomalies = [test_masks_list[i] for i in anomaly_indices if test_masks_list[i] is not None]
        pc_heatmaps_anomalies = pc_heatmaps[anomaly_indices]
        pd_heatmaps_anomalies = pd_heatmaps[anomaly_indices]
        
        if len(test_masks_anomalies) > 0:
            pc_pixel_auroc = compute_pixel_auroc(test_masks_anomalies, pc_heatmaps_anomalies)
            pc_pro = compute_pro(test_masks_anomalies, pc_heatmaps_anomalies)
            
            pd_pixel_auroc = compute_pixel_auroc(test_masks_anomalies, pd_heatmaps_anomalies)
            pd_pro = compute_pro(test_masks_anomalies, pd_heatmaps_anomalies)
        else:
            pc_pixel_auroc = pc_pro = pd_pixel_auroc = pd_pro = None
    else:
        pc_pixel_auroc = pc_pro = pd_pixel_auroc = pd_pro = None
    
    # Store results
    results_global['patchcore'][class_name] = {
        'auroc': pc_auroc,
        'auprc': pc_auprc,
        'f1': pc_metrics['f1'],
        'accuracy': pc_metrics['accuracy'],
        'precision': pc_metrics['precision'],
        'recall': pc_metrics['recall'],
        'pixel_auroc': pc_pixel_auroc,
        'pro': pc_pro
    }
    
    results_global['padim'][class_name] = {
        'auroc': pd_auroc,
        'auprc': pd_auprc,
        'f1': pd_metrics['f1'],
        'accuracy': pd_metrics['accuracy'],
        'precision': pd_metrics['precision'],
        'recall': pd_metrics['recall'],
        'pixel_auroc': pd_pixel_auroc,
        'pro': pd_pro
    }
    
    print(f"   PatchCore: AUROC={pc_auroc:.4f}, F1={pc_metrics['f1']:.4f}, Pixel AUROC={pc_pixel_auroc:.4f if pc_pixel_auroc else 'N/A'}")
    print(f"   PaDiM:     AUROC={pd_auroc:.4f}, F1={pd_metrics['f1']:.4f}, Pixel AUROC={pd_pixel_auroc:.4f if pd_pixel_auroc else 'N/A'}")

print("\nGlobal model evaluation complete")

## 6. Load Per-Class Models for Comparison

Load the per-class trained models to compute the performance gap.

In [None]:
print("="*70)
print("LOADING PER-CLASS MODELS FOR COMPARISON")
print("="*70)

# Load per-class results from PHASE 5 (clean domain evaluation)
clean_results_path = paths.results / 'clean_results.json'
with open(clean_results_path, 'r') as f:
    clean_results = json.load(f)

# Extract per-class model results
# clean_results structure: {"metadata": {...}, "patchcore": {...}, "padim": {...}}
results_per_class = {
    'patchcore': clean_results['patchcore'],
    'padim': clean_results['padim']
}

print("Per-class model results loaded")
print(f"   Available classes: {list(results_per_class['patchcore'].keys())}")

## 7. Performance Gap Analysis: Global vs Per-Class

**Key Question**: How much performance do we lose by using a single global model?

In [None]:
print("="*70)
print("PERFORMANCE GAP ANALYSIS: Per-Class vs Global Model")
print("="*70)

gaps = {
    'patchcore': {},
    'padim': {}
}

comparison_data = []

for method in ['patchcore', 'padim']:
    print(f"\n{method.upper()}:")
    print(f"{'Class':<12} {'Per-Class AUROC':<18} {'Global AUROC':<15} {'Gap':<10}")
    print("-" * 60)
    
    for class_name in CLASSES:
        auroc_per_class = results_per_class[method][class_name]['image_level']['auroc']
        auroc_global = results_global[method][class_name]['auroc']
        gap = auroc_per_class - auroc_global
        
        gaps[method][class_name] = gap
        
        comparison_data.append({
            'Method': method.upper(),
            'Class': class_name,
            'Per-Class AUROC': auroc_per_class,
            'Global AUROC': auroc_global,
            'Gap': gap
        })
        
        print(f"{class_name:<12} {auroc_per_class:>16.4f}   {auroc_global:>13.4f}   {gap:>+8.4f}")
    
    # Macro average
    avg_per_class = np.mean([results_per_class[method][c]['image_level']['auroc'] for c in CLASSES])
    avg_global = np.mean([results_global[method][c]['auroc'] for c in CLASSES])
    avg_gap = avg_per_class - avg_global
    
    print("-" * 60)
    print(f"{'MACRO AVG':<12} {avg_per_class:>16.4f}   {avg_global:>13.4f}   {avg_gap:>+8.4f}")
    
    comparison_data.append({
        'Method': method.upper(),
        'Class': 'MACRO_AVG',
        'Per-Class AUROC': avg_per_class,
        'Global AUROC': avg_global,
        'Gap': avg_gap
    })

# Create DataFrame for export
df_comparison = pd.DataFrame(comparison_data)
comparison_csv_path = paths.results / 'global_vs_per_class_comparison.csv'
df_comparison.to_csv(comparison_csv_path, index=False)
print(f"\nComparison saved to: {comparison_csv_path}")

In [None]:
# Visualization: Bar chart of performance gaps
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for idx, method in enumerate(['patchcore', 'padim']):
    ax = axes[idx]
    
    gap_values = [gaps[method][c] for c in CLASSES]
    colors = ['red' if g < 0 else 'green' for g in gap_values]
    
    bars = ax.bar(CLASSES, gap_values, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    ax.axhline(0, color='black', linestyle='--', linewidth=1)
    
    # Add value labels on bars
    for bar, val in zip(bars, gap_values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{val:+.3f}',
                ha='center', va='bottom' if height > 0 else 'top',
                fontsize=10, fontweight='bold')
    
    ax.set_title(f'{method.upper()}: Performance Gap\n(Per-Class - Global)', fontsize=12, fontweight='bold')
    ax.set_ylabel('AUROC Gap', fontsize=11)
    ax.set_xlabel('Class', fontsize=11)
    ax.grid(axis='y', alpha=0.3, linestyle=':')
    ax.set_ylim([min(gap_values) - 0.05, max(gap_values) + 0.05])

plt.tight_layout()
gap_plot_path = paths.visualizations / 'global_model_performance_gap.png'
plt.savefig(gap_plot_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"Plot saved to: {gap_plot_path}")

## 8. Identical Shortcut Problem Analysis

**Context**: In model-unified anomaly detection [CADA, Guo et al. 2024], a single model 
is trained on multiple classes but evaluated with per-class thresholds. This differs from 
absolute-unified settings where class information is unavailable at inference.

**The "Identical Shortcut" Problem** [UniAD, You et al. 2022]: When training a unified model 
on heterogeneous classes, normal patterns from one class may appear anomalous for another.

**Hypothesis**: Normal images from Class A may exceed the anomaly threshold calibrated for Class B, 
even with per-class thresholds, due to shared feature representation.

In [None]:
print("="*70)
print("IDENTICAL SHORTCUT PROBLEM ANALYSIS")
print("="*70)
print("Testing: Do normals from Class A trigger anomaly threshold for Class B?\n")

# Focus on PatchCore for this analysis (same applies to PaDiM)
confusion_matrix_cross = np.zeros((len(CLASSES), len(CLASSES)))

for target_idx, target_class in enumerate(CLASSES):
    threshold = thresholds_global['patchcore'][target_class]
    print(f"\nTarget Class: {target_class} (threshold={threshold:.4f})")
    
    for source_idx, source_class in enumerate(CLASSES):
        # Load NORMAL test images from source class
        source_test = splits[source_class]['test']
        normal_indices = [i for i, lbl in enumerate(source_test['labels']) if lbl == 0]
        
        if len(normal_indices) == 0:
            confusion_matrix_cross[target_idx, source_idx] = 0.0
            continue
        
        normal_images = [source_test['images'][i] for i in normal_indices]
        
        # Create dataset
        normal_dataset = MVTecDataset(
            images=normal_images,
            masks=[None] * len(normal_images),
            labels=[0] * len(normal_images),
            transform=transform_clean,
            phase='test'
        )
        
        normal_loader = DataLoader(normal_dataset, batch_size=32, shuffle=False, num_workers=4)
        normal_images_tensor = torch.cat([batch[0] for batch in normal_loader]).to(device)
        
        # Predict with global model
        scores, _ = patchcore_global.predict(normal_images_tensor, return_heatmaps=False)
        
        # Confusion rate: % of normals from source that exceed target threshold
        false_positive_rate = (scores > threshold).mean()
        confusion_matrix_cross[target_idx, source_idx] = false_positive_rate
        
        if source_class == target_class:
            print(f"   {source_class:10s} (same class): {false_positive_rate:.2%} FP rate")
        else:
            print(f"   {source_class:10s} → confusion: {false_positive_rate:.2%}")

print("\nCross-class confusion analysis complete")

In [None]:
# Visualization: Confusion heatmap
fig, ax = plt.subplots(figsize=(10, 8))

sns.heatmap(
    confusion_matrix_cross * 100,  # Convert to percentage
    annot=True,
    fmt='.1f',
    cmap='RdYlGn_r',  # Red = high confusion, Green = low confusion
    xticklabels=CLASSES,
    yticklabels=CLASSES,
    cbar_kws={'label': 'False Positive Rate (%)'},
    vmin=0,
    vmax=100,
    linewidths=1,
    linecolor='gray',
    ax=ax
)

ax.set_title('Identical Shortcut Problem\n'
             'Cross-Class Confusion Matrix (PatchCore Global Model)',
             fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel('Source Class (Normal Images)', fontsize=12, fontweight='bold')
ax.set_ylabel('Target Class (Threshold)', fontsize=12, fontweight='bold')

# Add explanation text
fig.text(0.5, 0.02,
         'Higher values (red) indicate normals from Source Class are confused as anomalies for Target Class',
         ha='center', fontsize=10, style='italic')

plt.tight_layout()
confusion_plot_path = paths.visualizations / 'identical_shortcut_confusion.png'
plt.savefig(confusion_plot_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"Confusion heatmap saved to: {confusion_plot_path}")

## 9. Feature Space Visualization: T-SNE

Visualize how the global model represents normal and anomalous samples from all classes in feature space.

In [None]:
print("="*70)
print("FEATURE SPACE VISUALIZATION (T-SNE)")
print("="*70)

# Collect features from all classes
all_features = []
all_labels = []
all_class_names = []

n_samples_per_class = 30  # Limit for visualization

for class_name in CLASSES:
    test_split = splits[class_name]['test']
    
    # Sample normal and anomalous images
    normal_indices = [i for i, lbl in enumerate(test_split['labels']) if lbl == 0][:n_samples_per_class]
    anomaly_indices = [i for i, lbl in enumerate(test_split['labels']) if lbl == 1][:n_samples_per_class]
    
    selected_indices = normal_indices + anomaly_indices
    selected_images = [test_split['images'][i] for i in selected_indices]
    selected_labels = [test_split['labels'][i] for i in selected_indices]
    
    # Create dataset
    sample_dataset = MVTecDataset(
        images=selected_images,
        masks=[None] * len(selected_images),
        labels=selected_labels,
        transform=transform_clean,
        phase='test'
    )
    
    sample_loader = DataLoader(sample_dataset, batch_size=len(selected_images), shuffle=False)
    sample_images_tensor = next(iter(sample_loader))[0].to(device)
    
    # Extract features using PatchCore backbone
    with torch.no_grad():
        features = patchcore_global.backbone(sample_images_tensor)  # (B, C, H, W)
        # Global average pooling to get image-level features
        features_pooled = features.mean(dim=[2, 3])  # (B, C)
        features_numpy = features_pooled.cpu().numpy()
    
    all_features.append(features_numpy)
    all_labels.extend(selected_labels)
    all_class_names.extend([class_name] * len(selected_images))
    
    print(f"   {class_name}: {len(normal_indices)} normals, {len(anomaly_indices)} anomalies")

# Concatenate all features
all_features = np.vstack(all_features)
print(f"\nCollected {all_features.shape[0]} samples with {all_features.shape[1]} features")

In [None]:
# Run T-SNE
print("Running T-SNE (this may take a minute)...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
features_2d = tsne.fit_transform(all_features)
print("T-SNE complete")

In [None]:
# Visualization
fig, ax = plt.subplots(figsize=(14, 10))

# Define colors and markers
class_colors = {'hazelnut': 'blue', 'carpet': 'green', 'zipper': 'purple'}
label_markers = {0: 'o', 1: 'X'}  # o = normal, X = anomaly
label_sizes = {0: 50, 1: 100}

# Plot each combination
for class_name in CLASSES:
    for label_type in [0, 1]:
        mask = [(c == class_name and l == label_type) 
                for c, l in zip(all_class_names, all_labels)]
        
        if not any(mask):
            continue
        
        label_str = 'Normal' if label_type == 0 else 'Anomaly'
        
        ax.scatter(
            features_2d[mask, 0],
            features_2d[mask, 1],
            c=class_colors[class_name],
            marker=label_markers[label_type],
            s=label_sizes[label_type],
            alpha=0.7,
            edgecolors='black',
            linewidth=0.5,
            label=f'{class_name} - {label_str}'
        )

ax.set_title('T-SNE Visualization of Feature Space\n'
             'Global PatchCore Model (All Classes)',
             fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel('T-SNE Dimension 1', fontsize=12)
ax.set_ylabel('T-SNE Dimension 2', fontsize=12)
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10, frameon=True, shadow=True)
ax.grid(alpha=0.3, linestyle=':')

plt.tight_layout()
tsne_plot_path = paths.visualizations / 'tsne_global_model.png'
plt.savefig(tsne_plot_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"T-SNE plot saved to: {tsne_plot_path}")

## 10. Summary Table and Final Results

In [None]:
# Create comprehensive summary table
summary_data = []

for method in ['patchcore', 'padim']:
    for class_name in CLASSES:
        # Per-class model
        per_class_res = results_per_class[method][class_name]['image_level']
        
        # Global model
        global_res = results_global[method][class_name]
        
        summary_data.append({
            'Method': method.upper(),
            'Class': class_name,
            'Model Type': 'Per-Class',
            'AUROC': per_class_res['auroc'],
            'AUPRC': per_class_res['auprc'],
            'F1': per_class_res['f1'],
            'Accuracy': per_class_res['accuracy']
        })
        
        summary_data.append({
            'Method': method.upper(),
            'Class': class_name,
            'Model Type': 'Global',
            'AUROC': global_res['auroc'],
            'AUPRC': global_res['auprc'],
            'F1': global_res['f1'],
            'Accuracy': global_res['accuracy']
        })

df_summary = pd.DataFrame(summary_data)

# Display
print("="*70)
print("FINAL SUMMARY: Per-Class vs Global Model")
print("="*70)
print(df_summary.to_string(index=False))

# Save to CSV
summary_path = paths.results / 'global_model_summary.csv'
df_summary.to_csv(summary_path, index=False)
print(f"\nSummary saved to: {summary_path}")

In [None]:
# Save complete results dictionary
final_results = {
    'global_model_results': results_global,
    'per_class_model_results': results_per_class,
    'performance_gaps': gaps,
    'cross_class_confusion_matrix': confusion_matrix_cross.tolist(),
    'thresholds_global': thresholds_global,
    'metadata': {
        'classes': CLASSES,
        'global_train_size': len(global_train_images),
        'patchcore_coreset_ratio': 0.05,
        'seed': 42
    }
}

results_json_path = paths.results / 'global_model_analysis.json'
with open(results_json_path, 'w') as f:
    json.dump(final_results, f, indent=2)

print(f"Complete results saved to: {results_json_path}")

Save results (google drive)

In [None]:
# ============================================================
# COPY ALL RESULTS TO GOOGLE DRIVE FOR PERSISTENCE
# ============================================================

import shutil

# Create destination folder in Drive
DRIVE_ROOT = Path('/content/drive/MyDrive/anomaly_detection_project')
PHASE8_OUTPUTS = DRIVE_ROOT / '10_global_model_outputs'
PHASE8_OUTPUTS.mkdir(parents=True, exist_ok=True)

print("\n" + "="*70)
print("COPYING FILES TO GOOGLE DRIVE")
print("="*70)
print(f"\nDestination: {PHASE8_OUTPUTS}")

# List of all generated files
generated_files = []

# Models (Global models)
print("\nCopying models...")
model_files = [
    MODELS_DIR / 'patchcore_global_clean.npy',
    MODELS_DIR / 'patchcore_global_clean_config.pth',
    MODELS_DIR / 'padim_global_clean.pt',
    MODELS_DIR / 'padim_global_clean.json'
]
generated_files.extend(model_files)

# Results
print("Copying results...")
result_files = [
    RESULTS_DIR / 'global_model_analysis.json',
    RESULTS_DIR / 'global_model_summary.csv',
    RESULTS_DIR / 'global_vs_per_class_comparison.csv'
]
generated_files.extend(result_files)

# Thresholds
print("Copying thresholds...")
threshold_files = [
    THRESHOLDS_DIR / 'global_thresholds.json'
]
generated_files.extend(threshold_files)

# Visualizations
print("Copying visualizations...")
viz_files = [
    VIZ_DIR / 'global_model_performance_gap.png',
    VIZ_DIR / 'identical_shortcut_confusion.png',
    VIZ_DIR / 'tsne_global_model.png'
]
generated_files.extend(viz_files)

# Copy all files
copied_count = 0
missing_count = 0

for src_path in generated_files:
    if src_path.exists():
        # Preserve directory structure
        if 'models' in str(src_path):
            dst_dir = PHASE8_OUTPUTS / 'models'
        elif 'results' in str(src_path):
            dst_dir = PHASE8_OUTPUTS / 'results'
        elif 'thresholds' in str(src_path):
            dst_dir = PHASE8_OUTPUTS / 'thresholds'
        elif 'visualizations' in str(src_path):
            dst_dir = PHASE8_OUTPUTS / 'visualizations'
        else:
            dst_dir = PHASE8_OUTPUTS

        dst_dir.mkdir(parents=True, exist_ok=True)
        dst_path = dst_dir / src_path.name

        shutil.copy2(src_path, dst_path)
        print(f"  ✓ {src_path.name}")
        copied_count += 1
    else:
        print(f"  ✗ MISSING: {src_path.name}")
        missing_count += 1

print("\n" + "="*70)
print(f"✓ Copy complete: {copied_count} files copied, {missing_count} missing")
print(f"✓ All results saved to: {PHASE8_OUTPUTS}")
print("="*70)

// TO CHECK
## 11. Key Findings and Interpretation

### Setting Clarification:
This experiment uses the **Model-Unified** setting [CADA, 2024; HierCore, 2025]:
- Single model trained on all classes
- **Per-class thresholds** at inference (class is known)
- This is NOT "absolute-unified" (which would require class-agnostic thresholds)

### Observations:
1. **Performance Gap**: Model-unified approach shows degraded AUROC vs per-class models
   - This validates CADA's observation that shared representations struggle with heterogeneous distributions
   
2. **Cross-Class Confusion**: The confusion matrix shows non-zero false positive rates across classes
   - Normal textures from one class can trigger another class's threshold
   - This is the "identical shortcut" phenomenon [UniAD, You et al. 2022]

### Implications:
- For industrial deployment with **single-category** quality control: **per-class models remain optimal**
- Model-unified approaches are useful when:
  - Storage/training efficiency is critical
  - Class categories are related (e.g., similar textures)
- Absolute-unified remains an open research challenge (see CADA, HierCore for solutions)

---


### Outputs Generated:
1. **Models**: `patchcore_global_clean.npy`, `padim_global_clean.pt`
2. **Thresholds**: `global_thresholds.json` (per-class thresholds for global models)
3. **Results**: `global_model_analysis.json`, `global_model_summary.csv`
4. **Visualizations**:
   - `global_model_performance_gap.png` (bar chart)
   - `identical_shortcut_confusion.png` (heatmap)
   - `tsne_global_model.png` (feature space)

- Compare with [You et al., 2022] findings on unified anomaly detection
- Discuss implications for industrial deployment