# Digit Correlation Experiment

Train 50 models on balanced samples (2 per digit = 20 total), then analyze digit-by-digit accuracy correlations across models.

In [None]:
!git clone https://github.com/Caleb-Briggs/MNIST_AI.git
%cd MNIST_AI

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from utils.data import load_mnist, get_device
from utils.models import SmallCNN
from utils.evaluation import get_predictions

device = get_device()
print(f"Device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
NUM_MODELS = 50
SAMPLES_PER_DIGIT = 2
TARGET_TRAIN_ACC = 0.99
MAX_EPOCHS = 200
LR = 1e-3
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
images, labels = load_mnist(device, train=True)
test_images, test_labels = load_mnist(device, train=False)

digit_indices = {d: (labels == d).nonzero(as_tuple=True)[0].cpu().numpy() for d in range(10)}

In [None]:
def create_balanced_sample(digit_indices, samples_per_digit, rng):
    indices = []
    for d in range(10):
        chosen = rng.choice(digit_indices[d], size=samples_per_digit, replace=False)
        indices.extend(chosen)
    return np.array(indices)

def train_until_accuracy(model, images, labels, indices, target_acc, max_epochs, lr):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    X = images[indices]
    y = labels[indices]
    
    model.train()
    for epoch in range(max_epochs):
        optimizer.zero_grad()
        output = model(X)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            preds = output.argmax(dim=1)
            acc = (preds == y).float().mean().item()
        
        if acc >= target_acc:
            return epoch + 1, acc
    
    return max_epochs, acc

In [None]:
print(f"Training {NUM_MODELS} models on {SAMPLES_PER_DIGIT * 10} samples each...")

models = []
training_indices = []
epochs_to_converge = []
rng = np.random.default_rng(SEED)

for i in tqdm(range(NUM_MODELS)):
    indices = create_balanced_sample(digit_indices, SAMPLES_PER_DIGIT, rng)
    training_indices.append(indices)
    
    model = SmallCNN().to(device)
    epochs, _ = train_until_accuracy(model, images, labels, indices, TARGET_TRAIN_ACC, MAX_EPOCHS, LR)
    
    models.append(model)
    epochs_to_converge.append(epochs)

print(f"Epochs to converge: {np.mean(epochs_to_converge):.1f} ± {np.std(epochs_to_converge):.1f}")

In [None]:
print("Evaluating on test set...")

all_correct = np.zeros((NUM_MODELS, len(test_labels)), dtype=bool)

for i, model in enumerate(tqdm(models)):
    result = get_predictions(model, test_images, test_labels)
    all_correct[i] = result['correct']

model_accuracies = all_correct.mean(axis=1)
print(f"Test accuracy: {model_accuracies.mean():.4f} ± {model_accuracies.std():.4f}")

In [None]:
test_labels_np = test_labels.cpu().numpy()
digit_accuracies = np.zeros((NUM_MODELS, 10))

for d in range(10):
    digit_mask = test_labels_np == d
    digit_accuracies[:, d] = all_correct[:, digit_mask].mean(axis=1)

print("Per-digit accuracy (mean ± std):")
for d in range(10):
    print(f"  {d}: {digit_accuracies[:, d].mean():.4f} ± {digit_accuracies[:, d].std():.4f}")

In [None]:
digit_corr = np.corrcoef(digit_accuracies.T)

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(digit_corr, cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_xticks(range(10))
ax.set_yticks(range(10))
ax.set_xlabel('Digit')
ax.set_ylabel('Digit')
ax.set_title('Digit-Digit Accuracy Correlation Across Models')

for i in range(10):
    for j in range(10):
        color = 'white' if abs(digit_corr[i, j]) > 0.5 else 'black'
        ax.text(j, i, f'{digit_corr[i, j]:.2f}', ha='center', va='center', fontsize=8, color=color)

plt.colorbar(im, label='Correlation')
plt.tight_layout()
plt.show()

In [None]:
sample_difficulty = all_correct.mean(axis=0)

print(f"Sample difficulty:")
print(f"  Always correct: {(sample_difficulty == 1.0).sum()}")
print(f"  >80% correct: {(sample_difficulty > 0.8).sum()}")
print(f"  <50% correct: {(sample_difficulty < 0.5).sum()}")
print(f"  Always wrong: {(sample_difficulty == 0.0).sum()}")

In [None]:
model_corr = np.corrcoef(all_correct)
off_diag = model_corr[np.triu_indices(NUM_MODELS, k=1)]

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

im = axes[0].imshow(model_corr, cmap='viridis')
axes[0].set_xlabel('Model')
axes[0].set_ylabel('Model')
axes[0].set_title('Model-Model Correlation')
plt.colorbar(im, ax=axes[0])

axes[1].hist(off_diag, bins=30, edgecolor='black', alpha=0.7)
axes[1].axvline(off_diag.mean(), color='red', linestyle='--', label=f'Mean: {off_diag.mean():.3f}')
axes[1].set_xlabel('Correlation')
axes[1].set_ylabel('Count')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"Model-model correlation: {off_diag.mean():.4f} ± {off_diag.std():.4f}")

In [None]:
fig, axes = plt.subplots(4, 10, figsize=(15, 6))

for d in range(10):
    digit_mask = test_labels_np == d
    digit_idx = np.where(digit_mask)[0]
    digit_diff = sample_difficulty[digit_mask]
    sorted_idx = digit_idx[np.argsort(digit_diff)]
    
    for row, idx in enumerate(sorted_idx[:3]):
        axes[row, d].imshow(test_images[idx, 0].cpu().numpy(), cmap='gray')
        axes[row, d].set_title(f'{sample_difficulty[idx]:.0%}', fontsize=9)
        axes[row, d].axis('off')
    
    easiest = sorted_idx[-1]
    axes[3, d].imshow(test_images[easiest, 0].cpu().numpy(), cmap='gray')
    axes[3, d].set_title(f'{sample_difficulty[easiest]:.0%}', fontsize=9)
    axes[3, d].axis('off')

for d in range(10):
    axes[3, d].set_xlabel(f'{d}', fontsize=10)

axes[0, 0].set_ylabel('Hardest', fontsize=10)
axes[1, 0].set_ylabel('2nd', fontsize=10)
axes[2, 0].set_ylabel('3rd', fontsize=10)
axes[3, 0].set_ylabel('Easiest', fontsize=10)

plt.suptitle('Hardest/Easiest Samples by Digit')
plt.tight_layout()
plt.show()

In [None]:
digit_corr_off = digit_corr[np.triu_indices(10, k=1)]

print("="*50)
print("SUMMARY")
print("="*50)
print(f"Models: {NUM_MODELS}, trained on {SAMPLES_PER_DIGIT * 10} samples each")
print(f"Test accuracy: {model_accuracies.mean():.4f} ± {model_accuracies.std():.4f}")
print(f"Model-model correlation: {off_diag.mean():.4f}")
print(f"Inter-digit correlation: {digit_corr_off.mean():.4f}")