# Part 2: Training the Delta Observer

This notebook trains the **Delta Observer** network that learns to map between monolithic and compositional representations.

## Architecture

```
Monolithic (64D) ──→ Encoder ──┐
                                ├──→ Shared Latent (16D) ──→ Decoders
Compositional (64D) ─→ Encoder ──┘
```

The 16D latent space learns the **semantic primitive** that distinguishes these representations.

---

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
sys.path.append('..')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Load Activations and Prepare Dataset

Load activations from both models and compute semantic labels.

In [None]:
# Load activations with smart path detection
import os

# Try multiple paths (Colab vs local)
paths = ['../data', 'data', 'delta-observer/data']
data_dir = next((p for p in paths if os.path.exists(os.path.join(p, 'monolithic_activations.npz'))), None)

if not data_dir:
    raise FileNotFoundError('Activation data not found. Please run notebook 01 first to generate the data files.')

print(f'✓ Loading data from: {data_dir}')

mono_data = np.load(os.path.join(data_dir, 'monolithic_activations.npz'))
comp_data = np.load(os.path.join(data_dir, 'compositional_activations.npz'))

mono_activations = mono_data['layer3_post']  # Use final layer activations
comp_activations = comp_data['bit3_layer2_post']  # Use bit 3 final layer (16D)
inputs = mono_data['inputs']

print(f'Monolithic activations: {mono_activations.shape}')
print(f'Compositional activations: {comp_activations.shape}')
print(f'Inputs: {inputs.shape}')

In [None]:
# Compute semantic labels
def compute_carry_count(input_bits):
    """Count number of carry operations in 4-bit addition."""
    carry_count = 0
    carry = 0
    for i in range(4):
        bit_sum = int(input_bits[i]) + int(input_bits[i+4]) + carry
        if bit_sum >= 2:
            carry_count += 1
            carry = 1
        else:
            carry = 0
    return carry_count

def compute_bit_position(input_bits):
    """Determine which bit position has the first carry."""
    carry = 0
    for i in range(4):
        bit_sum = int(input_bits[i]) + int(input_bits[i+4]) + carry
        if bit_sum >= 2:
            return i
        carry = 1 if bit_sum >= 2 else 0
    return 0

# Compute labels
carry_counts = np.array([compute_carry_count(inp) for inp in inputs])
bit_positions = np.array([compute_bit_position(inp) for inp in inputs])

print(f"\nCarry count distribution: {np.bincount(carry_counts)}")
print(f"Bit position distribution: {np.bincount(bit_positions)}")

# Save complete dataset
np.savez('../data/delta_observer_dataset.npz',
         mono_activations=mono_activations,
         comp_activations=comp_activations,
         inputs=inputs,
         carry_counts=carry_counts,
         bit_positions=bit_positions)

print("\nDataset saved!")

## Dataset Class

In [None]:
class DeltaObserverDataset(Dataset):
    def __init__(self, data_path):
        data = np.load(data_path)
        self.mono_act = torch.tensor(data['mono_activations'], dtype=torch.float32)
        self.comp_act = torch.tensor(data['comp_activations'], dtype=torch.float32)
        self.carry_counts = torch.tensor(data['carry_counts'], dtype=torch.long)
        self.bit_positions = torch.tensor(data['bit_positions'], dtype=torch.long)
        self.inputs = torch.tensor(data['inputs'], dtype=torch.float32)
    
    def __len__(self):
        return len(self.mono_act)
    
    def __getitem__(self, idx):
        return {
            'mono_act': self.mono_act[idx],
            'comp_act': self.comp_act[idx],
            'carry_count': self.carry_counts[idx],
            'bit_position': self.bit_positions[idx],
            'input': self.inputs[idx],
        }

dataset = DeltaObserverDataset('../data/delta_observer_dataset.npz')
print(f"Dataset size: {len(dataset)}")

# Split 80/20
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

## Delta Observer Architecture

