# Complete End-to-End Reproduction

This notebook reproduces all results from the paper:

> **"Delta Observer: Learning Continuous Semantic Manifolds Between Neural Network Representations"**  
> Aaron (Tripp) Josserand-Austin | EntroMorphic Research Team  
> [OSF MetaArXiv](https://doi.org/10.17605/OSF.IO/CNJTP)

---

## Key Discovery: Transient Clustering

**Clustering is scaffolding, not structure.** Networks build geometric organization to *learn* semantic concepts, then discard that organization once the concepts are encoded in the weights.

| Training Phase | R¬≤ | Silhouette | Interpretation |
|----------------|-----|-----------|----------------|
| Early (epoch 0) | 0.36 | -0.02 | Random initialization |
| Learning (epoch 20) | 0.94 | **0.33** | Clustering emerges |
| Final (epoch 200) | 0.99 | -0.02 | Clustering dissolves |

---

## Pipeline Overview

1. **Generate Dataset** - All 512 possible 4-bit + 4-bit additions
2. **Online Training** - Train all three models concurrently (Monolithic, Compositional, Delta Observer)
3. **Trajectory Analysis** - Track R¬≤ and Silhouette during training
4. **Discover Transient Clustering** - Observe clustering emerge then dissolve
5. **Generate Figures** - Reproduce paper visualizations

**Estimated runtime:** ~15 minutes on CPU, ~5 minutes on GPU

---

## Setup & Configuration

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
from sklearn.metrics import silhouette_score, r2_score
from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA
from tqdm import tqdm
import os

# Try UMAP, fall back to PCA
try:
    from umap import UMAP
    HAS_UMAP = True
except ImportError:
    HAS_UMAP = False
    print("UMAP not available, using PCA for visualization")

# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANDOM_SEED = 42
EPOCHS = 200
BATCH_SIZE = 64
LEARNING_RATE = 0.001
LATENT_DIM = 16
SNAPSHOT_INTERVAL = 5  # Save latent space every N epochs

# Set seeds
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Create directories
os.makedirs('../data', exist_ok=True)
os.makedirs('../figures', exist_ok=True)

# Plot style
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['font.size'] = 11

print(f"üñ•Ô∏è  Device: {DEVICE}")
print(f"üîÑ Training for {EPOCHS} epochs")
print(f"üì∏ Snapshot interval: every {SNAPSHOT_INTERVAL} epochs")
print("‚úÖ Configuration complete!")

## Step 1: Generate 4-bit Addition Dataset

We generate all 512 possible 4-bit + 4-bit additions. The key semantic variable is **carry count** (0-4).

In [None]:
def generate_4bit_addition_dataset():
    """Generate all 512 possible 4-bit + 4-bit additions."""
    inputs = []
    outputs = []
    carry_counts = []
    a_values = []
    b_values = []
    
    for a in range(16):
        for b in range(16):
            # Input: [a0, a1, a2, a3, b0, b1, b2, b3]
            a_bits = [(a >> i) & 1 for i in range(4)]
            b_bits = [(b >> i) & 1 for i in range(4)]
            input_bits = a_bits + b_bits
            
            # Output: 5-bit sum
            sum_val = a + b
            output_bits = [(sum_val >> i) & 1 for i in range(5)]
            
            # Carry count
            carry = 0
            count = 0
            for i in range(4):
                bit_sum = a_bits[i] + b_bits[i] + carry
                if bit_sum >= 2:
                    count += 1
                    carry = 1
                else:
                    carry = 0
            
            inputs.append(input_bits)
            outputs.append(output_bits)
            carry_counts.append(count)
            a_values.append(a)
            b_values.append(b)
    
    return (np.array(inputs, dtype=np.float32), 
            np.array(outputs, dtype=np.float32),
            np.array(carry_counts, dtype=np.int64),
            np.array(a_values),
            np.array(b_values))

X, y, carry_counts, a_vals, b_vals = generate_4bit_addition_dataset()
print(f"üìä Dataset: {X.shape[0]} examples")
print(f"üì• Input: {X.shape[1]} bits, Output: {y.shape[1]} bits")
print(f"üî¢ Carry count distribution: {np.bincount(carry_counts)}")

### Visualize the Dataset

In [None]:
fig = plt.figure(figsize=(16, 5))
gs = GridSpec(1, 3, width_ratios=[1, 1.2, 1])

# 1. Carry count distribution
ax1 = fig.add_subplot(gs[0])
counts = np.bincount(carry_counts)
colors = plt.cm.viridis(np.linspace(0, 1, 5))
bars = ax1.bar(range(5), counts, color=colors, edgecolor='black', linewidth=1.5)
ax1.set_xlabel('Carry Count', fontsize=12)
ax1.set_ylabel('Number of Examples', fontsize=12)
ax1.set_title('Carry Count Distribution', fontsize=14, fontweight='bold')
ax1.set_xticks(range(5))
for i, (bar, count) in enumerate(zip(bars, counts)):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 3, 
             str(count), ha='center', fontsize=11, fontweight='bold')

# 2. Carry count heatmap (a vs b)
ax2 = fig.add_subplot(gs[1])
carry_matrix = carry_counts.reshape(16, 16)
im = ax2.imshow(carry_matrix, cmap='viridis', origin='lower')
ax2.set_xlabel('b (second operand)', fontsize=12)
ax2.set_ylabel('a (first operand)', fontsize=12)
ax2.set_title('Carry Count by Operands', fontsize=14, fontweight='bold')
ax2.set_xticks([0, 5, 10, 15])
ax2.set_yticks([0, 5, 10, 15])
cbar = plt.colorbar(im, ax=ax2, label='Carry Count')
cbar.set_ticks([0, 1, 2, 3, 4])

