# PaDiM Baseline Training (Clean Domain)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/IvanNece/Detection-of-Anomalies-with-Localization/blob/main/notebooks/05_padim_clean.ipynb)

**Phase 4: Baseline Model Implementation**

---

## PaDiM Method

**PaDiM** = Probabilistic Anomaly Detection with Multi-scale features

- Extracts multi-scale features from ResNet
- Models normal appearance using Gaussian distributions
- Uses Mahalanobis distance for anomaly scoring
- **No gradient-based training** (statistical approach)

---

## Setup - Mount Drive & Configure Paths

In [None]:
# ============================================================
# SETUP - Mount Google Drive & Clone Repository
# ============================================================

from google.colab import drive
from pathlib import Path
import os
import sys

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')
print("Done!\n")

# Clone repository on main branch
print("Cloning repository (branch: main)...")
repo_dir = '/content/Detection-of-Anomalies-with-Localization'

# Remove if exists
if os.path.exists(repo_dir):
    print("Removing existing repository...")
    !rm -rf {repo_dir}

# Clone from main branch
!git clone https://github.com/IvanNece/Detection-of-Anomalies-with-Localization.git {repo_dir}
print("Done!\n")

# Set project root
PROJECT_ROOT = Path(repo_dir)
sys.path.insert(0, str(PROJECT_ROOT))
print(f"[OK] Path configured")

## Imports

In [None]:
# ============================================================
# INSTALL ANOMALIB - REQUIRED FOR PADIM
# ============================================================
# anomalib provides PaDiM's underlying implementation

print("Installing anomalib...")
!pip install anomalib --quiet

# Verify installation
try:
    import anomalib
    from anomalib.models.image.padim import Padim
    from anomalib.models.image.padim.torch_model import PadimModel
    print(f"Success! anomalib {anomalib.__version__} installed")
    print("PaDiM components available")
except ImportError as e:
    print(f"Error: {e}")
    print("Retry: !pip install anomalib")

## Install Dependencies

In [None]:
# ============================================================
# IMPORTS - All required modules
# ============================================================

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import json
import zipfile
from datetime import datetime

# Project imports
from src.models.padim_wrapper import PadimWrapper
from src.data.dataset import MVTecDataset
from src.data.transforms import get_clean_transforms
from src.utils.config import load_config
from src.utils.reproducibility import set_seed
from src.utils.paths import get_paths

print("[OK] All modules imported successfully")

In [None]:
# ============================================================
# LOAD DATASET SPLITS
# ============================================================

# Load clean splits (generated in notebook 02)
splits_path = PROJECT_ROOT / 'data' / 'processed' / 'clean_splits.json'

with open(splits_path, 'r') as f:
    splits = json.load(f)

print("[OK] Loaded clean splits")
print("\nSplit summary:")
for class_name in CLASSES:
    split_data = splits[class_name]
    print(f"\n{class_name}:")
    print(f"  Train:  {len(split_data['train']['images'])} images (normal only)")
    print(f"  Val:    {len(split_data['val']['images'])} images")
    print(f"    - Normal: {sum(1 for l in split_data['val']['labels'] if l == 0)}")
    print(f"    - Anomalous: {sum(1 for l in split_data['val']['labels'] if l == 1)}")
    print(f"  Test:   {len(split_data['test']['images'])} images")
    print(f"    - Normal: {sum(1 for l in split_data['test']['labels'] if l == 0)}")
    print(f"    - Anomalous: {sum(1 for l in split_data['test']['labels'] if l == 1)}")

## Configuration

In [None]:
# ============================================================
# CONFIGURATION - Load experiment config
# ============================================================

# Load configuration
config = load_config(PROJECT_ROOT / 'configs' / 'experiment_config.yaml')
paths = get_paths(config)

# Settings
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32
NUM_WORKERS = 2
CLASSES = config.dataset.classes  # ['hazelnut', 'carpet', 'zipper']

# Set seed for reproducibility
set_seed(config.seed)

