# 02 - PatchCore Baseline Training

This notebook demonstrates training centralized PatchCore models on the AutoVI dataset.

## Overview

PatchCore is a state-of-the-art anomaly detection method that:
1. Uses a pre-trained WideResNet-50-2 backbone for feature extraction
2. Builds a memory bank of representative normal patch features
3. Detects anomalies by computing distances to nearest neighbors in the memory bank

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import yaml
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.data.autovi_dataset import AutoVIDataset, CATEGORIES, get_resize_shape
from src.models.patchcore import PatchCore

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Configuration

In [None]:
# Dataset configuration
DATA_DIR = project_root / "dataset"
OUTPUT_DIR = project_root / "outputs" / "baseline"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Load configuration from YAML file
CONFIG_PATH = project_root / "experiments" / "configs" / "baseline" / "patchcore_config_lowmem.yaml"
with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

print(f"Loaded config from: {CONFIG_PATH.name}")

# Model configuration
CONFIG = {
    "backbone": config["model"]["backbone"],
    "layers": config["model"]["layers"],
    "coreset_percentage": config["model"]["coreset_percentage"],
    "neighborhood_size": config["model"]["neighborhood_size"],
    "use_faiss": config["model"]["use_faiss"],
}

# Training configuration
BATCH_SIZE = config["training"]["batch_size"]
NUM_WORKERS = config["training"]["num_workers"]
SEED = config["seed"]

print(f"Model config: {CONFIG}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Seed: {SEED}")

# Set seeds for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)

## Helper Functions

In [None]:
def get_transforms(category: str) -> transforms.Compose:
    """Get image transforms for a category."""
    resize_shape = get_resize_shape(category)
    
    return transforms.Compose([
        transforms.Resize(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])


def visualize_anomaly_map(image, anomaly_map, title="Anomaly Map"):
    """Visualize image with overlaid anomaly map."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image (denormalize)
    img = image.cpu().numpy().transpose(1, 2, 0)
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img = np.clip(img, 0, 1)
    axes[0].imshow(img)
    axes[0].set_title("Original Image")
    axes[0].axis("off")
    
    # Anomaly map
    im = axes[1].imshow(anomaly_map, cmap="jet")
    axes[1].set_title("Anomaly Map")
    axes[1].axis("off")
    plt.colorbar(im, ax=axes[1])
    
    # Overlay
    axes[2].imshow(img)
    axes[2].imshow(anomaly_map, cmap="jet", alpha=0.5)
    axes[2].set_title(title)
    axes[2].axis("off")
    
    plt.tight_layout()
    return fig

## Train PatchCore for a Single Category

Let's start by training on a single category to verify the implementation.

In [None]:
# Select a category for demo
DEMO_CATEGORY = "engine_wiring"  # Choose from CATEGORIES

print(f"Training PatchCore for: {DEMO_CATEGORY}")
print(f"Available categories: {CATEGORIES}")

In [None]:
# Create training dataset
transform = get_transforms(DEMO_CATEGORY)

train_dataset = AutoVIDataset(
    root_dir=DATA_DIR,
    categories=[DEMO_CATEGORY],
    split="train",
    transform=transform,
)

print(f"Training samples: {len(train_dataset)}")
print(f"Dataset statistics: {train_dataset.get_statistics()}")

In [None]:
# Create dataloader
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

In [None]:
# Initialize PatchCore model
model = PatchCore(
    backbone_name=CONFIG["backbone"],
    layers=CONFIG["layers"],
    coreset_ratio=CONFIG["coreset_percentage"],
    neighborhood_size=CONFIG["neighborhood_size"],
    device="auto",
    use_faiss=CONFIG["use_faiss"],
)

print(f"Model initialized on: {model.device}")
print(f"Feature dimension: {model.feature_dim}")

In [None]:
# Train (build memory bank)
model.fit(train_loader, seed=SEED)

# Print statistics
stats = model.get_stats()
print("\nModel Statistics:")
for key, value in stats.items():
    print(f"  {key}: {value}")

In [None]:
# Save model
model_path = OUTPUT_DIR / "models" / f"patchcore_{DEMO_CATEGORY}"
model_path.parent.mkdir(parents=True, exist_ok=True)
model.save(str(model_path))

## Test Inference

Let's test the trained model on some test images.

In [None]:
# Create test dataset
test_dataset = AutoVIDataset(
    root_dir=DATA_DIR,
    categories=[DEMO_CATEGORY],
    split="test",
    transform=transform,
)

print(f"Test samples: {len(test_dataset)}")
print(f"Test statistics: {test_dataset.get_statistics()}")

In [None]:
# Test on a few samples
n_samples = min(5, len(test_dataset))

for i in range(n_samples):
    sample = test_dataset[i]
    image = sample["image"]
    label = sample["label"]
    defect_type = sample.get("defect_type", "good")
    
    # Predict
    anomaly_map, score = model.predict_single(image)
    
    # Visualize
    title = f"Label: {'Defective' if label else 'Good'} ({defect_type}) | Score: {score:.2f}"
    fig = visualize_anomaly_map(image, anomaly_map, title)
    plt.show()

## Train All Categories

Now let's train models for all categories.

In [None]:
def train_category(category: str) -> dict:
    """Train PatchCore for a single category."""
    print(f"\n{'='*60}")
    print(f"Training: {category}")
    print(f"{'='*60}")
    
    # Get transforms
    transform = get_transforms(category)
    
    # Create dataset
    dataset = AutoVIDataset(
        root_dir=DATA_DIR,
        categories=[category],
        split="train",
        transform=transform,
    )
    
    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    
    # Create and train model
    model = PatchCore(
        backbone_name=CONFIG["backbone"],
        layers=CONFIG["layers"],
        coreset_ratio=CONFIG["coreset_percentage"],
        neighborhood_size=CONFIG["neighborhood_size"],
        device="auto",
        use_faiss=CONFIG["use_faiss"],
    )
    
    model.fit(dataloader, seed=SEED)
    
    # Save model
    model_path = OUTPUT_DIR / "models" / f"patchcore_{category}"
    model.save(str(model_path))
    
    # Return stats
    stats = model.get_stats()
    stats["category"] = category
    stats["num_training_samples"] = len(dataset)
    
    return stats

In [None]:
# Train all categories
all_stats = {}

for category in CATEGORIES:
    try:
        stats = train_category(category)
        all_stats[category] = stats
    except Exception as e:
        print(f"Error training {category}: {e}")
        all_stats[category] = {"error": str(e)}

In [None]:
# Print summary
print("\n" + "="*60)
print("Training Summary")
print("="*60)

for category, stats in all_stats.items():
    if "error" in stats:
        print(f"{category}: ERROR - {stats['error']}")
    else:
        print(f"{category}:")
        print(f"  Training samples: {stats.get('num_training_samples', 'N/A')}")
        print(f"  Memory bank size: {stats.get('memory_bank_size', 'N/A')}")

In [None]:
# Save summary
import json

summary_path = OUTPUT_DIR / "training_summary.json"
with open(summary_path, "w") as f:
    json.dump(all_stats, f, indent=2, default=str)

print(f"\nTraining summary saved to: {summary_path}")

## Next Steps

After training the baseline models, proceed to:
1. **Evaluation**: Run evaluation on test sets to compute AUC-sPRO and AUC-ROC metrics
2. **Phase 3**: Set up federated learning infrastructure
3. **Comparison**: Compare centralized baseline with federated approaches