In [None]:
class DeltaObserver(nn.Module):
    def __init__(self, mono_dim=64, comp_dim=64, latent_dim=16):
        super().__init__()
        
        # Dual encoders
        self.mono_encoder = nn.Sequential(
            nn.Linear(mono_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
        
        self.comp_encoder = nn.Sequential(
            nn.Linear(comp_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
        
        # Shared latent encoder
        self.shared_encoder = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, latent_dim),
        )
        
        # Decoders
        self.mono_decoder = nn.Sequential(
            nn.Linear(latent_dim, 32),
            nn.ReLU(),
            nn.Linear(32, mono_dim),
        )
        
        self.comp_decoder = nn.Sequential(
            nn.Linear(latent_dim, 32),
            nn.ReLU(),
            nn.Linear(32, comp_dim),
        )
        
        # Classifiers
        self.bit_classifier = nn.Sequential(
            nn.Linear(latent_dim, 8),
            nn.ReLU(),
            nn.Linear(8, 4),
        )
        
        self.carry_regressor = nn.Sequential(
            nn.Linear(latent_dim, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
        )
        
        self.latent_dim = latent_dim
    
    def encode(self, mono_act, comp_act):
        mono_enc = self.mono_encoder(mono_act)
        comp_enc = self.comp_encoder(comp_act)
        joint = torch.cat([mono_enc, comp_enc], dim=-1)
        return self.shared_encoder(joint)
    
    def forward(self, mono_act, comp_act):
        latent = self.encode(mono_act, comp_act)
        mono_recon = self.mono_decoder(latent)
        comp_recon = self.comp_decoder(latent)
        bit_logits = self.bit_classifier(latent)
        carry_pred = self.carry_regressor(latent)
        
        return {
            'latent': latent,
            'mono_recon': mono_recon,
            'comp_recon': comp_recon,
            'bit_logits': bit_logits,
            'carry_pred': carry_pred,
        }

model = DeltaObserver(mono_dim=64, comp_dim=16, latent_dim=16).to(device)  # comp_dim=16 for compositional
print(f"Parameters: {sum(p.numel() for p in model.parameters())}")

## Training

In [None]:
epochs = 100
lr = 0.001

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_loss = float('inf')

print("Training Delta Observer...\n")

for epoch in tqdm(range(epochs)):
    # Training
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    
    for batch in train_loader:
        mono_act = batch['mono_act'].to(device)
        comp_act = batch['comp_act'].to(device)
        bit_position = batch['bit_position'].to(device)
        carry_count = batch['carry_count'].to(device).float()
        
        optimizer.zero_grad()
        outputs = model(mono_act, comp_act)
        
        # Losses
        recon_loss = F.mse_loss(outputs['mono_recon'], mono_act) + F.mse_loss(outputs['comp_recon'], comp_act)
        class_loss = F.cross_entropy(outputs['bit_logits'], bit_position)
        carry_loss = F.mse_loss(outputs['carry_pred'].squeeze(), carry_count)
        
        loss = recon_loss + class_loss + 0.1 * carry_loss
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = outputs['bit_logits'].argmax(dim=1)
        train_correct += (pred == bit_position).sum().item()
        train_total += bit_position.size(0)
    
    train_loss /= len(train_loader)
    train_acc = train_correct / train_total
    
    # Validation
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for batch in val_loader:
            mono_act = batch['mono_act'].to(device)
            comp_act = batch['comp_act'].to(device)
            bit_position = batch['bit_position'].to(device)
            carry_count = batch['carry_count'].to(device).float()
            
            outputs = model(mono_act, comp_act)
            
            recon_loss = F.mse_loss(outputs['mono_recon'], mono_act) + F.mse_loss(outputs['comp_recon'], comp_act)
            class_loss = F.cross_entropy(outputs['bit_logits'], bit_position)
            carry_loss = F.mse_loss(outputs['carry_pred'].squeeze(), carry_count)
            
            loss = recon_loss + class_loss + 0.1 * carry_loss
            val_loss += loss.item()
            
            pred = outputs['bit_logits'].argmax(dim=1)
            val_correct += (pred == bit_position).sum().item()
            val_total += bit_position.size(0)
    
    val_loss /= len(val_loader)
    val_acc = val_correct / val_total
    
    scheduler.step()
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), '../models/delta_observer_best.pt')
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")

print(f"\nBest Val Loss: {best_val_loss:.4f}")

## Extract Latent Space

In [None]:
model.load_state_dict(torch.load('../models/delta_observer_best.pt'))
model.eval()

full_loader = DataLoader(dataset, batch_size=64, shuffle=False)
all_latents = []
all_carry = []
all_bits = []
all_inputs = []

with torch.no_grad():
    for batch in full_loader:
        latent = model.encode(batch['mono_act'].to(device), batch['comp_act'].to(device))
        all_latents.append(latent.cpu().numpy())
        all_carry.append(batch['carry_count'].numpy())
        all_bits.append(batch['bit_position'].numpy())
        all_inputs.append(batch['input'].numpy())

latent_space = np.concatenate(all_latents)
carry_counts = np.concatenate(all_carry)
bit_positions = np.concatenate(all_bits)
inputs = np.concatenate(all_inputs)

np.savez('../data/delta_latent_umap.npz',
         latent_space=latent_space,
         carry_counts=carry_counts,
         bit_positions=bit_positions,
         inputs=inputs)

print(f"Latent space: {latent_space.shape}")
print("Saved to ../data/delta_latent_umap.npz")

## Next Steps

Continue to **`03_analysis_visualization.ipynb`** for geometric analysis.