print(f"\n[INFO] PaDiM Configuration:")
print(f"  Backbone: {config.padim.backbone}")
print(f"  Layers: {config.padim.layers}")
print(f"  Feature dimension: {config.padim.n_features}")
print(f"  Device: {DEVICE}")
print(f"  Classes: {CLASSES}")

## Load Data

In [None]:
# Create transform (SAME AS PATCHCORE - same preprocessing for fair comparison)
transform = get_clean_transforms(
    image_size=config.dataset.image_size,
    normalize_mean=config.dataset.normalize.mean,
    normalize_std=config.dataset.normalize.std
)

print(f"[OK] Transform initialized:")
print(f"  Size: {config.dataset.image_size}x{config.dataset.image_size}")
print(f"  Normalization: ImageNet statistics")
print(f"  Augmentation: None (deterministic preprocessing)")

# Helper function to create DataLoader (ALIGNED WITH PATCHCORE)
def create_dataloader(split_dict, batch_size=32, shuffle=False):
    """Create DataLoader from split dict (same format as PatchCore)."""
    dataset = MVTecDataset(
        images=split_dict['images'],
        masks=split_dict['masks'],
        labels=split_dict['labels'],
        transform=transform,
        phase='train' if shuffle else 'val'
    )
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True if DEVICE == 'cuda' else False
    )
    
    return loader

print("[OK] DataLoader helper function ready")

## Training Loop

In [None]:
# Initialize results tracking
MODEL_DIR = PROJECT_ROOT / 'outputs' / 'models'
MODEL_DIR.mkdir(parents=True, exist_ok=True)

training_results = {
    'classes': [],
    'num_train_samples': [],
    'training_time_seconds': [],
    'memory_bank_size_mb': [],
    'model_path': []
}

trained_models = {}

print(f" Model directory: {MODEL_DIR}")

In [None]:
# Train PaDiM for each class
print("\n" + "="*80)
print("[START] TRAINING PaDiM MODELS")
print("="*80 + "\n")

for class_name in CLASSES:
    print(f"\n{'='*80}")
    print(f"[TRAIN] PaDiM: {class_name.upper()}")
    print(f"{'='*80}\n")
    
    # 1. Load data (ALIGNED WITH PATCHCORE)
    split_data = splits[class_name]
    train_loader = create_dataloader(
        split_dict=split_data['train'],
        batch_size=BATCH_SIZE,
        shuffle=False  # No shuffle for deterministic training
    )
    
    print(f"[OK] Loaded {len(train_loader.dataset)} training samples\n")
    
    # 2. Initialize model
    model = PadimWrapper(
        backbone=config.padim.backbone,
        layers=config.padim.layers,
        n_features=config.padim.n_features,
        image_size=config.dataset.image_size,
        device=DEVICE
    )
    
    # 3. Train (fit on normal samples)
    model.fit(train_loader, verbose=True)
    
    # 4. Save
    model_path = MODEL_DIR / f"padim_{class_name}_clean.pt"
    model.save(model_path, include_stats=True)
    
    # 5. Store results
    training_results['classes'].append(class_name)
    training_results['num_train_samples'].append(model.training_stats['num_samples'])
    training_results['training_time_seconds'].append(model.training_stats['training_time_seconds'])
    training_results['memory_bank_size_mb'].append(model.training_stats['memory_bank_size_mb'])
    training_results['model_path'].append(str(model_path))
    
    trained_models[class_name] = model
    
    print(f"\n[COMPLETE] {class_name}")
    print(f"   Model: {model_path.name}")
    print(f"   Time: {model.training_stats['training_time_seconds']:.2f}s")

print(f"\n{'='*80}")
print("[SUCCESS] ALL CLASSES TRAINED SUCCESSFULLY!")
print(f"{'='*80}\n")

## Training Results

In [None]:
# Save training results to JSON
results_path = PROJECT_ROOT / 'outputs' / 'results' / 'padim_training_results_clean.json'
results_path.parent.mkdir(parents=True, exist_ok=True)

