In [None]:
# TEST FILE HERE

In [None]:
"""
MNIST Cluster Size Sweep Experiment (v2)
=========================================
Paste this into Google Colab with A100 runtime.

Improvements:
- Uses all 60k MNIST examples
- Clusters are strictly disjoint (no overlapping examples)
- Parallel training of multiple models for small cluster sizes
"""

#%% Cell 1: Imports and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm.auto import tqdm
import time
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

#%% Cell 2: Configuration
class Config:
    # Cluster sizes to sweep
    cluster_sizes = [1, 2, 3, 5, 7, 10, 15, 20, 30, 50, 70, 100]

    # Training settings
    batch_size = 128
    epochs = 20
    lr = 1e-3
    weight_decay = 1e-4

    # How many clusters to train per cluster size
    max_clusters_to_train = 200

    # Parallel training: how many models to train simultaneously
    # (for small cluster sizes, we train many models in parallel)
    parallel_models = 50  # Adjust based on GPU memory

    # Seed
    seed = 123

config = Config()
torch.manual_seed(config.seed)
np.random.seed(config.seed)

#%% Cell 3: Model Definition
class SmallCNN(nn.Module):
    """Small CNN (~100k params)"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = x.view(-1, 128 * 3 * 3)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

def create_model():
    return SmallCNN().to(device)

num_params = sum(p.numel() for p in create_model().parameters())
print(f"Model parameters: {num_params:,}")

#%% Cell 4: Load Data
print("Loading MNIST...")
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)

# Extract ALL data into tensors
all_images = train_dataset.data.float().unsqueeze(1) / 255.0
all_images = (all_images - 0.1307) / 0.3081
all_labels = train_dataset.targets
all_images = all_images.to(device)
all_labels = all_labels.to(device)

num_total = len(all_labels)
print(f"Total examples: {num_total}")

#%% Cell 5: Create Disjoint Clusters for All Sizes
print("\nCreating disjoint clusters for all cluster sizes...")

# Single random permutation for all cluster assignments
master_permutation = np.random.permutation(num_total)

def create_disjoint_clusters(cluster_size, permutation):
    """Create disjoint clusters of given size from the permutation."""
    num_clusters = len(permutation) // cluster_size
    clusters = []
    for i in range(num_clusters):
        start = i * cluster_size
        end = start + cluster_size
        clusters.append(permutation[start:end])
    return clusters

# Pre-create clusters for all sizes
all_clusters = {}
for size in config.cluster_sizes:
    clusters = create_disjoint_clusters(size, master_permutation)
    all_clusters[size] = clusters
    print(f"  Size {size:>3}: {len(clusters)} disjoint clusters")

#%% Cell 6: Sequential Training (for larger cluster sizes)
def train_single_model(cluster_indices, epochs=config.epochs):
    """Train a single model on given indices."""
    model = create_model()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    X = all_images[cluster_indices]
    y = all_labels[cluster_indices]

    actual_batch_size = min(config.batch_size, len(cluster_indices))
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=actual_batch_size, shuffle=True, drop_last=False)

    model.train()
    for epoch in range(epochs):
        for batch_X, batch_y in loader:
            optimizer.zero_grad()
            output = model(batch_X)
            loss = F.cross_entropy(output, batch_y)
            loss.backward()
            optimizer.step()
        scheduler.step()

    return model

#%% Cell 7: Parallel Training (for small cluster sizes)
def train_models_parallel(list_of_cluster_indices, epochs=config.epochs):
    """
    Train multiple models in parallel by interleaving gradient updates.
    Returns list of trained models.
    """
    num_models = len(list_of_cluster_indices)

    # Create all models and optimizers
    models = [create_model() for _ in range(num_models)]
    optimizers = [torch.optim.AdamW(m.parameters(), lr=config.lr, weight_decay=config.weight_decay)
                  for m in models]
    schedulers = [torch.optim.lr_scheduler.CosineAnnealingLR(opt, epochs)
                  for opt in optimizers]

    # Prepare data for each model
    data_X = [all_images[idx] for idx in list_of_cluster_indices]
    data_y = [all_labels[idx] for idx in list_of_cluster_indices]

    # Training loop - interleave updates across models
    for epoch in range(epochs):
        # Set all models to train mode
        for m in models:
            m.train()

        # For small clusters, we might only have 1-few examples per model
        # So we just do one "batch" per epoch (the whole cluster)
        for i in range(num_models):
            X, y = data_X[i], data_y[i]

            optimizers[i].zero_grad()
            output = models[i](X)
            loss = F.cross_entropy(output, y)
            loss.backward()
            optimizers[i].step()

        for s in schedulers:
            s.step()

    return models

#%% Cell 8: Evaluation Functions
@torch.no_grad()
def evaluate_model(model, indices):
    """Evaluate single model on given indices, return accuracy."""
    model.eval()
    X = all_images[indices]
    y = all_labels[indices]

    # Process in batches
    batch_size = 2048
    correct = 0
    total = 0

    for i in range(0, len(indices), batch_size):
        batch_X = X[i:i+batch_size]
        batch_y = y[i:i+batch_size]
        output = model(batch_X)
        preds = output.argmax(dim=1)
        correct += (preds == batch_y).sum().item()
        total += len(batch_y)

    return correct / total

@torch.no_grad()
def evaluate_models_on_clusters(models, clusters, train_cluster_indices):
    """
    Evaluate multiple models on the clusters we trained on (not all clusters).

    Args:
        models: list of M models
        clusters: list of all clusters
        train_cluster_indices: list of M cluster indices that each model was trained on

    Returns:
        accuracy_matrix: (M, M) array where entry [i,j] is accuracy of model i on cluster j's data
                        (diagonal is NaN since we exclude self-evaluation)
    """
    num_models = len(models)
    accuracy_matrix = np.full((num_models, num_models), np.nan)

    # Set all models to eval mode
    for m in models:
        m.eval()

    # Get data for each trained cluster (cache it)
    cluster_data = []
    for c_idx in train_cluster_indices:
        c_indices = clusters[c_idx]
        X = all_images[c_indices]
        y = all_labels[c_indices]
        cluster_data.append((X, y))

    # Evaluate: model i on cluster j (where j indexes into train_cluster_indices)
    for j, (X, y) in enumerate(cluster_data):
        for i, model in enumerate(models):
            if i == j:  # Skip self-evaluation
                continue

            output = model(X)
            preds = output.argmax(dim=1)
            acc = (preds == y).float().mean().item()
            accuracy_matrix[i, j] = acc

    return accuracy_matrix

#%% Cell 9: Main Sweep Loop
print("\n" + "="*60)
print("RUNNING CLUSTER SIZE SWEEP")
print("="*60)

sweep_results = {}

for cluster_size in config.cluster_sizes:
    clusters = all_clusters[cluster_size]
    num_clusters = len(clusters)

    print(f"\n{'='*50}")
    print(f"Cluster size: {cluster_size}, Total clusters: {num_clusters}")
    print(f"{'='*50}")

    # Decide how many clusters to train
    num_to_train = min(num_clusters, config.max_clusters_to_train)

    # Randomly select which clusters to train on
    np.random.seed(config.seed + cluster_size)  # Reproducible selection per size
    train_cluster_indices = np.random.choice(num_clusters, size=num_to_train, replace=False)

    print(f"Training {num_to_train} models...")
    start_time = time.time()

    # Storage
    all_models = []
    all_train_cluster_idx = []

    # Decide whether to use parallel or sequential training
    if cluster_size <= 10:
        # Use parallel training in batches
        batch_size = config.parallel_models
        num_batches = (num_to_train + batch_size - 1) // batch_size

        for batch_idx in tqdm(range(num_batches), desc=f"Size {cluster_size} (parallel)"):
            batch_start = batch_idx * batch_size
            batch_end = min(batch_start + batch_size, num_to_train)
            batch_train_indices = train_cluster_indices[batch_start:batch_end]

            # Get cluster data for this batch
            batch_clusters = [clusters[i] for i in batch_train_indices]

            # Train models in parallel
            batch_models = train_models_parallel(batch_clusters, epochs=config.epochs)

            all_models.extend(batch_models)
            all_train_cluster_idx.extend(batch_train_indices)

            # Clear some memory
            torch.cuda.empty_cache()
    else:
        # Use sequential training for larger clusters
        for i, train_idx in enumerate(tqdm(train_cluster_indices, desc=f"Size {cluster_size} (sequential)")):
            model = train_single_model(clusters[train_idx], epochs=config.epochs)
            all_models.append(model)
            all_train_cluster_idx.append(train_idx)

            # Periodic memory cleanup
            if (i + 1) % 50 == 0:
                torch.cuda.empty_cache()

    train_time = time.time() - start_time
    print(f"Training completed in {train_time:.1f}s")

    # Evaluate all models on all trained clusters
    print("Evaluating models on trained clusters...")
    eval_start = time.time()

    accuracy_matrix = evaluate_models_on_clusters(all_models, clusters, all_train_cluster_idx)

    eval_time = time.time() - eval_start
    print(f"Evaluation completed in {eval_time:.1f}s")

    # Compute diagonal (training set) accuracies - batch this too
    diag_accs = []
    for model in all_models:
        model.eval()

    for m_idx, train_idx in enumerate(all_train_cluster_idx):
        X = all_images[clusters[train_idx]]
        y = all_labels[clusters[train_idx]]
        with torch.no_grad():
            output = all_models[m_idx](X)
            preds = output.argmax(dim=1)
            acc = (preds == y).float().mean().item()
        diag_accs.append(acc)
    diag_accs = np.array(diag_accs)

    # Compute off-diagonal statistics
    off_diag_accs = accuracy_matrix[~np.isnan(accuracy_matrix)]

    results = {
        'cluster_size': cluster_size,
        'num_clusters': num_clusters,
        'num_trained': num_to_train,
        'off_diag_mean': off_diag_accs.mean(),
        'off_diag_std': off_diag_accs.std(),
        'off_diag_median': np.median(off_diag_accs),
        'off_diag_min': off_diag_accs.min(),
        'off_diag_max': off_diag_accs.max(),
        'off_diag_25': np.percentile(off_diag_accs, 25),
        'off_diag_75': np.percentile(off_diag_accs, 75),
        'diag_mean': diag_accs.mean(),
        'diag_std': diag_accs.std(),
        'all_off_diag': off_diag_accs,
        'all_diag': diag_accs,
        'matrix': accuracy_matrix,
        'train_cluster_indices': np.array(all_train_cluster_idx),
        'train_time': train_time,
        'eval_time': eval_time
    }

    sweep_results[cluster_size] = results

    print(f"  Off-diagonal accuracy: {results['off_diag_mean']:.4f} ± {results['off_diag_std']:.4f}")
    print(f"  Diagonal (train) accuracy: {results['diag_mean']:.4f} ± {results['diag_std']:.4f}")

    # Cleanup
    del all_models
    torch.cuda.empty_cache()

#%% Cell 10: Summary Plot - Accuracy vs Cluster Size
print("\n" + "="*60)
print("VISUALIZATIONS")
print("="*60)

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

sizes = config.cluster_sizes
means = [sweep_results[s]['off_diag_mean'] for s in sizes]
stds = [sweep_results[s]['off_diag_std'] for s in sizes]
q25 = [sweep_results[s]['off_diag_25'] for s in sizes]
q75 = [sweep_results[s]['off_diag_75'] for s in sizes]

# Plot 1: Mean accuracy vs cluster size
ax = axes[0, 0]
ax.errorbar(sizes, means, yerr=stds, fmt='o-', capsize=5, capthick=2, linewidth=2, markersize=8, label='Mean ± Std')
ax.fill_between(sizes, q25, q75, alpha=0.3, label='25th-75th percentile')
ax.set_xlabel('Cluster Size (training examples)', fontsize=12)
ax.set_ylabel('Off-Diagonal Accuracy', fontsize=12)
ax.set_title('Generalization Accuracy vs Training Cluster Size', fontsize=14)
ax.set_xscale('log')
ax.set_xticks(sizes)
ax.set_xticklabels(sizes)
ax.legend()
ax.grid(True, alpha=0.3)
ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.5, label='Random (10%)')

# Plot 2: Box plots
ax = axes[0, 1]
box_data = [sweep_results[s]['all_off_diag'] for s in sizes]
bp = ax.boxplot(box_data, labels=sizes, patch_artist=True)
for patch in bp['boxes']:
    patch.set_facecolor('lightblue')
ax.set_xlabel('Cluster Size', fontsize=12)
ax.set_ylabel('Off-Diagonal Accuracy', fontsize=12)
ax.set_title('Distribution of Generalization Accuracies', fontsize=14)
ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.5)
ax.grid(True, alpha=0.3, axis='y')

# Plot 3: Train vs Test accuracy
ax = axes[1, 0]
train_means = [sweep_results[s]['diag_mean'] for s in sizes]
test_means = [sweep_results[s]['off_diag_mean'] for s in sizes]
ax.plot(sizes, train_means, 'o-', linewidth=2, markersize=8, label='Training (diagonal)')
ax.plot(sizes, test_means, 's-', linewidth=2, markersize=8, label='Test (off-diagonal)')
ax.set_xlabel('Cluster Size', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Train vs Test Accuracy', fontsize=14)
ax.set_xscale('log')
ax.set_xticks(sizes)
ax.set_xticklabels(sizes)
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Generalization gap
ax = axes[1, 1]
gaps = [t - e for t, e in zip(train_means, test_means)]
ax.bar(range(len(sizes)), gaps, tick_label=sizes, color='coral', edgecolor='black')
ax.set_xlabel('Cluster Size', fontsize=12)
ax.set_ylabel('Train - Test Accuracy Gap', fontsize=12)
ax.set_title('Generalization Gap', fontsize=14)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('cluster_size_sweep_summary.png', dpi=150)
plt.show()

#%% Cell 11: Violin Plots
fig, ax = plt.subplots(figsize=(14, 6))

violin_data = [sweep_results[s]['all_off_diag'] for s in sizes]
parts = ax.violinplot(violin_data, positions=range(len(sizes)), showmeans=True, showmedians=True)

for pc in parts['bodies']:
    pc.set_facecolor('steelblue')
    pc.set_alpha(0.7)

ax.set_xticks(range(len(sizes)))
ax.set_xticklabels(sizes)
ax.set_xlabel('Cluster Size (training examples)', fontsize=12)
ax.set_ylabel('Off-Diagonal Accuracy', fontsize=12)
ax.set_title('Distribution of Generalization Accuracies by Cluster Size', fontsize=14)
ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.5, label='Random baseline')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('cluster_size_violins.png', dpi=150)
plt.show()

#%% Cell 12: Heatmaps for Selected Sizes
key_sizes = [5, 20, 50, 100]
key_sizes = [s for s in key_sizes if s in sweep_results]

if len(key_sizes) > 0:
    fig, axes = plt.subplots(1, len(key_sizes), figsize=(5*len(key_sizes), 5))
    if len(key_sizes) == 1:
        axes = [axes]

    for ax, size in zip(axes, key_sizes):
        matrix = sweep_results[size]['matrix']
        show_size = min(50, matrix.shape[0], matrix.shape[1])
        matrix_show = matrix[:show_size, :show_size].copy()

        # Diagonal is already NaN from evaluation

        im = ax.imshow(matrix_show, cmap='viridis', aspect='auto')
        ax.set_title(f'Cluster Size = {size}\n({show_size}x{show_size}, diagonal excluded)')
        ax.set_xlabel('Eval Cluster (index into trained)')
        ax.set_ylabel('Train Model')
        plt.colorbar(im, ax=ax, label='Accuracy')

    plt.tight_layout()
    plt.savefig('cluster_size_heatmaps.png', dpi=150)
    plt.show()

#%% Cell 13: Learning Curve Fit
fig, ax = plt.subplots(figsize=(10, 6))

sizes_arr = np.array(sizes)
means_arr = np.array(means)

from scipy.optimize import curve_fit

def learning_curve(n, a, b, c):
    return a - b * np.exp(-c * n)

try:
    popt, pcov = curve_fit(learning_curve, sizes_arr, means_arr, p0=[0.9, 0.8, 0.1], maxfev=10000)
    fitted_sizes = np.linspace(1, 100, 200)
    fitted_acc = learning_curve(fitted_sizes, *popt)

    ax.scatter(sizes, means, s=100, zorder=5, label='Observed')
    ax.plot(fitted_sizes, fitted_acc, 'r-', linewidth=2,
            label=f'Fit: {popt[0]:.3f} - {popt[1]:.3f}·exp(-{popt[2]:.3f}·n)')

    # Estimate cluster size needed for various accuracy thresholds
    thresholds = [0.5, 0.6, 0.7, 0.8]
    for thresh in thresholds:
        if popt[0] > thresh and popt[1] > 0:
            ratio = (popt[0] - thresh) / popt[1]
            if ratio > 0:
                n_thresh = -np.log(ratio) / popt[2]
                if 1 <= n_thresh <= 100:
                    ax.axhline(y=thresh, color='gray', linestyle=':', alpha=0.5)
                    ax.axvline(x=n_thresh, color='gray', linestyle=':', alpha=0.5)
                    ax.annotate(f'{thresh*100:.0f}% @ n≈{n_thresh:.0f}',
                               xy=(n_thresh, thresh), xytext=(n_thresh+5, thresh-0.03),
                               fontsize=10)
except Exception as e:
    print(f"Curve fitting failed: {e}")
    ax.scatter(sizes, means, s=100, zorder=5, label='Observed')

ax.set_xlabel('Cluster Size (training examples)', fontsize=12)
ax.set_ylabel('Mean Off-Diagonal Accuracy', fontsize=12)
ax.set_title('Learning Curve: Accuracy vs Training Set Size', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 105)

plt.tight_layout()
plt.savefig('learning_curve_fit.png', dpi=150)
plt.show()

#%% Cell 14: Timing Analysis
fig, ax = plt.subplots(figsize=(10, 5))

train_times = [sweep_results[s]['train_time'] for s in sizes]
eval_times = [sweep_results[s]['eval_time'] for s in sizes]

x = np.arange(len(sizes))
width = 0.35

bars1 = ax.bar(x - width/2, train_times, width, label='Training', color='steelblue')
bars2 = ax.bar(x + width/2, eval_times, width, label='Evaluation', color='coral')

ax.set_xlabel('Cluster Size', fontsize=12)
ax.set_ylabel('Time (seconds)', fontsize=12)
ax.set_title('Compute Time by Cluster Size', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(sizes)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('timing_analysis.png', dpi=150)
plt.show()

#%% Cell 15: Summary Table
print("\n" + "="*60)
print("SUMMARY TABLE")
print("="*60)

print(f"\n{'Size':>6} | {'Clusters':>8} | {'Trained':>7} | {'Test Acc':>14} | {'Train Acc':>14} | {'Gap':>8} | {'Time':>10}")
print("-" * 90)
for s in sizes:
    r = sweep_results[s]
    gap = r['diag_mean'] - r['off_diag_mean']
    total_time = r['train_time'] + r['eval_time']
    print(f"{s:>6} | {r['num_clusters']:>8} | {r['num_trained']:>7} | "
          f"{r['off_diag_mean']:.4f}±{r['off_diag_std']:.4f} | "
          f"{r['diag_mean']:.4f}±{r['diag_std']:.4f} | {gap:>8.4f} | {total_time:>9.1f}s")

#%% Cell 16: Save Results
import json

results_to_save = {
    'config': {
        'cluster_sizes': config.cluster_sizes,
        'epochs': config.epochs,
        'max_clusters_to_train': config.max_clusters_to_train,
        'parallel_models': config.parallel_models,
        'seed': config.seed
    },
    'sweep_results': {
        str(size): {
            'cluster_size': r['cluster_size'],
            'num_clusters': r['num_clusters'],
            'num_trained': r['num_trained'],
            'off_diag_mean': float(r['off_diag_mean']),
            'off_diag_std': float(r['off_diag_std']),
            'off_diag_median': float(r['off_diag_median']),
            'diag_mean': float(r['diag_mean']),
            'diag_std': float(r['diag_std']),
            'train_time': r['train_time'],
            'eval_time': r['eval_time']
        }
        for size, r in sweep_results.items()
    }
}

with open('cluster_size_sweep_results.json', 'w') as f:
    json.dump(results_to_save, f, indent=2)

# Save full arrays
np.savez_compressed('cluster_size_sweep_full.npz',
    **{f'size_{s}_off_diag': sweep_results[s]['all_off_diag'] for s in sizes},
    **{f'size_{s}_diag': sweep_results[s]['all_diag'] for s in sizes},
    **{f'size_{s}_matrix': sweep_results[s]['matrix'] for s in sizes},
    **{f'size_{s}_train_idx': sweep_results[s]['train_cluster_indices'] for s in sizes}
)

print("\nResults saved!")

print("\n" + "="*60)
print("EXPERIMENT COMPLETE")
print("="*60)

: 