# Latent Space Clustering Evolution

Track how network representations evolve during training:
1. Cluster raw pixels by cosine similarity
2. Extract activations at each layer during training
3. Visualize clustering quality over time
4. See how network learns to separate digits in latent space

In [1]:
!git clone https://github.com/Caleb-Briggs/MNIST_AI.git
%cd MNIST_AI

Cloning into 'MNIST_AI'...
remote: Enumerating objects: 54, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 54 (delta 11), reused 51 (delta 8), pack-reused 0 (from 0)[K
Receiving objects: 100% (54/54), 18.82 MiB | 14.02 MiB/s, done.
Resolving deltas: 100% (11/11), done.
/content/MNIST_AI


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, adjusted_rand_score
from tqdm.auto import tqdm
import seaborn as sns

import sys
sys.path.append('/content/MNIST_AI')

from shared.utils.data import load_mnist, get_device
from shared.utils.models import SmallCNN

device = get_device()
print(f"Device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Config
TRAIN_SIZE = 5000  # Train on more data than correlation experiment
EVAL_SIZE = 1000   # Subset for visualization
MAX_EPOCHS = 100
LR = 1e-3
CHECKPOINT_EPOCHS = list(range(0, 101, 2))  # Every 2 epochs for smooth animation
SEED = 17  # Different seed each experiment (not 42!)

torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
# Load data
images, labels = load_mnist(device, train=True)
test_images, test_labels = load_mnist(device, train=False)

# Sample training set (balanced across digits)
train_indices = []
samples_per_digit = TRAIN_SIZE // 10
for d in range(10):
    digit_mask = labels == d
    digit_indices = torch.where(digit_mask)[0]
    perm = torch.randperm(len(digit_indices))[:samples_per_digit]
    train_indices.extend(digit_indices[perm].tolist())

train_indices = torch.tensor(train_indices)
X_train = images[train_indices]
y_train = labels[train_indices]

# Sample eval set (for visualization)
eval_indices = []
eval_per_digit = EVAL_SIZE // 10
for d in range(10):
    digit_mask = test_labels == d
    digit_indices = torch.where(digit_mask)[0]
    perm = torch.randperm(len(digit_indices))[:eval_per_digit]
    eval_indices.extend(digit_indices[perm].tolist())

eval_indices = torch.tensor(eval_indices)
X_eval = test_images[eval_indices]
y_eval = test_labels[eval_indices]

print(f"Train: {len(X_train)} images")
print(f"Eval: {len(X_eval)} images")

## 1. Raw Pixel Clustering

First, cluster raw pixels by cosine similarity to see baseline structure.

In [None]:
# Flatten images to vectors
X_flat = X_eval.view(len(X_eval), -1).cpu().numpy()
y_np = y_eval.cpu().numpy()

# Normalize to unit length (for cosine similarity)
X_norm = X_flat / (np.linalg.norm(X_flat, axis=1, keepdims=True) + 1e-8)

# K-means clustering (k=10 for 10 digits)
kmeans = KMeans(n_clusters=10, random_state=SEED, n_init=10)
pixel_clusters = kmeans.fit_predict(X_norm)

# Metrics
silhouette = silhouette_score(X_norm, pixel_clusters)
ari = adjusted_rand_score(y_np, pixel_clusters)

print(f"Raw Pixels - Silhouette: {silhouette:.3f}, ARI: {ari:.3f}")

In [None]:
# Visualize with t-SNE
print("Computing t-SNE for raw pixels...")
tsne = TSNE(n_components=2, random_state=SEED, perplexity=30)
X_tsne = tsne.fit_transform(X_norm)

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Color by true label
scatter = axes[0].scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_np, cmap='tab10', s=20, alpha=0.6)
axes[0].set_title('Raw Pixels (colored by true digit)', fontsize=14)
axes[0].set_xlabel('t-SNE 1')
axes[0].set_ylabel('t-SNE 2')
plt.colorbar(scatter, ax=axes[0], ticks=range(10))

# Color by cluster
scatter = axes[1].scatter(X_tsne[:, 0], X_tsne[:, 1], c=pixel_clusters, cmap='tab10', s=20, alpha=0.6)
axes[1].set_title(f'K-Means Clusters (ARI={ari:.3f})', fontsize=14)
axes[1].set_xlabel('t-SNE 1')
axes[1].set_ylabel('t-SNE 2')
plt.colorbar(scatter, ax=axes[1], ticks=range(10))

plt.tight_layout()
plt.show()

## 2. Layer Activation Extraction

Extract activations at each layer during training to see how representations evolve.

In [None]:
# Hook to capture layer activations
activations = {}

def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

def extract_layer_activations(model, x):
    """Extract activations at each layer."""
    global activations
    activations = {}
    
    # Register hooks
    hooks = []
    hooks.append(model.conv1.register_forward_hook(get_activation('conv1')))
    hooks.append(model.conv2.register_forward_hook(get_activation('conv2')))
    hooks.append(model.conv3.register_forward_hook(get_activation('conv3')))
    hooks.append(model.fc1.register_forward_hook(get_activation('fc1')))
    hooks.append(model.fc2.register_forward_hook(get_activation('fc2')))
    
    # Forward pass
    with torch.no_grad():
        model.eval()
        _ = model(x)
    
    # Remove hooks
    for h in hooks:
        h.remove()
    
    # Flatten activations
    result = {}
    for name, act in activations.items():
        result[name] = act.view(len(x), -1).cpu().numpy()
    
    return result

In [None]:
# Train model and save layer activations at checkpoints
model = SmallCNN().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

layer_snapshots = []  # (epoch, layer_name, activations)
training_history = []  # (epoch, train_acc, test_acc)

print("Training model and capturing layer activations...")
for epoch in tqdm(range(MAX_EPOCHS)):
    # Training step
    model.train()
    optimizer.zero_grad()
    output = model(X_train)
    loss = F.cross_entropy(output, y_train)
    loss.backward()
    optimizer.step()
    
    # Eval
    model.eval()
    with torch.no_grad():
        train_preds = model(X_train).argmax(dim=1)
        train_acc = (train_preds == y_train).float().mean().item()
        
        test_preds = model(X_eval).argmax(dim=1)
        test_acc = (test_preds == y_eval).float().mean().item()
    
    training_history.append((epoch + 1, train_acc, test_acc))
    
    # Save layer activations at checkpoints
    if (epoch + 1) in CHECKPOINT_EPOCHS:
        layer_acts = extract_layer_activations(model, X_eval)
        for layer_name, acts in layer_acts.items():
            layer_snapshots.append((epoch + 1, layer_name, acts))
        print(f"  Epoch {epoch + 1}: Train={train_acc:.3f}, Test={test_acc:.3f}")

print(f"\nFinal: Train={train_acc:.3f}, Test={test_acc:.3f}")

In [None]:
# Plot training curves
epochs, train_accs, test_accs = zip(*training_history)

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_accs, label='Train', linewidth=2)
plt.plot(epochs, test_accs, label='Test', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Progress')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## 3. Layer-wise Clustering Over Time

Visualize how each layer learns to separate digits.

In [None]:
# Compute clustering metrics for all snapshots
layer_names = ['conv1', 'conv2', 'conv3', 'fc1', 'fc2']
checkpoint_epochs_actual = sorted(set(e for e, _, _ in layer_snapshots))

clustering_metrics = {layer: {'epoch': [], 'silhouette': [], 'ari': []} 
                     for layer in layer_names}

print("Computing clustering metrics...")
for epoch, layer_name, acts in tqdm(layer_snapshots):
    # Normalize activations
    acts_norm = acts / (np.linalg.norm(acts, axis=1, keepdims=True) + 1e-8)
    
    # K-means clustering
    kmeans = KMeans(n_clusters=10, random_state=SEED, n_init=10)
    clusters = kmeans.fit_predict(acts_norm)
    
    # Metrics
    silhouette = silhouette_score(acts_norm, clusters)
    ari = adjusted_rand_score(y_np, clusters)
    
    clustering_metrics[layer_name]['epoch'].append(epoch)
    clustering_metrics[layer_name]['silhouette'].append(silhouette)
    clustering_metrics[layer_name]['ari'].append(ari)

print("Done!")

In [None]:
# Plot clustering quality over time
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Silhouette score
for layer in layer_names:
    axes[0].plot(clustering_metrics[layer]['epoch'], 
                clustering_metrics[layer]['silhouette'],
                marker='o', label=layer, linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Silhouette Score')
axes[0].set_title('Cluster Cohesion Over Training', fontsize=14)
axes[0].legend()
axes[0].grid(alpha=0.3)

# Adjusted Rand Index
for layer in layer_names:
    axes[1].plot(clustering_metrics[layer]['epoch'], 
                clustering_metrics[layer]['ari'],
                marker='o', label=layer, linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Adjusted Rand Index')
axes[1].set_title('Alignment with True Labels Over Training', fontsize=14)
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Layer Visualization Grid

Show t-SNE of each layer at different epochs.

In [None]:
# Visualize ALL layers at ALL epochs - create frames for animation
vis_epochs = sorted(set(e for e, _, _ in layer_snapshots))
vis_layers = ['conv1', 'conv2', 'conv3', 'fc1', 'fc2']

# Pre-compute all t-SNE embeddings
print("Pre-computing t-SNE embeddings for all layers and epochs...")
embeddings = {}  # (epoch, layer) -> embedding

for epoch, layer, acts in tqdm(layer_snapshots):
    # Normalize
    acts_norm = acts / (np.linalg.norm(acts, axis=1, keepdims=True) + 1e-8)
    
    # Use PCA first if dimensionality is very high
    if acts_norm.shape[1] > 50:
        pca = PCA(n_components=50, random_state=SEED)
        acts_pca = pca.fit_transform(acts_norm)
    else:
        acts_pca = acts_norm
    
    # t-SNE
    tsne = TSNE(n_components=2, random_state=SEED, perplexity=30)
    embedding = tsne.fit_transform(acts_pca)
    embeddings[(epoch, layer)] = embedding

print(f"Computed {len(embeddings)} embeddings")

In [None]:
# Also create static grid visualization of ALL layers
print("Creating static grid of all layers...")

# Select specific epochs to show
grid_epochs = [0, 2, 5, 10, 20, 50, 100]
grid_layers = ['conv1', 'conv2', 'conv3', 'fc1', 'fc2']

fig, axes = plt.subplots(len(grid_epochs), len(grid_layers), 
                         figsize=(4 * len(grid_layers), 4 * len(grid_epochs)))

for i, epoch in enumerate(grid_epochs):
    for j, layer in enumerate(grid_layers):
        ax = axes[i, j]
        
        # Get embedding
        embedding = embeddings.get((epoch, layer))
        if embedding is None:
            ax.text(0.5, 0.5, 'No data', ha='center', va='center')
            ax.set_xticks([])
            ax.set_yticks([])
            continue
        
        # Get ARI
        if epoch in clustering_metrics[layer]['epoch']:
            ari_idx = clustering_metrics[layer]['epoch'].index(epoch)
            ari = clustering_metrics[layer]['ari'][ari_idx]
        else:
            ari = 0.0
        
        # Plot
        scatter = ax.scatter(embedding[:, 0], embedding[:, 1], 
                           c=y_np, cmap='tab10', s=8, alpha=0.6)
        ax.set_title(f'{layer} @ Epoch {epoch} (ARI={ari:.2f})', fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])

# Add row labels
for i, epoch in enumerate(grid_epochs):
    axes[i, 0].set_ylabel(f'Epoch {epoch}', fontsize=12, rotation=0, 
                         ha='right', va='center', labelpad=40)

plt.tight_layout()
plt.show()

In [None]:
# Download GIF using iframe (only method that works in VS Code Colab)
from IPython.display import HTML, display
import base64

# Read the GIF file
with open('/tmp/latent_evolution.gif', 'rb') as f:
    gif_data = f.read()

# Encode as base64
b64 = base64.b64encode(gif_data).decode()

# Create download link
html = f'''
<a download="latent_evolution.gif" 
   href="data:image/gif;base64,{b64}" 
   id="download_link">
   Click to download latent_evolution.gif
</a>
<script>
// Auto-trigger download
document.getElementById('download_link').click();
</script>
<div style="margin-top: 1rem;">
<img src="data:image/gif;base64,{b64}" style="max-width: 100%;"/>
</div>
'''
display(HTML(html))
print("GIF download should start automatically. Preview shown above.")

In [None]:
# Create animation frames
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend for rendering
from matplotlib.animation import FuncAnimation, PillowWriter
import io

print("Creating animation frames...")
fig, axes = plt.subplots(1, len(vis_layers), figsize=(4 * len(vis_layers), 4))

def update_frame(epoch):
    for j, layer in enumerate(vis_layers):
        ax = axes[j]
        ax.clear()
        
        # Get embedding
        embedding = embeddings.get((epoch, layer))
        if embedding is None:
            ax.text(0.5, 0.5, f'Epoch {epoch}\nNo data', 
                   ha='center', va='center', transform=ax.transAxes)
            continue
        
        # Get ARI
        ari_idx = clustering_metrics[layer]['epoch'].index(epoch)
        ari = clustering_metrics[layer]['ari'][ari_idx]
        
        # Plot
        scatter = ax.scatter(embedding[:, 0], embedding[:, 1], 
                           c=y_np, cmap='tab10', s=10, alpha=0.7)
        ax.set_title(f'{layer} (Epoch {epoch}, ARI={ari:.3f})', fontsize=11)
        ax.set_xticks([])
        ax.set_yticks([])
    
    fig.tight_layout()
    return axes

# Create animation
anim = FuncAnimation(fig, update_frame, frames=vis_epochs, interval=200, repeat=True)

# Save as GIF
print("Saving animation as GIF (this may take a while)...")
writer = PillowWriter(fps=5)
anim.save('/tmp/latent_evolution.gif', writer=writer, dpi=80)

plt.close(fig)
print("Animation saved!")

## 5. Summary

Compare final layer to raw pixels.

In [None]:
print("="*60)
print("CLUSTERING QUALITY SUMMARY")
print("="*60)

print(f"\nRaw Pixels:")
print(f"  Silhouette: {silhouette:.3f}")
print(f"  ARI: {ari:.3f}")

print(f"\nFinal Epoch ({MAX_EPOCHS}):")
for layer in layer_names:
    final_sil = clustering_metrics[layer]['silhouette'][-1]
    final_ari = clustering_metrics[layer]['ari'][-1]
    print(f"  {layer:6s}: Silhouette={final_sil:.3f}, ARI={final_ari:.3f}")

# Best layer
best_layer = max(layer_names, key=lambda l: clustering_metrics[l]['ari'][-1])
best_ari = clustering_metrics[best_layer]['ari'][-1]
print(f"\nBest layer: {best_layer} (ARI={best_ari:.3f})")