# 3. Example additions
ax3 = fig.add_subplot(gs[2])
ax3.axis('off')
ax3.set_title('Example Additions', fontsize=14, fontweight='bold')

examples = [
    (3, 2, '0011 + 0010 = 00101', 0),   # 3+2=5, no carries
    (7, 1, '0111 + 0001 = 01000', 3),   # 7+1=8, 3 carries
    (15, 15, '1111 + 1111 = 11110', 4), # 15+15=30, 4 carries
    (8, 4, '1000 + 0100 = 01100', 0),   # 8+4=12, no carries
    (9, 7, '1001 + 0111 = 10000', 4),   # 9+7=16, 4 carries
]

text = "   a  +  b  =  sum   carries\n" + "‚îÄ"*35 + "\n"
for a, b, desc, c in examples:
    text += f"  {a:2d} + {b:2d} = {a+b:2d}    ({c} carries)\n"
    text += f"  {desc}\n\n"

ax3.text(0.1, 0.95, text, transform=ax3.transAxes, fontsize=10,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig('../figures/dataset_visualization.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Dataset visualization saved")

In [None]:
class AdditionDataset(Dataset):
    def __init__(self, X, y, carry_counts):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.carry_counts = torch.tensor(carry_counts, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.carry_counts[idx]

dataset = AdditionDataset(X, y, carry_counts)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
full_loader = DataLoader(dataset, batch_size=512, shuffle=False)
print("‚úÖ Dataset ready")

## Step 2: Define Model Architectures

We use two contrasting architectures:
- **Monolithic MLP**: Processes all 8 input bits together
- **Compositional Network**: Processes each bit position separately with carry propagation

In [None]:
class MonolithicMLP(nn.Module):
    """Standard MLP that processes all bits at once."""
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(8, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 5)
        self.hidden_dim = hidden_dim
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        hidden = torch.relu(self.fc2(x))
        out = torch.sigmoid(self.fc3(hidden))
        return out, hidden


class CompositionalNetwork(nn.Module):
    """Modular network with separate per-bit processing."""
    def __init__(self, module_dim=16):
        super().__init__()
        self.bit_modules = nn.ModuleList([
            nn.Sequential(
                nn.Linear(3, module_dim),
                nn.ReLU(),
                nn.Linear(module_dim, module_dim),
                nn.ReLU()
            ) for _ in range(4)
        ])
        self.output = nn.Linear(4 * module_dim, 5)
        self.hidden_dim = 4 * module_dim
    
    def forward(self, x):
        batch_size = x.size(0)
        bit_outputs = []
        carry = torch.zeros(batch_size, 1, device=x.device)
        
        for i in range(4):
            a_bit = x[:, i:i+1]
            b_bit = x[:, i+4:i+5]
            module_input = torch.cat([a_bit, b_bit, carry], dim=1)
            module_output = self.bit_modules[i](module_input)
            bit_outputs.append(module_output)
            carry = torch.sigmoid(module_output[:, :1])
        
        hidden = torch.cat(bit_outputs, dim=1)
        out = torch.sigmoid(self.output(hidden))
        return out, hidden


class DeltaObserver(nn.Module):
    """Learns shared latent space between two architectures."""
    def __init__(self, mono_dim=64, comp_dim=64, latent_dim=16):
        super().__init__()
        self.mono_encoder = nn.Sequential(
            nn.Linear(mono_dim, 32), nn.ReLU(), nn.Dropout(0.1)
        )
        self.comp_encoder = nn.Sequential(
            nn.Linear(comp_dim, 32), nn.ReLU(), nn.Dropout(0.1)
        )
        self.shared_encoder = nn.Sequential(
            nn.Linear(64, 32), nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(32, latent_dim)
        )
        self.mono_decoder = nn.Sequential(
            nn.Linear(latent_dim, 32), nn.ReLU(),
            nn.Linear(32, mono_dim)
        )
        self.comp_decoder = nn.Sequential(
            nn.Linear(latent_dim, 32), nn.ReLU(),
            nn.Linear(32, comp_dim)
        )
        self.carry_head = nn.Sequential(
            nn.Linear(latent_dim, 8), nn.ReLU(),
            nn.Linear(8, 1)
        )
        self.latent_dim = latent_dim
    
    def encode(self, mono_act, comp_act):
        mono_enc = self.mono_encoder(mono_act)
        comp_enc = self.comp_encoder(comp_act)
        joint = torch.cat([mono_enc, comp_enc], dim=-1)
        return self.shared_encoder(joint)
    
    def forward(self, mono_act, comp_act):
        latent = self.encode(mono_act, comp_act)
        return {
            'latent': latent,
            'mono_recon': self.mono_decoder(latent),
            'comp_recon': self.comp_decoder(latent),
            'carry_pred': self.carry_head(latent)
        }

print("‚úÖ Model architectures defined")

### Visualize Architecture Comparison

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 6))

# Helper function to draw boxes
def draw_box(ax, x, y, w, h, text, color='lightblue', fontsize=9):
    rect = mpatches.FancyBboxPatch((x, y), w, h, boxstyle="round,pad=0.02",
                                    facecolor=color, edgecolor='black', linewidth=1.5)
    ax.add_patch(rect)
    ax.text(x + w/2, y + h/2, text, ha='center', va='center', fontsize=fontsize, fontweight='bold')

def draw_arrow(ax, x1, y1, x2, y2):
    ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
                arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))

