# Digit Correlation Experiment

Train 50 models on balanced samples (2 per digit = 20 total), then analyze digit-by-digit accuracy correlations across models.

Exports data for interactive D3.js dashboard.

In [None]:
!git clone https://github.com/Caleb-Briggs/MNIST_AI.git
%cd MNIST_AI

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import json
import base64
from io import BytesIO
from PIL import Image

from utils.data import load_mnist, get_device
from utils.models import SmallCNN
from utils.evaluation import get_predictions

device = get_device()
print(f"Device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
NUM_MODELS = 50
SAMPLES_PER_DIGIT = 2
TARGET_TRAIN_ACC = 0.99
MAX_EPOCHS = 200
LR = 1e-3
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
images, labels = load_mnist(device, train=True)
test_images, test_labels = load_mnist(device, train=False)

digit_indices = {d: (labels == d).nonzero(as_tuple=True)[0].cpu().numpy() for d in range(10)}

In [None]:
def create_balanced_sample(digit_indices, samples_per_digit, rng):
    indices = []
    for d in range(10):
        chosen = rng.choice(digit_indices[d], size=samples_per_digit, replace=False)
        indices.extend(chosen)
    return np.array(indices)

def train_until_accuracy(model, images, labels, indices, target_acc, max_epochs, lr):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    X = images[indices]
    y = labels[indices]
    
    model.train()
    for epoch in range(max_epochs):
        optimizer.zero_grad()
        output = model(X)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            preds = output.argmax(dim=1)
            acc = (preds == y).float().mean().item()
        
        if acc >= target_acc:
            return epoch + 1, acc
    
    return max_epochs, acc

In [None]:
print(f"Training {NUM_MODELS} models on {SAMPLES_PER_DIGIT * 10} samples each...")

models = []
training_indices = []
epochs_to_converge = []
rng = np.random.default_rng(SEED)

for i in tqdm(range(NUM_MODELS)):
    indices = create_balanced_sample(digit_indices, SAMPLES_PER_DIGIT, rng)
    training_indices.append(indices)
    
    model = SmallCNN().to(device)
    epochs, _ = train_until_accuracy(model, images, labels, indices, TARGET_TRAIN_ACC, MAX_EPOCHS, LR)
    
    models.append(model)
    epochs_to_converge.append(epochs)

print(f"Epochs to converge: {np.mean(epochs_to_converge):.1f} ± {np.std(epochs_to_converge):.1f}")

In [None]:
print("Evaluating on test set...")

all_correct = np.zeros((NUM_MODELS, len(test_labels)), dtype=bool)

for i, model in enumerate(tqdm(models)):
    result = get_predictions(model, test_images, test_labels)
    all_correct[i] = result['correct']

model_accuracies = all_correct.mean(axis=1)
print(f"Test accuracy: {model_accuracies.mean():.4f} ± {model_accuracies.std():.4f}")

In [None]:
test_labels_np = test_labels.cpu().numpy()
sample_difficulty = all_correct.mean(axis=0)

digit_accuracies = np.zeros((NUM_MODELS, 10))
for d in range(10):
    digit_mask = test_labels_np == d
    digit_accuracies[:, d] = all_correct[:, digit_mask].mean(axis=1)

digit_corr = np.corrcoef(digit_accuracies.T)
model_corr = np.corrcoef(all_correct)

In [None]:
# Visualizations
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Digit correlation
im = axes[0, 0].imshow(digit_corr, cmap='RdBu_r', vmin=-1, vmax=1)
axes[0, 0].set_xticks(range(10))
axes[0, 0].set_yticks(range(10))
axes[0, 0].set_title('Digit-Digit Accuracy Correlation')
plt.colorbar(im, ax=axes[0, 0])

# Model correlation
im = axes[0, 1].imshow(model_corr, cmap='viridis')
axes[0, 1].set_title('Model-Model Correlation')
plt.colorbar(im, ax=axes[0, 1])

# Sample difficulty distribution
axes[1, 0].hist(sample_difficulty, bins=50, edgecolor='black', alpha=0.7)
axes[1, 0].set_xlabel('Accuracy across models')
axes[1, 0].set_ylabel('Count')
axes[1, 0].set_title('Sample Difficulty Distribution')

# Model accuracy distribution
axes[1, 1].hist(model_accuracies, bins=20, edgecolor='black', alpha=0.7)
axes[1, 1].set_xlabel('Test accuracy')
axes[1, 1].set_ylabel('Count')
axes[1, 1].set_title('Model Accuracy Distribution')

plt.tight_layout()
plt.show()

In [None]:
# Per-digit difficulty
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for d in range(10):
    ax = axes[d // 5, d % 5]
    digit_mask = test_labels_np == d
    digit_diff = sample_difficulty[digit_mask]
    ax.hist(digit_diff, bins=20, edgecolor='black', alpha=0.7)
    ax.axvline(digit_diff.mean(), color='red', linestyle='--')
    ax.set_title(f'Digit {d} (mean={digit_diff.mean():.2f})')
    ax.set_xlim(0, 1)
plt.suptitle('Sample Difficulty by Digit')
plt.tight_layout()
plt.show()

In [None]:
# Hardest/easiest samples
fig, axes = plt.subplots(4, 10, figsize=(15, 6))
for d in range(10):
    digit_mask = test_labels_np == d
    digit_idx = np.where(digit_mask)[0]
    digit_diff = sample_difficulty[digit_mask]
    sorted_idx = digit_idx[np.argsort(digit_diff)]
    
    for row, idx in enumerate(sorted_idx[:3]):
        axes[row, d].imshow(test_images[idx, 0].cpu().numpy(), cmap='gray')
        axes[row, d].set_title(f'{sample_difficulty[idx]:.0%}', fontsize=9)
        axes[row, d].axis('off')
    
    axes[3, d].imshow(test_images[sorted_idx[-1], 0].cpu().numpy(), cmap='gray')
    axes[3, d].set_title(f'{sample_difficulty[sorted_idx[-1]]:.0%}', fontsize=9)
    axes[3, d].axis('off')
    axes[3, d].set_xlabel(f'{d}')

axes[0, 0].set_ylabel('Hardest')
axes[1, 0].set_ylabel('2nd')
axes[2, 0].set_ylabel('3rd')
axes[3, 0].set_ylabel('Easiest')
plt.suptitle('Hardest/Easiest Samples by Digit')
plt.tight_layout()
plt.show()

In [None]:
print("="*50)
print("SUMMARY")
print("="*50)
print(f"Models: {NUM_MODELS}, trained on {SAMPLES_PER_DIGIT * 10} samples each")
print(f"Test accuracy: {model_accuracies.mean():.4f} ± {model_accuracies.std():.4f}")
off_diag = model_corr[np.triu_indices(NUM_MODELS, k=1)]
print(f"Model-model correlation: {off_diag.mean():.4f}")
digit_corr_off = digit_corr[np.triu_indices(10, k=1)]
print(f"Inter-digit correlation: {digit_corr_off.mean():.4f}")

## Export for Dashboard

In [None]:
def tensor_to_base64(tensor):
    """Convert a single-channel image tensor to base64 PNG."""
    # Denormalize
    img = tensor.cpu().numpy()
    img = (img * 0.3081 + 0.1307) * 255
    img = np.clip(img, 0, 255).astype(np.uint8)
    
    # Convert to PIL and encode
    pil_img = Image.fromarray(img, mode='L')
    buffer = BytesIO()
    pil_img.save(buffer, format='PNG')
    b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
    return f"data:image/png;base64,{b64}"

print("Converting images to base64...")
test_images_b64 = [tensor_to_base64(test_images[i, 0]) for i in tqdm(range(len(test_images)))]

In [None]:
print("Building dashboard data...")

# Build test_images data
test_images_data = []
for i in range(len(test_labels)):
    correct_by = np.where(all_correct[:, i])[0].tolist()
    test_images_data.append({
        "id": i,
        "digit": int(test_labels_np[i]),
        "difficulty": float(sample_difficulty[i]),
        "correct_by": correct_by,
        "image": test_images_b64[i]
    })

# Build models data
models_data = []
for i in range(NUM_MODELS):
    models_data.append({
        "id": i,
        "training_indices": training_indices[i].tolist(),
        "test_accuracy": float(model_accuracies[i])
    })

dashboard_data = {
    "config": {
        "num_models": NUM_MODELS,
        "samples_per_digit": SAMPLES_PER_DIGIT,
        "target_train_acc": TARGET_TRAIN_ACC,
        "seed": SEED
    },
    "models": models_data,
    "test_images": test_images_data,
    "digit_correlation": digit_corr.tolist(),
    "model_correlation": model_corr.tolist()
}

print(f"Data size: {len(json.dumps(dashboard_data)) / 1e6:.1f} MB")

In [None]:
# Save to file
!mkdir -p dashboard/data

with open('dashboard/data/dashboard_data.json', 'w') as f:
    json.dump(dashboard_data, f)

print("Saved to dashboard/data/dashboard_data.json")
print("\nTo view the dashboard:")
print("1. Download the dashboard/ folder")
print("2. Run: python -m http.server 8000")
print("3. Open: http://localhost:8000/dashboard/")