with open(results_path, 'w') as f:
    json.dump(training_results, f, indent=2)

# Save as CSV for easy viewing
import pandas as pd
results_df = pd.DataFrame(training_results)
csv_path = PROJECT_ROOT / 'outputs' / 'results' / 'padim_training_stats_clean.csv'
results_df.to_csv(csv_path, index=False)

print(f"\n[OK] Results saved: {results_path.name}")

In [None]:
# Visualize training results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training time
axes[0].bar(results_df['classes'], results_df['training_time_seconds'], color='steelblue', alpha=0.8)
axes[0].set_xlabel('Class', fontsize=12)
axes[0].set_ylabel('Training Time (seconds)', fontsize=12)
axes[0].set_title('PaDiM Training Time per Class', fontsize=14, fontweight='bold')
axes[0].grid(axis='y', alpha=0.3)
for i, v in enumerate(results_df['training_time_seconds']):
    axes[0].text(i, v + 0.5, f"{v:.1f}s", ha='center', fontweight='bold')

# Memory bank size
axes[1].bar(results_df['classes'], results_df['memory_bank_size_mb'], color='coral', alpha=0.8)
axes[1].set_xlabel('Class', fontsize=12)
axes[1].set_ylabel('Memory Bank Size (MB)', fontsize=12)
axes[1].set_title('PaDiM Memory Usage per Class', fontsize=14, fontweight='bold')
axes[1].grid(axis='y', alpha=0.3)
for i, v in enumerate(results_df['memory_bank_size_mb']):
    axes[1].text(i, v + 1, f"{v:.1f}MB", ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

## Testing and Validation

Comprehensive tests to verify models work correctly

In [None]:
# TEST 1: Verify all models load correctly
print("="*80)
print("[TEST 1] Model Loading")
print("="*80 + "\n")

all_load_success = True
for class_name in CLASSES:
    model_path = MODEL_DIR / f"padim_{class_name}_clean.pt"
    try:
        test_model = PadimWrapper(device=DEVICE)
        test_model.load(model_path)
        print(f"[OK] {class_name}: Loaded successfully")
    except Exception as e:
        print(f"[FAIL] {class_name}: Failed - {e}")
        all_load_success = False

if all_load_success:
    print(f"\n{'='*80}")
    print("[PASS] TEST 1 PASSED: All models load correctly")
    print("="*80)
else:
    print(f"\n{'='*80}")
    print("[FAIL] TEST 1 FAILED: Some models failed to load")
    print("="*80)

In [None]:
# TEST 2: Verify predictions work and anomalous > normal
print(f"\n\n{'='*80}")
print("[TEST 2] Prediction Validation")
print("="*80 + "\n")

test_results = []

for class_name in CLASSES:
    print(f"Testing {class_name}...")
    model = trained_models[class_name]
    split_data = splits[class_name]
    
    # Create val loader
    val_loader = create_dataloader(
        split_dict=split_data['val'],
        batch_size=1,
        shuffle=False
    )
    
    # Get one normal and one anomalous sample
    normal_score = None
    anomalous_score = None
    
    for images, masks, labels, paths in val_loader:
        label = labels.item()
        
        if label == 0 and normal_score is None:
            # Single image prediction
            score, _ = model.predict(images, return_heatmaps=False)
            normal_score = float(score)
        elif label == 1 and anomalous_score is None:
            score, _ = model.predict(images, return_heatmaps=False)
            anomalous_score = float(score)
        
        if normal_score is not None and anomalous_score is not None:
            break
    
    # Check result
    passed = anomalous_score > normal_score if (normal_score and anomalous_score) else False
    status = "[OK]" if passed else "[FAIL]"
    
    print(f"  {status} Normal score: {normal_score:.4f}")
    print(f"  {status} Anomalous score: {anomalous_score:.4f}")
    print(f"  {status} Test: {'PASSED' if passed else 'FAILED'}\n")
    
    test_results.append(passed)

if all(test_results):
    print(f"{'='*80}")
    print("[PASS] TEST 2 PASSED: Anomalous scores > Normal scores for all classes")
    print("="*80)
else:
    print(f"{'='*80}")
    print("[WARNING] TEST 2 WARNING: Some anomalous scores not higher than normal")
    print("="*80)

In [None]:
# TEST 3: Visual validation
print(f"\n\n{'='*80}")
print("[TEST 3] Visual Validation")
print("="*80 + "\n")

# Select one class for visualization
TEST_CLASS = CLASSES[0]
model = trained_models[TEST_CLASS]
split_data = splits[TEST_CLASS]

# Create val loader
val_loader = create_dataloader(
    split_dict=split_data['val'],
    batch_size=1,
    shuffle=False
)

# Get one normal and one anomalous sample
normal_images = None
anomalous_images = None

for images, masks, labels, paths in val_loader:
    if labels.item() == 0 and normal_images is None:
        normal_images = images
    elif labels.item() == 1 and anomalous_images is None:
        anomalous_images = images
    
    if normal_images is not None and anomalous_images is not None:
        break

# Predict
normal_score, normal_heatmap = model.predict(normal_images, return_heatmaps=True)
anomalous_score, anomalous_heatmap = model.predict(anomalous_images, return_heatmaps=True)

# Handle scalar vs array
if isinstance(normal_score, np.ndarray):
    normal_score = normal_score.item()
if isinstance(anomalous_score, np.ndarray):
    anomalous_score = anomalous_score.item()

print(f"[OK] Predictions generated:")
print(f"  Normal score: {normal_score:.4f}")
print(f"  Anomalous score: {anomalous_score:.4f}\n")

# Denormalize for visualization
def denormalize(tensor, mean, std):
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    return torch.clamp(tensor * std + mean, 0, 1)

normal_img = denormalize(normal_images[0].cpu(), config.dataset.normalize.mean, config.dataset.normalize.std)
normal_img = normal_img.permute(1, 2, 0).numpy()

anomalous_img = denormalize(anomalous_images[0].cpu(), config.dataset.normalize.mean, config.dataset.normalize.std)
anomalous_img = anomalous_img.permute(1, 2, 0).numpy()

# Plot
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Normal sample
axes[0, 0].imshow(normal_img)
axes[0, 0].set_title(f'Normal Image\nScore: {normal_score:.4f}', fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

if normal_heatmap is not None:
    im1 = axes[0, 1].imshow(normal_heatmap, cmap='jet')
    axes[0, 1].set_title('Anomaly Heatmap', fontsize=12, fontweight='bold')
    axes[0, 1].axis('off')
    plt.colorbar(im1, ax=axes[0, 1], fraction=0.046)
    
    axes[0, 2].imshow(normal_img)
    axes[0, 2].imshow(normal_heatmap, cmap='jet', alpha=0.5)
    axes[0, 2].set_title('Overlay', fontsize=12, fontweight='bold')
    axes[0, 2].axis('off')

# Anomalous sample
axes[1, 0].imshow(anomalous_img)
axes[1, 0].set_title(f'Anomalous Image\nScore: {anomalous_score:.4f}', fontsize=12, fontweight='bold', color='red')
axes[1, 0].axis('off')

if anomalous_heatmap is not None:
    im2 = axes[1, 1].imshow(anomalous_heatmap, cmap='jet')
    axes[1, 1].set_title('Anomaly Heatmap', fontsize=12, fontweight='bold')
    axes[1, 1].axis('off')
    plt.colorbar(im2, ax=axes[1, 1], fraction=0.046)
    
    axes[1, 2].imshow(anomalous_img)
    axes[1, 2].imshow(anomalous_heatmap, cmap='jet', alpha=0.5)
    axes[1, 2].set_title('Overlay', fontsize=12, fontweight='bold')
    axes[1, 2].axis('off')

plt.suptitle(f'PaDiM Predictions - {TEST_CLASS.upper()}', fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.show()

print(f"{'='*80}")
print("[PASS] TEST 3 PASSED: Visualizations generated successfully")
print("="*80)

## Download Results

Create a ZIP package with all outputs for download

In [None]:
from google.colab import files

print("="*80)
print("[PACKAGE] Preparing Download Package")
print("="*80 + "\n")

# Create ZIP
zip_path = PROJECT_ROOT / 'padim_outputs.zip'

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add models
    for class_name in CLASSES:
        model_file = MODEL_DIR / f"padim_{class_name}_clean.pt"
        stats_file = MODEL_DIR / f"padim_{class_name}_clean.json"
        
        if model_file.exists():
            zipf.write(model_file, f"models/{model_file.name}")
            print(f"[OK] Added: {model_file.name}")
        
        if stats_file.exists():
            zipf.write(stats_file, f"models/{stats_file.name}")
            print(f"[OK] Added: {stats_file.name}")
    
    # Add results JSON
    if results_path.exists():
        zipf.write(results_path, f"results/{results_path.name}")
        print(f"[OK] Added: {results_path.name}")

print(f"\n[OK] Package created: {zip_path.name}")
print(f"   Size: {zip_path.stat().st_size / (1024**2):.2f} MB")

# Download
print(f"\n{'='*80}")
print("[DOWNLOAD] Starting download...")
print("="*80)
files.download(str(zip_path))
print("\n[OK] Download complete!")

## Save to Google Drive

Copy all outputs to Google Drive for persistent storage

In [None]:
import shutil

# Create output directory in Google Drive
DRIVE_OUTPUT = Path('/content/drive/MyDrive/anomaly_detection_project/05_padim_clean_outputs')
DRIVE_OUTPUT.mkdir(parents=True, exist_ok=True)

print("="*70)
print("[SAVE] Copying outputs to Google Drive...")
print("="*70 + "\n")

# 1. Copy model files
models_saved = []
for class_name in CLASSES:
    model_file = MODEL_DIR / f"padim_{class_name}_clean.pt"
    stats_file = MODEL_DIR / f"padim_{class_name}_clean.json"
    
    if model_file.exists():
        shutil.copy2(model_file, DRIVE_OUTPUT / model_file.name)
        models_saved.append(model_file.name)
        print(f"[OK] Saved: {model_file.name}")
    
    if stats_file.exists():
        shutil.copy2(stats_file, DRIVE_OUTPUT / stats_file.name)
        print(f"[OK] Saved: {stats_file.name}")

# 2. Copy results files
results_files = [
    (PROJECT_ROOT / 'outputs' / 'results' / 'padim_training_stats_clean.csv', 'padim_training_stats_clean.csv'),
    (PROJECT_ROOT / 'outputs' / 'results' / 'padim_training_results_clean.json', 'padim_training_results_clean.json')
]

for src, dst_name in results_files:
    if src.exists():
        shutil.copy2(src, DRIVE_OUTPUT / dst_name)
        print(f"[OK] Saved: {dst_name}")

print("\n" + "="*70)
print("[SUCCESS] OUTPUTS SAVED TO GOOGLE DRIVE!")
print("="*70)
print(f"Location: {DRIVE_OUTPUT}")
print(f"Total files: {len(list(DRIVE_OUTPUT.iterdir()))}")
print(f"Total size: {sum(f.stat().st_size for f in DRIVE_OUTPUT.iterdir() if f.is_file()) / (1024*1024):.2f} MB")
print("="*70)

---
### Outputs

```
outputs/models/
├── padim_hazelnut_clean.pt      (~50-100 MB each)
├── padim_hazelnut_clean.json
├── padim_carpet_clean.pt
├── padim_carpet_clean.json
├── padim_zipper_clean.pt
└── padim_zipper_clean.json

outputs/results/
└── padim_training_results_clean.json
```

