# Federated PatchCore Experiments

This notebook runs federated learning experiments comparing:
1. **IID Partitioning**: Data uniformly distributed across clients
2. **Category-Based Partitioning**: Non-IID distribution simulating factory stations

## Setup

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path().absolute().parent
sys.path.insert(0, str(project_root))

import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from src.data.autovi_dataset import AutoVIDataset, CATEGORIES
from src.data.partitioner import IIDPartitioner, CategoryPartitioner, compute_partition_stats
from src.data.preprocessing import get_transforms
from src.federated import FederatedPatchCore

# Set random seeds
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Configuration
DATA_ROOT = "/path/to/autovi"  # <-- UPDATE THIS PATH
OUTPUT_DIR = project_root / "outputs" / "federated"

print(f"Project root: {project_root}")
print(f"Output directory: {OUTPUT_DIR}")

## 1. Load Dataset and Explore

In [None]:
# Load training dataset
dataset = AutoVIDataset(
    root_dir=DATA_ROOT,
    categories=CATEGORIES,
    split="train",
    transform=None,
)

print(f"Dataset size: {len(dataset)} samples")
stats = dataset.get_statistics()
print(f"\nSamples per category:")
for cat, counts in stats['by_category'].items():
    print(f"  {cat}: {counts['good']} (all good for training)")

## 2. Compare Partitioning Strategies

Let's visualize the data distribution for IID vs Category-based partitioning.

In [None]:
# Create partitioners
iid_partitioner = IIDPartitioner(num_clients=5, seed=SEED)
category_partitioner = CategoryPartitioner(seed=SEED)

# Create partitions
iid_partition = iid_partitioner.partition(dataset)
category_partition = category_partitioner.partition(dataset)

# Compute statistics
iid_stats = compute_partition_stats(dataset, iid_partition)
category_stats = compute_partition_stats(dataset, category_partition)

In [None]:
def plot_partition_distribution(stats, title):
    """Plot the distribution of categories across clients."""
    fig, ax = plt.subplots(figsize=(12, 6))
    
    clients = list(stats['clients'].keys())
    categories = list(CATEGORIES)
    
    x = np.arange(len(clients))
    width = 0.12
    
    colors = plt.cm.Set3(np.linspace(0, 1, len(categories)))
    
    for i, cat in enumerate(categories):
        counts = [stats['clients'][c]['by_category'].get(cat, 0) for c in clients]
        ax.bar(x + i * width, counts, width, label=cat, color=colors[i])
    
    ax.set_xlabel('Client ID')
    ax.set_ylabel('Number of Samples')
    ax.set_title(title)
    ax.set_xticks(x + width * 2.5)
    ax.set_xticklabels([f'Client {c}' for c in clients])
    ax.legend(loc='upper right', bbox_to_anchor=(1.15, 1))
    
    plt.tight_layout()
    return fig

# Plot both distributions
fig1 = plot_partition_distribution(iid_stats, 'IID Partitioning')
plt.show()

fig2 = plot_partition_distribution(category_stats, 'Category-Based Partitioning')
plt.show()

## 3. Helper Functions

In [None]:
class CategoryTransformDataset:
    """Wrapper to apply category-specific transforms."""
    def __init__(self, dataset, transforms_dict):
        self.dataset = dataset
        self.transforms = transforms_dict

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        category = item["category"]
        if category in self.transforms:
            item["image"] = self.transforms[category](item["image"])
        return item


class TransformedSubset:
    """Subset wrapper for transformed dataset."""
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]


def create_dataloaders(transformed_dataset, partition, batch_size=32, num_workers=4):
    """Create dataloaders for each client."""
    def collate_fn(batch):
        images = torch.stack([item["image"] for item in batch])
        labels = torch.tensor([item["label"] for item in batch])
        categories = [item["category"] for item in batch]
        return {"image": images, "label": labels, "category": categories}

    dataloaders = {}
    for client_id, indices in partition.items():
        subset = TransformedSubset(transformed_dataset, indices)
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=torch.cuda.is_available(),
        )
        dataloaders[client_id] = loader
    return dataloaders


# Build transforms
transforms_dict = {}
for cat in CATEGORIES:
    transforms_dict[cat] = get_transforms(cat, normalize=True, to_tensor=True)

# Create transformed dataset
transformed_dataset = CategoryTransformDataset(dataset, transforms_dict)
print("Transforms ready for all categories")

## 4. Experiment 1: IID Partitioning

In [None]:
# Create dataloaders for IID partition
iid_dataloaders = create_dataloaders(transformed_dataset, iid_partition, batch_size=32)

print(f"Created {len(iid_dataloaders)} dataloaders for IID experiment")
for client_id, loader in iid_dataloaders.items():
    print(f"  Client {client_id}: {len(loader.dataset)} samples, {len(loader)} batches")

In [None]:
# Initialize IID federated model
iid_federated = FederatedPatchCore(
    num_clients=5,
    backbone_name="wide_resnet50_2",
    layers=["layer2", "layer3"],
    coreset_ratio=0.1,
    global_bank_size=10000,
    neighborhood_size=3,
    aggregation_strategy="federated_coreset",
    weighted_by_samples=True,
    use_faiss=True,
    device="auto",
)

# Store partition info
iid_federated.partition = iid_partition
iid_federated.partition_stats = iid_stats

print(iid_federated)

In [None]:
# Run IID federated training
print("Starting IID federated training...")
iid_global_bank = iid_federated.train(iid_dataloaders, seed=SEED)

print(f"\nIID Global memory bank: {iid_global_bank.shape}")

In [None]:
# Save IID results
iid_output_dir = OUTPUT_DIR / "iid"
iid_federated.save(str(iid_output_dir))
print(f"Saved IID model to {iid_output_dir}")

