# Part 1: Training Source Models

This notebook trains two neural networks to solve 4-bit binary addition:
1. **Monolithic MLP** - Single dense network
2. **Compositional Modular Network** - Bit-wise processing modules

Both achieve 100% accuracy but learn fundamentally different internal representations.

---

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import sys
sys.path.append('..')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Setup paths for Colab/cloud environments
import os

# Create necessary directories
os.makedirs('../models', exist_ok=True)
os.makedirs('../data', exist_ok=True)
os.makedirs('../figures', exist_ok=True)

print("✓ Directories created")

## Generate 4-bit Addition Dataset

All possible 4-bit + 4-bit additions (512 examples)

In [None]:
def generate_4bit_addition_dataset():
    """Generate all 512 possible 4-bit + 4-bit additions."""
    inputs = []
    outputs = []
    
    for a in range(16):  # 4-bit: 0-15
        for b in range(16):
            # Convert to binary (4 bits each)
            a_bits = [(a >> i) & 1 for i in range(4)]
            b_bits = [(b >> i) & 1 for i in range(4)]
            
            # Concatenate: [a0, a1, a2, a3, b0, b1, b2, b3]
            input_bits = a_bits + b_bits
            
            # Output: 5-bit sum (0-30)
            sum_val = a + b
            output_bits = [(sum_val >> i) & 1 for i in range(5)]
            
            inputs.append(input_bits)
            outputs.append(output_bits)
    
    return np.array(inputs, dtype=np.float32), np.array(outputs, dtype=np.float32)

# Generate dataset
X, y = generate_4bit_addition_dataset()
print(f"Dataset shape: X={X.shape}, y={y.shape}")
print(f"Example: {X[0]} + {X[1]} = {y[0]}")

In [None]:
class AdditionDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Create dataloader
dataset = AdditionDataset(X, y)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

## Model 1: Monolithic MLP

Simple feed-forward network: 8 → 64 → 64 → 5

In [None]:
class MonolithicMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(8, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 5)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        hidden = torch.relu(self.fc2(x))  # This is what we'll extract
        x = torch.sigmoid(self.fc3(hidden))
        return x, hidden

mono_model = MonolithicMLP().to(device)
print(f"Monolithic parameters: {sum(p.numel() for p in mono_model.parameters())}")

In [None]:
# Training function
def train_model(model, train_loader, epochs=100, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCELoss()
    
    losses = []
    accuracies = []
    
    for epoch in tqdm(range(epochs), desc="Training"):
        model.train()
        epoch_loss = 0
        correct = 0
        total = 0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            # Accuracy: all 5 bits must match
            pred_bits = (outputs > 0.5).float()
            correct += (pred_bits == targets).all(dim=1).sum().item()
            total += inputs.size(0)
        
        losses.append(epoch_loss / len(train_loader))
        accuracies.append(100 * correct / total)
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}: Loss={losses[-1]:.4f}, Acc={accuracies[-1]:.2f}%")
    
    return losses, accuracies

In [None]:
# Train monolithic model
print("Training Monolithic MLP...")
mono_losses, mono_accs = train_model(mono_model, train_loader, epochs=100)

# Save model
torch.save(mono_model.state_dict(), '../models/monolithic_4bit.pth')
print(f"\nFinal accuracy: {mono_accs[-1]:.2f}%")

## Model 2: Compositional Modular Network

Processes each bit position independently, then combines results.

In [None]:
class CompositionalNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 4 independent bit-processing modules
        self.bit_modules = nn.ModuleList([
            nn.Sequential(
                nn.Linear(3, 16),  # 2 input bits + 1 carry-in
                nn.ReLU(),
                nn.Linear(16, 16),
                nn.ReLU()
            ) for _ in range(4)
        ])
        
        # Output layer
        self.output = nn.Linear(64, 5)  # 4 modules × 16D = 64D
    
    def forward(self, x):
        # x: [batch, 8] = [a0, a1, a2, a3, b0, b1, b2, b3]
        batch_size = x.size(0)
        
        bit_outputs = []
        carry = torch.zeros(batch_size, 1).to(x.device)
        
        for i in range(4):
            # Get bits for this position
            a_bit = x[:, i:i+1]
            b_bit = x[:, i+4:i+5]
            
            # Process with module
            module_input = torch.cat([a_bit, b_bit, carry], dim=1)
            module_output = self.bit_modules[i](module_input)
            bit_outputs.append(module_output)
            
            # Update carry (simple approximation)
            carry = torch.sigmoid(module_output[:, :1])
        
        # Concatenate all bit module outputs
        hidden = torch.cat(bit_outputs, dim=1)  # [batch, 64]
        
        # Final output
        output = torch.sigmoid(self.output(hidden))
        
        return output, hidden

comp_model = CompositionalNetwork().to(device)
print(f"Compositional parameters: {sum(p.numel() for p in comp_model.parameters())}")

In [None]:
# Train compositional model
print("Training Compositional Network...")
comp_losses, comp_accs = train_model(comp_model, train_loader, epochs=100)

# Save model
torch.save(comp_model.state_dict(), '../models/compositional_4bit.pth')
print(f"\nFinal accuracy: {comp_accs[-1]:.2f}%")

## Compare Training Curves

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax1.plot(mono_losses, label='Monolithic', linewidth=2)
ax1.plot(comp_losses, label='Compositional', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss', fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(mono_accs, label='Monolithic', linewidth=2)
ax2.plot(comp_accs, label='Compositional', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training Accuracy', fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 105])

plt.tight_layout()
plt.savefig('../figures/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*50)
print("Both models achieve 100% accuracy!")
print("But do they learn the same representations?")
print("="*50)

## Extract and Save Activations

Extract hidden layer activations from both models for Delta Observer training.

In [None]:
# Extract activations
mono_model.eval()
comp_model.eval()

with torch.no_grad():
    X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
    
    _, mono_activations = mono_model(X_tensor)
    _, comp_activations = comp_model(X_tensor)
    
    mono_activations = mono_activations.cpu().numpy()
    comp_activations = comp_activations.cpu().numpy()

print(f"Monolithic activations: {mono_activations.shape}")
print(f"Compositional activations: {comp_activations.shape}")

# Save
np.savez('../data/monolithic_activations.npz', activations=mono_activations, inputs=X)
np.savez('../data/compositional_activations.npz', activations=comp_activations, inputs=X)

print("\nActivations saved!")

## Next Steps

Continue to **`02_delta_observer_training.ipynb`** to train the Delta Observer that maps between these representations.