# 1. Monolithic MLP
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('Monolithic MLP', fontsize=14, fontweight='bold')

draw_box(ax1, 3, 8, 4, 1, 'Input (8 bits)', 'lightgreen')
draw_arrow(ax1, 5, 8, 5, 7)
draw_box(ax1, 2.5, 5.5, 5, 1.2, 'Hidden (64 dim)', 'lightblue')
draw_arrow(ax1, 5, 5.5, 5, 4.5)
draw_box(ax1, 2.5, 3, 5, 1.2, 'Hidden (64 dim)', 'lightblue')
draw_arrow(ax1, 5, 3, 5, 2)
draw_box(ax1, 3, 0.5, 4, 1, 'Output (5 bits)', 'lightyellow')
ax1.text(5, -0.5, 'All bits processed\ntogether', ha='center', fontsize=10, style='italic')

# 2. Compositional Network
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Compositional Network', fontsize=14, fontweight='bold')

# Input bits
for i in range(4):
    draw_box(ax2, 0.5 + i*2.3, 8, 2, 0.8, f'Bit {i}', 'lightgreen', fontsize=8)

# Bit modules
for i in range(4):
    draw_arrow(ax2, 1.5 + i*2.3, 8, 1.5 + i*2.3, 7)
    draw_box(ax2, 0.5 + i*2.3, 5.5, 2, 1.2, f'Module\n{i}', 'lightcoral', fontsize=8)

# Carry arrows
for i in range(3):
    ax2.annotate('', xy=(2.7 + i*2.3, 6.1), xytext=(2.5 + i*2.3, 6.1),
                arrowprops=dict(arrowstyle='->', color='red', lw=2))
ax2.text(5, 6.9, 'carry propagation ‚Üí', ha='center', fontsize=9, color='red')

# Concatenate
for i in range(4):
    draw_arrow(ax2, 1.5 + i*2.3, 5.5, 5, 4)
draw_box(ax2, 2.5, 2.5, 5, 1.2, 'Concat (64 dim)', 'lightblue')
draw_arrow(ax2, 5, 2.5, 5, 1.5)
draw_box(ax2, 3, 0.3, 4, 0.9, 'Output (5 bits)', 'lightyellow')
ax2.text(5, -0.5, 'Each bit position\nprocessed separately', ha='center', fontsize=10, style='italic')

# 3. Delta Observer
ax3 = axes[2]
ax3.set_xlim(0, 10)
ax3.set_ylim(0, 10)
ax3.axis('off')
ax3.set_title('Delta Observer', fontsize=14, fontweight='bold')

draw_box(ax3, 0.5, 8, 3.5, 1, 'Mono Hidden', 'lightblue')
draw_box(ax3, 6, 8, 3.5, 1, 'Comp Hidden', 'lightcoral')
draw_arrow(ax3, 2.25, 8, 3.5, 6.5)
draw_arrow(ax3, 7.75, 8, 6.5, 6.5)
draw_box(ax3, 2, 5.5, 3, 0.8, 'Encoder', 'plum', fontsize=9)
draw_box(ax3, 5, 5.5, 3, 0.8, 'Encoder', 'plum', fontsize=9)
draw_arrow(ax3, 3.5, 5.5, 5, 4.5)
draw_arrow(ax3, 6.5, 5.5, 5, 4.5)
draw_box(ax3, 3, 3.5, 4, 0.9, 'Shared Latent\n(16 dim)', 'gold', fontsize=9)
draw_arrow(ax3, 5, 3.5, 3, 2)
draw_arrow(ax3, 5, 3.5, 7, 2)
draw_arrow(ax3, 5, 3.5, 5, 1.5)
draw_box(ax3, 1.5, 1, 2.5, 0.7, 'Mono\nRecon', 'lightblue', fontsize=8)
draw_box(ax3, 6, 1, 2.5, 0.7, 'Comp\nRecon', 'lightcoral', fontsize=8)
draw_box(ax3, 3.75, 0.3, 2.5, 0.7, 'Carry\nPred', 'lightgreen', fontsize=8)
ax3.text(5, -0.5, 'Maps between\nrepresentation spaces', ha='center', fontsize=10, style='italic')