## 5. Experiment 2: Category-Based Partitioning

In [None]:
# Create dataloaders for category partition
category_dataloaders = create_dataloaders(transformed_dataset, category_partition, batch_size=32)

print(f"Created {len(category_dataloaders)} dataloaders for Category experiment")
for client_id, loader in category_dataloaders.items():
    print(f"  Client {client_id}: {len(loader.dataset)} samples, {len(loader)} batches")

In [None]:
# Initialize Category-based federated model
category_federated = FederatedPatchCore(
    num_clients=5,
    backbone_name="wide_resnet50_2",
    layers=["layer2", "layer3"],
    coreset_ratio=0.1,
    global_bank_size=10000,
    neighborhood_size=3,
    aggregation_strategy="federated_coreset",
    weighted_by_samples=True,
    use_faiss=True,
    device="auto",
)

# Store partition info
category_federated.partition = category_partition
category_federated.partition_stats = category_stats

print(category_federated)

In [None]:
# Run Category-based federated training
print("Starting Category-based federated training...")
category_global_bank = category_federated.train(category_dataloaders, seed=SEED)

print(f"\nCategory Global memory bank: {category_global_bank.shape}")

In [None]:
# Save Category results
category_output_dir = OUTPUT_DIR / "category_based"
category_federated.save(str(category_output_dir))
print(f"Saved Category model to {category_output_dir}")

## 6. Compare Results

In [None]:
# Compare statistics
print("=" * 60)
print("Comparison Summary")
print("=" * 60)

iid_stats_final = iid_federated.get_stats()
category_stats_final = category_federated.get_stats()

print(f"\n{'Metric':<30} {'IID':<15} {'Category':<15}")
print("-" * 60)
print(f"{'Global bank size':<30} {iid_stats_final['actual_global_bank_size']:<15} {category_stats_final['actual_global_bank_size']:<15}")
print(f"{'Feature dimension':<30} {iid_stats_final['feature_dim']:<15} {category_stats_final['feature_dim']:<15}")

if 'training' in iid_stats_final:
    iid_time = iid_stats_final['training'].get('elapsed_time_seconds', 'N/A')
    category_time = category_stats_final['training'].get('elapsed_time_seconds', 'N/A')
    print(f"{'Training time (s)':<30} {iid_time:<15.2f} {category_time:<15.2f}")

In [None]:
# Client contribution analysis
def plot_client_contributions(stats, title):
    """Plot client contributions to global memory bank."""
    if 'training' not in stats:
        print("No training stats available")
        return
    
    client_stats = stats['training'].get('client_stats', [])
    if not client_stats:
        print("No client stats available")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Samples per client
    client_ids = [s['client_id'] for s in client_stats]
    num_samples = [s['num_samples'] for s in client_stats]
    coreset_sizes = [s['coreset_size'] for s in client_stats]
    
    axes[0].bar(client_ids, num_samples, color='steelblue')
    axes[0].set_xlabel('Client ID')
    axes[0].set_ylabel('Number of Samples')
    axes[0].set_title(f'{title} - Samples per Client')
    
    axes[1].bar(client_ids, coreset_sizes, color='coral')
    axes[1].set_xlabel('Client ID')
    axes[1].set_ylabel('Coreset Size')
    axes[1].set_title(f'{title} - Coreset per Client')
    
    plt.tight_layout()
    return fig

plot_client_contributions(iid_stats_final, 'IID')
plt.show()

plot_client_contributions(category_stats_final, 'Category-Based')
plt.show()

## 7. Quick Inference Test

In [None]:
# Load test dataset
test_dataset = AutoVIDataset(
    root_dir=DATA_ROOT,
    categories=CATEGORIES,
    split="test",
    transform=None,
)

print(f"Test dataset: {len(test_dataset)} samples")
test_stats = test_dataset.get_statistics()
print(f"Good: {test_stats['by_label'][0]}, Defective: {test_stats['by_label'][1]}")

In [None]:
# Test inference on a few samples
def test_inference(model, dataset, num_samples=5):
    """Run inference on a few test samples."""
    results = []
    
    for i in range(min(num_samples, len(dataset))):
        sample = dataset[i]
        image = sample['image']
        category = sample['category']
        label = sample['label']
        
        # Apply transforms
        transform = get_transforms(category, normalize=True, to_tensor=True)
        image_tensor = transform(image).unsqueeze(0)
        
        # Predict
        anomaly_map, image_score = model.predict_single(image_tensor)
        
        results.append({
            'index': i,
            'category': category,
            'label': 'defective' if label == 1 else 'good',
            'score': image_score,
        })
    
    return results

# Test both models
print("IID Model Predictions:")
iid_results = test_inference(iid_federated, test_dataset, num_samples=10)
for r in iid_results:
    print(f"  Sample {r['index']}: {r['category']}, {r['label']}, score={r['score']:.4f}")

print("\nCategory Model Predictions:")
category_results = test_inference(category_federated, test_dataset, num_samples=10)
for r in category_results:
    print(f"  Sample {r['index']}: {r['category']}, {r['label']}, score={r['score']:.4f}")

## 8. Summary

This notebook demonstrated:

1. **Data Partitioning**: Comparing IID (uniform) vs Category-based (non-IID) distributions
2. **Federated Training**: Running PatchCore in a federated setting with 5 clients
3. **Memory Bank Aggregation**: Using the federated coreset strategy
4. **Results Comparison**: Analyzing client contributions and model statistics

### Key Observations:
- IID partitioning results in uniform data distribution across clients
- Category-based partitioning simulates realistic factory scenarios
- Both approaches produce comparable global memory banks
- Full evaluation requires running the evaluation pipeline on test data