plt.tight_layout()
plt.savefig('../figures/architecture_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Architecture visualization saved")

## Step 3: Online Training (All Models Concurrently)

**This is the key innovation.** The Delta Observer watches training as it happens, capturing temporal dynamics invisible to post-hoc analysis.

In [None]:
# Initialize models
mono_model = MonolithicMLP(hidden_dim=64).to(DEVICE)
comp_model = CompositionalNetwork(module_dim=16).to(DEVICE)
delta_model = DeltaObserver(mono_dim=64, comp_dim=64, latent_dim=LATENT_DIM).to(DEVICE)

# Optimizers
mono_opt = optim.Adam(mono_model.parameters(), lr=LEARNING_RATE)
comp_opt = optim.Adam(comp_model.parameters(), lr=LEARNING_RATE)
delta_opt = optim.Adam(delta_model.parameters(), lr=LEARNING_RATE)

criterion = nn.BCELoss()

# Trajectory storage
trajectory = {
    'epochs': [],
    'latents': [],
    'carry_counts': [],
    'mono_acc': [],
    'comp_acc': [],
    'r2': [],
    'silhouette': [],
    'mono_loss': [],
    'comp_loss': [],
    'delta_loss': []
}

print("üß† Models initialized")
print(f"   Monolithic: {sum(p.numel() for p in mono_model.parameters()):,} params")
print(f"   Compositional: {sum(p.numel() for p in comp_model.parameters()):,} params")
print(f"   Delta Observer: {sum(p.numel() for p in delta_model.parameters()):,} params")

In [None]:
def compute_metrics(latents, carry_counts):
    """Compute R¬≤ and Silhouette score."""
    reg = LinearRegression()
    reg.fit(latents, carry_counts)
    r2 = r2_score(carry_counts, reg.predict(latents))
    
    try:
        sil = silhouette_score(latents, carry_counts)
    except:
        sil = 0.0
    
    return r2, sil


def compute_accuracy(model, loader):
    """Compute model accuracy."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets, _ in loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs, _ = model(inputs)
            pred_bits = (outputs > 0.5).float()
            correct += (pred_bits == targets).all(dim=1).sum().item()
            total += inputs.size(0)
    return 100 * correct / total


def snapshot_latents(mono_model, comp_model, delta_model, loader, device):
    """Extract latent representations for all samples."""
    mono_model.eval()
    comp_model.eval()
    delta_model.eval()
    
    all_latents = []
    all_carry = []
    
    with torch.no_grad():
        for inputs, _, carry in loader:
            inputs = inputs.to(device)
            _, mono_h = mono_model(inputs)
            _, comp_h = comp_model(inputs)
            latent = delta_model.encode(mono_h, comp_h)
            all_latents.append(latent.cpu().numpy())
            all_carry.append(carry.numpy())
    
    return np.concatenate(all_latents), np.concatenate(all_carry)

In [None]:
print("="*70)
print("üöÄ ONLINE TRAINING - All models train concurrently")
print("="*70)
print("\nThe Delta Observer watches training as it happens...\n")

epoch_losses = {'mono': [], 'comp': [], 'delta': []}

for epoch in tqdm(range(EPOCHS), desc="Training"):
    mono_model.train()
    comp_model.train()
    delta_model.train()
    
    batch_losses = {'mono': [], 'comp': [], 'delta': []}
    
    for inputs, targets, carry in train_loader:
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE)
        carry = carry.to(DEVICE).float()
        
        # --- Train Monolithic ---
        mono_opt.zero_grad()
        mono_out, mono_h = mono_model(inputs)
        mono_loss = criterion(mono_out, targets)
        mono_loss.backward()
        mono_opt.step()
        batch_losses['mono'].append(mono_loss.item())
        
        # --- Train Compositional ---
        comp_opt.zero_grad()
        comp_out, comp_h = comp_model(inputs)
        comp_loss = criterion(comp_out, targets)
        comp_loss.backward()
        comp_opt.step()
        batch_losses['comp'].append(comp_loss.item())
        
        # --- Train Delta Observer (detached activations) ---
        with torch.no_grad():
            _, mono_h_det = mono_model(inputs)
            _, comp_h_det = comp_model(inputs)
        
        delta_opt.zero_grad()
        delta_out = delta_model(mono_h_det.detach(), comp_h_det.detach())
        
        recon_loss = (nn.functional.mse_loss(delta_out['mono_recon'], mono_h_det.detach()) +
                      nn.functional.mse_loss(delta_out['comp_recon'], comp_h_det.detach()))
        carry_loss = nn.functional.mse_loss(delta_out['carry_pred'].squeeze(), carry)
        delta_loss = recon_loss + 0.1 * carry_loss
        delta_loss.backward()
        delta_opt.step()
        batch_losses['delta'].append(delta_loss.item())
    
    # Store epoch losses
    epoch_losses['mono'].append(np.mean(batch_losses['mono']))
    epoch_losses['comp'].append(np.mean(batch_losses['comp']))
    epoch_losses['delta'].append(np.mean(batch_losses['delta']))
    
    # --- Snapshot at intervals ---
    if epoch % SNAPSHOT_INTERVAL == 0 or epoch == EPOCHS - 1:
        latents, carries = snapshot_latents(mono_model, comp_model, delta_model, full_loader, DEVICE)
        r2, sil = compute_metrics(latents, carries)
        mono_acc = compute_accuracy(mono_model, full_loader)
        comp_acc = compute_accuracy(comp_model, full_loader)
        
        trajectory['epochs'].append(epoch)
        trajectory['latents'].append(latents.copy())
        trajectory['carry_counts'].append(carries.copy())
        trajectory['r2'].append(r2)
        trajectory['silhouette'].append(sil)
        trajectory['mono_acc'].append(mono_acc)
        trajectory['comp_acc'].append(comp_acc)
        trajectory['mono_loss'].append(epoch_losses['mono'][-1])
        trajectory['comp_loss'].append(epoch_losses['comp'][-1])
        trajectory['delta_loss'].append(epoch_losses['delta'][-1])
        
        if epoch % 20 == 0:
            print(f"\nEpoch {epoch:3d}: R¬≤={r2:.4f}, Sil={sil:.4f}, Mono={mono_acc:.1f}%, Comp={comp_acc:.1f}%")

print("\n" + "="*70)
print("‚úÖ Online training complete!")
print("="*70)

### Visualize Training Progress

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs_arr = np.array(trajectory['epochs'])

# 1. Model Accuracy
ax1 = axes[0, 0]
ax1.plot(epochs_arr, trajectory['mono_acc'], 'b-', linewidth=2, marker='o', markersize=3, label='Monolithic')
ax1.plot(epochs_arr, trajectory['comp_acc'], 'r-', linewidth=2, marker='s', markersize=3, label='Compositional')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.set_title('Model Accuracy During Training', fontsize=14, fontweight='bold')
ax1.legend(loc='lower right')
ax1.set_ylim(0, 105)
ax1.axhline(y=100, color='green', linestyle='--', alpha=0.3)
ax1.grid(True, alpha=0.3)

# 2. Training Loss
ax2 = axes[0, 1]
ax2.semilogy(range(EPOCHS), epoch_losses['mono'], 'b-', alpha=0.7, label='Monolithic')
ax2.semilogy(range(EPOCHS), epoch_losses['comp'], 'r-', alpha=0.7, label='Compositional')
ax2.semilogy(range(EPOCHS), epoch_losses['delta'], 'g-', alpha=0.7, label='Delta Observer')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Loss (log scale)', fontsize=12)
ax2.set_title('Training Loss', fontsize=14, fontweight='bold')
ax2.legend(loc='upper right')
ax2.grid(True, alpha=0.3)

# 3. R¬≤ Evolution
ax3 = axes[1, 0]
ax3.plot(epochs_arr, trajectory['r2'], 'g-', linewidth=2.5, marker='o', markersize=4)
ax3.fill_between(epochs_arr, 0, trajectory['r2'], alpha=0.2, color='green')
ax3.set_xlabel('Epoch', fontsize=12)
ax3.set_ylabel('R¬≤ (Linear Accessibility)', fontsize=12)
ax3.set_title('Semantic Accessibility During Training', fontsize=14, fontweight='bold')
ax3.set_ylim(0, 1.05)
ax3.axhline(y=0.9, color='green', linestyle='--', alpha=0.5, label='90% threshold')
ax3.grid(True, alpha=0.3)
ax3.legend()

# 4. Silhouette Evolution (Transient Clustering)
ax4 = axes[1, 1]
ax4.plot(epochs_arr, trajectory['silhouette'], 'r-', linewidth=2.5, marker='s', markersize=4)
ax4.fill_between(epochs_arr, 0, trajectory['silhouette'], where=np.array(trajectory['silhouette'])>0, 
                 alpha=0.3, color='red', label='Positive clustering')
ax4.fill_between(epochs_arr, 0, trajectory['silhouette'], where=np.array(trajectory['silhouette'])<=0,
                 alpha=0.3, color='blue', label='No clustering')
peak_idx = np.argmax(trajectory['silhouette'])
ax4.annotate(f'Peak: {trajectory["silhouette"][peak_idx]:.2f}', 
             xy=(epochs_arr[peak_idx], trajectory['silhouette'][peak_idx]),
             xytext=(epochs_arr[peak_idx]+30, trajectory['silhouette'][peak_idx]+0.05),
             arrowprops=dict(arrowstyle='->', color='red'),
             fontsize=10, color='red')
ax4.set_xlabel('Epoch', fontsize=12)
ax4.set_ylabel('Silhouette Score', fontsize=12)
ax4.set_title('Geometric Clustering (Transient!)', fontsize=14, fontweight='bold')
ax4.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
ax4.grid(True, alpha=0.3)
ax4.legend(loc='upper right')

plt.tight_layout()
plt.savefig('../figures/training_progress.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Training progress visualization saved")

## Step 4: Discover Transient Clustering

**The key finding:** Clustering peaks during training then dissolves.

In [None]:
epochs = np.array(trajectory['epochs'])
r2_values = np.array(trajectory['r2'])
sil_values = np.array(trajectory['silhouette'])

# Find peak clustering
peak_idx = np.argmax(sil_values)
peak_epoch = epochs[peak_idx]
peak_sil = sil_values[peak_idx]

print("="*70)
print("üéØ TRANSIENT CLUSTERING DISCOVERY")
print("="*70)
print(f"\nüìà Peak clustering: Silhouette = {peak_sil:.4f} at epoch {peak_epoch}")
print(f"üìâ Final state:     Silhouette = {sil_values[-1]:.4f} at epoch {epochs[-1]}")
print(f"\n‚ú® Final R¬≤ (accessibility): {r2_values[-1]:.4f}")
print("\n" + "-"*70)
print("üí° INTERPRETATION")
print("-"*70)
print("\n  ‚Ä¢ Clustering EMERGES during learning (scaffolding)")
print("  ‚Ä¢ Clustering DISSOLVES after convergence (scaffolding removed)")
print("\n  ‚Üí The semantic primitive is in the TRAJECTORY, not the final state.")
print("="*70)

### Visualize Latent Space Evolution (Before ‚Üí During ‚Üí After Clustering)

In [None]:
# Select three key epochs: early, peak clustering, final
early_idx = 0
peak_idx = np.argmax(sil_values)
final_idx = -1

key_epochs = [
    (early_idx, 'Early (Random)', epochs[early_idx]),
    (peak_idx, 'Peak Clustering', epochs[peak_idx]),
    (final_idx, 'Final (Converged)', epochs[final_idx])
]

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for ax, (idx, title, ep) in zip(axes, key_epochs):
    latents = trajectory['latents'][idx]
    carries = trajectory['carry_counts'][idx]
    
    # Use PCA for consistent comparison
    pca = PCA(n_components=2, random_state=RANDOM_SEED)
    latents_2d = pca.fit_transform(latents)
    
    r2, sil = compute_metrics(latents, carries)
    
    scatter = ax.scatter(latents_2d[:, 0], latents_2d[:, 1],
                        c=carries, cmap='viridis',
                        s=30, alpha=0.7, edgecolors='white', linewidth=0.3)
    
    ax.set_xlabel('PC1', fontsize=11)
    ax.set_ylabel('PC2', fontsize=11)
    ax.set_title(f'{title}\nEpoch {ep}', fontsize=13, fontweight='bold')
    
    # Metrics box
    textstr = f'R¬≤ = {r2:.3f}\nSil = {sil:.3f}'
    props = dict(boxstyle='round', facecolor='white', alpha=0.8)
    ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=props)

# Add colorbar to last plot
cbar = plt.colorbar(scatter, ax=axes[-1], label='Carry Count')
cbar.set_ticks([0, 1, 2, 3, 4])

plt.suptitle('Latent Space Evolution: Clustering is Transient', fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('../figures/latent_evolution.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Latent evolution visualization saved")

### The Key Figure: Transient Clustering

In [None]:
# Plot: Transient Clustering
fig, ax1 = plt.subplots(figsize=(12, 6))

color1 = '#2ecc71'  # Green for R¬≤
color2 = '#e74c3c'  # Red for Silhouette

ax1.set_xlabel('Training Epoch', fontsize=12)
ax1.set_ylabel('R¬≤ (Linear Accessibility)', color=color1, fontsize=12)
line1, = ax1.plot(epochs, r2_values, color=color1, linewidth=2.5, marker='o', markersize=4, label='R¬≤')
ax1.tick_params(axis='y', labelcolor=color1)
ax1.set_ylim(0, 1.05)
ax1.axhline(y=0.9, color=color1, linestyle='--', alpha=0.3)

ax2 = ax1.twinx()
ax2.set_ylabel('Silhouette Score (Clustering)', color=color2, fontsize=12)
line2, = ax2.plot(epochs, sil_values, color=color2, linewidth=2.5, marker='s', markersize=4, label='Silhouette')
ax2.tick_params(axis='y', labelcolor=color2)
ax2.set_ylim(-0.1, 0.5)
ax2.axhline(y=0, color=color2, linestyle='--', alpha=0.3)

# Highlight phases
ax1.axvspan(0, 10, alpha=0.1, color='blue', label='Init')
ax1.axvspan(10, 50, alpha=0.1, color='green', label='Learning')
ax1.axvspan(50, EPOCHS, alpha=0.1, color='orange', label='Converged')

# Annotate peak
ax2.annotate(f'Peak: {peak_sil:.2f}\n(epoch {peak_epoch})',
             xy=(peak_epoch, peak_sil),
             xytext=(peak_epoch + 30, peak_sil + 0.08),
             fontsize=10,
             arrowprops=dict(arrowstyle='->', color=color2, alpha=0.7),
             color=color2)

ax1.set_title('Transient Clustering: Scaffolding Emerges Then Dissolves', fontsize=14, fontweight='bold')
ax1.legend([line1, line2], ['R¬≤ (Accessibility)', 'Silhouette (Clustering)'], loc='center right')

# Phase labels
ax1.text(5, 0.15, 'Init', ha='center', fontsize=10, color='blue', alpha=0.7)
ax1.text(30, 0.15, 'Learning', ha='center', fontsize=10, color='green', alpha=0.7)
ax1.text(125, 0.15, 'Converged', ha='center', fontsize=10, color='orange', alpha=0.7)

plt.tight_layout()
plt.savefig('../figures/figure5_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Figure 5 (Training Curves) saved")

## Step 5: Visualize Final Latent Space

In [None]:
# Get final latent space
final_latents = trajectory['latents'][-1]
final_carry = trajectory['carry_counts'][-1]

# Dimensionality reduction
if HAS_UMAP:
    reducer = UMAP(n_components=2, random_state=RANDOM_SEED, n_neighbors=15, min_dist=0.1)
    latents_2d = reducer.fit_transform(final_latents)
    method = "UMAP"
else:
    reducer = PCA(n_components=2, random_state=RANDOM_SEED)
    latents_2d = reducer.fit_transform(final_latents)
    method = "PCA"

# Final metrics
final_r2, final_sil = compute_metrics(final_latents, final_carry)

# Plot
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(latents_2d[:, 0], latents_2d[:, 1],
                     c=final_carry, cmap='viridis',
                     s=50, alpha=0.7, edgecolors='white', linewidth=0.5)

cbar = plt.colorbar(scatter, ax=ax, label='Carry Count')
cbar.set_ticks([0, 1, 2, 3, 4])

ax.set_xlabel(f'{method} Dimension 1', fontsize=12)
ax.set_ylabel(f'{method} Dimension 2', fontsize=12)
ax.set_title('Online Delta Observer Latent Space\n(Final State)', fontsize=14, fontweight='bold')

# Add metrics
textstr = f'R¬≤ = {final_r2:.4f}\nSilhouette = {final_sil:.4f}'
props = dict(boxstyle='round', facecolor='white', alpha=0.8)
ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', bbox=props)

plt.tight_layout()
plt.savefig('../figures/figure2_delta_latent_space.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Figure 2 (Latent Space) saved")

### Latent Dimension Analysis

In [None]:
# Analyze which latent dimensions correlate with carry count
correlations = []
for i in range(LATENT_DIM):
    corr = np.corrcoef(final_latents[:, i], final_carry)[0, 1]
    correlations.append(corr)

correlations = np.array(correlations)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 1. Correlation bar chart
ax1 = axes[0]
colors = ['#e74c3c' if c < 0 else '#2ecc71' for c in correlations]
ax1.bar(range(LATENT_DIM), correlations, color=colors, edgecolor='black', linewidth=0.5)
ax1.axhline(y=0, color='black', linewidth=0.5)
ax1.set_xlabel('Latent Dimension', fontsize=12)
ax1.set_ylabel('Correlation with Carry Count', fontsize=12)
ax1.set_title('Latent Dimension Correlations', fontsize=14, fontweight='bold')
ax1.set_xticks(range(LATENT_DIM))

# 2. Top dimensions scatter plots
ax2 = axes[1]
top_dims = np.argsort(np.abs(correlations))[-2:]  # Top 2 correlated dims
scatter = ax2.scatter(final_latents[:, top_dims[0]], final_latents[:, top_dims[1]],
                      c=final_carry, cmap='viridis', s=30, alpha=0.7)
ax2.set_xlabel(f'Dim {top_dims[0]} (corr={correlations[top_dims[0]]:.2f})', fontsize=12)
ax2.set_ylabel(f'Dim {top_dims[1]} (corr={correlations[top_dims[1]]:.2f})', fontsize=12)
ax2.set_title('Top Correlated Dimensions', fontsize=14, fontweight='bold')
plt.colorbar(scatter, ax=ax2, label='Carry Count')

plt.tight_layout()
plt.savefig('../figures/latent_dimensions.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Latent dimension analysis saved")

## Step 6: Compare Methods

In [None]:
# Get activations for PCA baseline comparison
mono_model.eval()
comp_model.eval()

with torch.no_grad():
    all_inputs = torch.tensor(X, dtype=torch.float32).to(DEVICE)
    _, mono_h = mono_model(all_inputs)
    _, comp_h = comp_model(all_inputs)
    mono_act = mono_h.cpu().numpy()
    comp_act = comp_h.cpu().numpy()

# PCA baseline
combined = np.concatenate([mono_act, comp_act], axis=1)
pca = PCA(n_components=LATENT_DIM, random_state=RANDOM_SEED)
pca_latents = pca.fit_transform(combined)
pca_r2, pca_sil = compute_metrics(pca_latents, carry_counts)

print("="*70)
print("üìä METHOD COMPARISON")
print("="*70)
print(f"\n{'Method':<25} {'R¬≤':>10} {'Silhouette':>12} {'Œî vs PCA':>10}")
print("-"*60)
print(f"{'Online Observer':<25} {final_r2:>10.4f} {final_sil:>12.4f} {(final_r2-pca_r2)*100:>+9.1f}%")
print(f"{'PCA Baseline':<25} {pca_r2:>10.4f} {pca_sil:>12.4f} {'---':>10}")
print("="*70)

In [None]:
# Comprehensive comparison visualization
fig = plt.figure(figsize=(16, 5))
gs = GridSpec(1, 3, width_ratios=[1, 1, 1.2])

# 1. Bar chart comparison
ax1 = fig.add_subplot(gs[0])
methods = ['Online\nObserver', 'PCA\nBaseline']
r2_vals = [final_r2, pca_r2]
colors = ['#3498db', '#95a5a6']

bars = ax1.bar(methods, r2_vals, color=colors, edgecolor='black', linewidth=1.5)
for bar, val in zip(bars, r2_vals):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
            f'{val:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
ax1.set_ylabel('R¬≤ (Linear Accessibility)', fontsize=12)
ax1.set_title('R¬≤ Comparison', fontsize=14, fontweight='bold')
ax1.set_ylim(0.85, 1.01)
delta_pct = (final_r2 - pca_r2) * 100
ax1.annotate(f'+{delta_pct:.1f}%', xy=(0.5, 0.96), fontsize=14, fontweight='bold', 
             color='green', ha='center')

# 2. Silhouette comparison
ax2 = fig.add_subplot(gs[1])
sil_vals = [final_sil, pca_sil]
bars = ax2.bar(methods, sil_vals, color=colors, edgecolor='black', linewidth=1.5)
for bar, val in zip(bars, sil_vals):
    ax2.text(bar.get_x() + bar.get_width()/2, max(0.01, bar.get_height() + 0.01),
            f'{val:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')
ax2.set_ylabel('Silhouette Score', fontsize=12)
ax2.set_title('Clustering Comparison', fontsize=14, fontweight='bold')
ax2.axhline(y=0, color='gray', linestyle='--')
ax2.set_ylim(-0.1, 0.15)

# 3. Scatter plot comparison
ax3 = fig.add_subplot(gs[2])
ax3.scatter(pca_sil, pca_r2, s=200, c='#95a5a6', edgecolors='black', linewidth=2, 
            label='PCA Baseline', zorder=5)
ax3.scatter(final_sil, final_r2, s=200, c='#3498db', edgecolors='black', linewidth=2,
            label='Online Observer', zorder=5)
ax3.set_xlabel('Silhouette (Clustering)', fontsize=12)
ax3.set_ylabel('R¬≤ (Accessibility)', fontsize=12)
ax3.set_title('Accessibility vs Clustering', fontsize=14, fontweight='bold')
ax3.axhline(y=0.95, color='green', linestyle='--', alpha=0.5)
ax3.axvline(x=0.1, color='red', linestyle='--', alpha=0.5)
ax3.legend(loc='lower right')
ax3.set_xlim(-0.15, 0.2)
ax3.set_ylim(0.85, 1.01)

# Highlight "good" region
ax3.fill_between([-0.15, 0.1], 0.95, 1.01, alpha=0.1, color='green')
ax3.text(-0.02, 0.98, 'High Access\nLow Cluster', fontsize=9, color='green', ha='center')

plt.tight_layout()
plt.savefig('../figures/figure_method_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Method comparison visualization saved")

## Step 7: Save Results

In [None]:
# Save trajectory data
np.savez('../data/online_observer_trajectory.npz',
         snapshots=np.array(trajectory['latents']),
         epochs=np.array(trajectory['epochs']))

# Save final latents
np.savez('../data/online_observer_latents.npz',
         latents=final_latents,
         carry_counts=final_carry,
         mono_activations=mono_act,
         comp_activations=comp_act,
         bit_positions=np.zeros_like(final_carry))  # placeholder

print("‚úÖ Data saved to ../data/")

## Final Summary

In [None]:
# Create summary visualization
fig = plt.figure(figsize=(16, 10))
gs = GridSpec(2, 3, height_ratios=[1, 1])

# 1. Key finding: Transient Clustering
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(epochs, r2_values, 'g-', linewidth=3, marker='o', markersize=5, label='R¬≤ (Accessibility)')
ax1.plot(epochs, sil_values, 'r-', linewidth=3, marker='s', markersize=5, label='Silhouette (Clustering)')
ax1.axhline(y=0, color='gray', linestyle='-', alpha=0.3)
ax1.fill_between(epochs, 0, sil_values, where=np.array(sil_values)>0, alpha=0.2, color='red')
ax1.set_xlabel('Training Epoch', fontsize=12)
ax1.set_ylabel('Metric Value', fontsize=12)
ax1.set_title('KEY FINDING: Clustering is Transient (Scaffolding, Not Structure)', 
              fontsize=16, fontweight='bold')
ax1.legend(loc='center right', fontsize=11)
ax1.set_ylim(-0.1, 1.05)

# Add phase annotations
ax1.annotate('', xy=(peak_epoch, peak_sil), xytext=(peak_epoch, peak_sil+0.15),
            arrowprops=dict(arrowstyle='->', color='red', lw=2))
ax1.text(peak_epoch, peak_sil+0.18, f'PEAK\n{peak_sil:.2f}', ha='center', fontsize=10, 
         color='red', fontweight='bold')

# 2-4. Three phase snapshots
for i, (idx, title) in enumerate([(0, 'Phase 1: Init'), (peak_idx, 'Phase 2: Scaffolding'), (-1, 'Phase 3: Final')]):
    ax = fig.add_subplot(gs[1, i])
    latents = trajectory['latents'][idx]
    carries = trajectory['carry_counts'][idx]
    pca_temp = PCA(n_components=2, random_state=RANDOM_SEED)
    latents_2d = pca_temp.fit_transform(latents)
    r2_temp, sil_temp = compute_metrics(latents, carries)
    
    scatter = ax.scatter(latents_2d[:, 0], latents_2d[:, 1], c=carries, cmap='viridis',
                        s=25, alpha=0.7, edgecolors='white', linewidth=0.3)
    ax.set_title(f'{title}\nEpoch {epochs[idx]}', fontsize=12, fontweight='bold')
    ax.text(0.02, 0.98, f'R¬≤={r2_temp:.2f}\nSil={sil_temp:.2f}', transform=ax.transAxes,
           fontsize=10, verticalalignment='top', 
           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.savefig('../figures/summary_visualization.png', dpi=150, bbox_inches='tight')
plt.show()
print("‚úÖ Summary visualization saved")

In [None]:
print("\n" + "="*70)
print("üéâ REPRODUCTION COMPLETE")
print("="*70)

print("\nüìä MODELS TRAINED (Online, Concurrently)")
print(f"   Monolithic MLP: {trajectory['mono_acc'][-1]:.1f}% accuracy")
print(f"   Compositional Network: {trajectory['comp_acc'][-1]:.1f}% accuracy")

print("\nüéØ KEY DISCOVERY: TRANSIENT CLUSTERING")
print(f"   Peak clustering: Silhouette = {peak_sil:.4f} at epoch {peak_epoch}")
print(f"   Final state:     Silhouette = {final_sil:.4f}")
print(f"   Final R¬≤:        {final_r2:.4f}")

print("\nüìà METHOD COMPARISON")
print(f"   Online Observer: R¬≤ = {final_r2:.4f}")
print(f"   PCA Baseline:    R¬≤ = {pca_r2:.4f}")
print(f"   Improvement:     +{(final_r2-pca_r2)*100:.1f}%")

print("\nüìÅ FILES GENERATED")
print("   Data:")
print("     ‚Ä¢ data/online_observer_trajectory.npz")
print("     ‚Ä¢ data/online_observer_latents.npz")
print("   Figures:")
print("     ‚Ä¢ figures/dataset_visualization.png")
print("     ‚Ä¢ figures/architecture_comparison.png")
print("     ‚Ä¢ figures/training_progress.png")
print("     ‚Ä¢ figures/latent_evolution.png")
print("     ‚Ä¢ figures/figure5_training_curves.png")
print("     ‚Ä¢ figures/figure2_delta_latent_space.png")
print("     ‚Ä¢ figures/latent_dimensions.png")
print("     ‚Ä¢ figures/figure_method_comparison.png")
print("     ‚Ä¢ figures/summary_visualization.png")

print("\n" + "="*70)
print("üí° KEY INSIGHT")
print("="*70)
print("\n   CLUSTERING IS SCAFFOLDING, NOT STRUCTURE.")
print("\n   Networks build geometric organization to LEARN,")
print("   then DISCARD it once concepts are encoded in weights.")
print("\n   The semantic primitive is in the TRAJECTORY,")
print("   not the final representation.")
print("\n" + "="*70)
print("\n‚úÖ All results successfully reproduced!")