# üß† Part 1: Training Source Models

This notebook trains two neural networks to solve **4-bit binary addition**:

| Model | Architecture | Description |
|-------|-------------|-------------|
| **Monolithic MLP** | 8 ‚Üí 64 ‚Üí 64 ‚Üí 5 | Single dense network |
| **Compositional Network** | Bit-wise modules | Independent bit processing |

Both achieve **100% accuracy** but learn fundamentally different internal representations.

üìÑ **Paper:** [OSF MetaArXiv](https://doi.org/10.17605/OSF.IO/CNJTP)  
üîó **Code:** [github.com/EntroMorphic/delta-observer](https://github.com/EntroMorphic/delta-observer)

---

**Note:** For the complete Online Delta Observer pipeline (recommended), use **`99_full_reproduction.ipynb`** which trains all models concurrently. This notebook is for understanding the source models in detail.

---

## üì¶ Setup

In [None]:
# Install dependencies if needed (Colab)
import subprocess
import sys

def install_if_needed(package):
    try:
        __import__(package.replace('-', '_'))
    except ImportError:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])

install_if_needed('torch')
install_if_needed('matplotlib')

print('‚úÖ Dependencies ready!')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

# Device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'üñ•Ô∏è Using device: {device}')

# Plotting style
plt.style.use('default')
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

# Colors
COLORS = {
    'mono': '#e74c3c',    # Red
    'comp': '#3498db',    # Blue
    'accent': '#2ecc71'   # Green
}

In [None]:
# Create necessary directories
os.makedirs('../models', exist_ok=True)
os.makedirs('../data', exist_ok=True)
os.makedirs('../figures', exist_ok=True)
print('‚úÖ Directories created')

---

## üìä Generate 4-bit Addition Dataset

All possible 4-bit + 4-bit additions: **512 examples** (16 √ó 16 combinations)

**Key semantic variable:** The number of **carry operations** required (0-4).

In [None]:
def count_carries(a, b):
    """Count the number of carry operations in binary addition."""
    carries = 0
    carry = 0
    for i in range(4):
        a_bit = (a >> i) & 1
        b_bit = (b >> i) & 1
        total = a_bit + b_bit + carry
        if total >= 2:
            carries += 1
            carry = 1
        else:
            carry = 0
    return carries

def generate_4bit_addition_dataset():
    """Generate all 512 possible 4-bit + 4-bit additions."""
    inputs = []
    outputs = []
    carry_counts = []
    
    for a in range(16):  # 4-bit: 0-15
        for b in range(16):
            # Convert to binary (4 bits each)
            a_bits = [(a >> i) & 1 for i in range(4)]
            b_bits = [(b >> i) & 1 for i in range(4)]
            
            # Concatenate: [a0, a1, a2, a3, b0, b1, b2, b3]
            input_bits = a_bits + b_bits
            
            # Output: 5-bit sum (0-30)
            sum_val = a + b
            output_bits = [(sum_val >> i) & 1 for i in range(5)]
            
            inputs.append(input_bits)
            outputs.append(output_bits)
            carry_counts.append(count_carries(a, b))
    
    return (np.array(inputs, dtype=np.float32), 
            np.array(outputs, dtype=np.float32),
            np.array(carry_counts))

# Generate dataset
X, y, carry_counts = generate_4bit_addition_dataset()
print(f'üìä Dataset shape: X={X.shape}, y={y.shape}')
print(f'üìä Carry count distribution: {np.bincount(carry_counts)}')

In [None]:
# üé® Visualize the dataset
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Carry count distribution
ax1 = axes[0]
counts = np.bincount(carry_counts)
bars = ax1.bar(range(len(counts)), counts, color=plt.cm.viridis(np.linspace(0, 1, len(counts))),
               edgecolor='black', linewidth=1.5)
ax1.set_xlabel('Carry Count', fontsize=11)
ax1.set_ylabel('Number of Examples', fontsize=11)
ax1.set_title('üìä Carry Count Distribution', fontsize=12, fontweight='bold')
ax1.set_xticks(range(5))
for bar, count in zip(bars, counts):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2,
             str(count), ha='center', fontsize=10, fontweight='bold')

# 2. Addition heatmap
ax2 = axes[1]
carry_matrix = np.zeros((16, 16))
for a in range(16):
    for b in range(16):
        carry_matrix[a, b] = count_carries(a, b)
im = ax2.imshow(carry_matrix, cmap='viridis', aspect='equal')
ax2.set_xlabel('Second Operand (b)', fontsize=11)
ax2.set_ylabel('First Operand (a)', fontsize=11)
ax2.set_title('üéØ Carries by Operand Pair', fontsize=12, fontweight='bold')
plt.colorbar(im, ax=ax2, label='Carry Count')

# 3. Example additions
ax3 = axes[2]
ax3.axis('off')
examples = [
    (1, 1, '0001 + 0001 = 00010', 0),
    (7, 1, '0111 + 0001 = 01000', 3),
    (15, 15, '1111 + 1111 = 11110', 4),
    (8, 4, '1000 + 0100 = 01100', 0),
]
text = 'üî¢ Example Additions:\n\n'
for a, b, binary, carries in examples:
    text += f'{a} + {b} = {a+b}\n'
    text += f'  {binary}\n'
    text += f'  Carries: {carries}\n\n'
ax3.text(0.1, 0.9, text, transform=ax3.transAxes, fontsize=10, 
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
ax3.set_title('üìù Examples', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('../figures/dataset_visualization.png', dpi=150, bbox_inches='tight')
plt.show()

print('\nüí° The carry count is the key semantic variable we study!')

In [None]:
class AdditionDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Create dataloader
dataset = AdditionDataset(X, y)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
print('‚úÖ DataLoader created')

---

## üß± Model 1: Monolithic MLP

A simple feed-forward network that processes all input bits together.

```
Input (8) ‚Üí Dense (64) ‚Üí ReLU ‚Üí Dense (64) ‚Üí ReLU ‚Üí Dense (5) ‚Üí Sigmoid
```

In [None]:
class MonolithicMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(8, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 5)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        hidden = torch.relu(self.fc2(x))  # Extract this layer
        x = torch.sigmoid(self.fc3(hidden))
        return x, hidden

mono_model = MonolithicMLP().to(device)
print(f'üß± Monolithic MLP parameters: {sum(p.numel() for p in mono_model.parameters()):,}')

In [None]:
# üé® Visualize architecture
fig, ax = plt.subplots(figsize=(12, 6))
ax.axis('off')

# Draw layers
layer_sizes = [8, 64, 64, 5]
layer_names = ['Input\n(8 bits)', 'Hidden 1\n(64)', 'Hidden 2\n(64)', 'Output\n(5 bits)']
layer_colors = ['#3498db', '#e74c3c', '#e74c3c', '#2ecc71']
x_positions = [0.1, 0.35, 0.6, 0.85]

for i, (x, size, name, color) in enumerate(zip(x_positions, layer_sizes, layer_names, layer_colors)):
    # Draw rectangle
    height = size / 100
    rect = plt.Rectangle((x - 0.08, 0.5 - height/2), 0.16, height,
                          facecolor=color, edgecolor='black', linewidth=2, alpha=0.7)
    ax.add_patch(rect)
    ax.text(x, 0.5 - height/2 - 0.08, name, ha='center', fontsize=11, fontweight='bold')
    
    # Draw arrows
    if i < len(x_positions) - 1:
        ax.annotate('', xy=(x_positions[i+1] - 0.08, 0.5), xytext=(x + 0.08, 0.5),
                    arrowprops=dict(arrowstyle='->', color='black', lw=2))
        ax.text((x + x_positions[i+1]) / 2, 0.55, 'Dense\n+ReLU', ha='center', fontsize=9)

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title('üß± Monolithic MLP Architecture', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

---

## üîß Model 2: Compositional Modular Network

Processes each bit position independently with separate modules, then combines results.

```
Bit 0: (a0, b0, carry_in) ‚Üí Module 0 ‚Üí (out0, carry_out)
Bit 1: (a1, b1, carry_in) ‚Üí Module 1 ‚Üí (out1, carry_out)
...and so on...
```

In [None]:
class CompositionalNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 4 independent bit-processing modules
        self.bit_modules = nn.ModuleList([
            nn.Sequential(
                nn.Linear(3, 16),  # 2 input bits + 1 carry-in
                nn.ReLU(),
                nn.Linear(16, 16),
                nn.ReLU()
            ) for _ in range(4)
        ])
        
        # Output layer
        self.output = nn.Linear(64, 5)  # 4 modules √ó 16D = 64D
    
    def forward(self, x):
        batch_size = x.size(0)
        
        bit_outputs = []
        carry = torch.zeros(batch_size, 1).to(x.device)
        
        for i in range(4):
            # Get bits for this position
            a_bit = x[:, i:i+1]
            b_bit = x[:, i+4:i+5]
            
            # Process with module
            module_input = torch.cat([a_bit, b_bit, carry], dim=1)
            module_output = self.bit_modules[i](module_input)
            bit_outputs.append(module_output)
            
            # Update carry (simple approximation)
            carry = torch.sigmoid(module_output[:, :1])
        
        # Concatenate all bit module outputs
        hidden = torch.cat(bit_outputs, dim=1)  # [batch, 64]
        
        # Final output
        output = torch.sigmoid(self.output(hidden))
        
        return output, hidden

comp_model = CompositionalNetwork().to(device)
print(f'üîß Compositional Network parameters: {sum(p.numel() for p in comp_model.parameters()):,}')

In [None]:
# üé® Visualize compositional architecture
fig, ax = plt.subplots(figsize=(14, 8))
ax.axis('off')

# Draw modules
module_y = [0.8, 0.6, 0.4, 0.2]
module_colors = plt.cm.Blues(np.linspace(0.4, 0.8, 4))

for i, (y, color) in enumerate(zip(module_y, module_colors)):
    # Input
    ax.text(0.05, y, f'a{i}, b{i}', ha='center', fontsize=10, fontweight='bold')
    ax.annotate('', xy=(0.15, y), xytext=(0.08, y),
                arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
    
    # Module box
    rect = plt.Rectangle((0.15, y - 0.06), 0.25, 0.12,
                          facecolor=color, edgecolor='black', linewidth=2, alpha=0.8)
    ax.add_patch(rect)
    ax.text(0.275, y, f'Module {i}\n(3‚Üí16‚Üí16)', ha='center', va='center', fontsize=9, fontweight='bold')
    
    # Carry arrow (if not first)
    if i > 0:
        ax.annotate('', xy=(0.15, y + 0.03), xytext=(0.15, module_y[i-1] - 0.08),
                    arrowprops=dict(arrowstyle='->', color='orange', lw=1.5, ls='--'))
    
    # Output arrow
    ax.annotate('', xy=(0.5, y), xytext=(0.4, y),
                arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))

# Concatenation
concat_rect = plt.Rectangle((0.5, 0.15), 0.1, 0.7,
                              facecolor='#9b59b6', edgecolor='black', linewidth=2, alpha=0.7)
ax.add_patch(concat_rect)
ax.text(0.55, 0.5, 'Concat\n(64D)', ha='center', va='center', fontsize=10, fontweight='bold', color='white')

# Output layer
ax.annotate('', xy=(0.7, 0.5), xytext=(0.6, 0.5),
            arrowprops=dict(arrowstyle='->', color='black', lw=2))
output_rect = plt.Rectangle((0.7, 0.35), 0.15, 0.3,
                              facecolor='#2ecc71', edgecolor='black', linewidth=2, alpha=0.7)
ax.add_patch(output_rect)
ax.text(0.775, 0.5, 'Output\n(5 bits)', ha='center', va='center', fontsize=10, fontweight='bold')

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_title('üîß Compositional Network Architecture', fontsize=14, fontweight='bold')

# Legend
ax.text(0.1, 0.05, 'Orange dashed: Carry propagation', fontsize=9, color='orange')

plt.tight_layout()
plt.show()

---

## üèãÔ∏è Training

In [None]:
def train_model(model, train_loader, epochs=100, lr=0.001, name='Model'):
    """Train a model and return loss/accuracy history."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCELoss()
    
    losses = []
    accuracies = []
    
    pbar = tqdm(range(epochs), desc=f'Training {name}')
    for epoch in pbar:
        model.train()
        epoch_loss = 0
        correct = 0
        total = 0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            # Accuracy: all 5 bits must match
            pred_bits = (outputs > 0.5).float()
            correct += (pred_bits == targets).all(dim=1).sum().item()
            total += inputs.size(0)
        
        losses.append(epoch_loss / len(train_loader))
        accuracies.append(100 * correct / total)
        
        pbar.set_postfix({'Loss': f'{losses[-1]:.4f}', 'Acc': f'{accuracies[-1]:.1f}%'})
    
    return losses, accuracies

In [None]:
# Train monolithic model
print('üß± Training Monolithic MLP...')
mono_losses, mono_accs = train_model(mono_model, train_loader, epochs=100, name='Monolithic')

# Save model
torch.save(mono_model.state_dict(), '../models/monolithic_4bit.pth')
print(f'\n‚úÖ Final accuracy: {mono_accs[-1]:.2f}%')

In [None]:
# Train compositional model
print('\nüîß Training Compositional Network...')
comp_losses, comp_accs = train_model(comp_model, train_loader, epochs=100, name='Compositional')

# Save model
torch.save(comp_model.state_dict(), '../models/compositional_4bit.pth')
print(f'\n‚úÖ Final accuracy: {comp_accs[-1]:.2f}%')

---

## üìà Compare Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1 = axes[0]
ax1.plot(mono_losses, label='Monolithic', color=COLORS['mono'], linewidth=2.5)
ax1.plot(comp_losses, label='Compositional', color=COLORS['comp'], linewidth=2.5)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('üìâ Training Loss', fontsize=13, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2 = axes[1]
ax2.plot(mono_accs, label='Monolithic', color=COLORS['mono'], linewidth=2.5)
ax2.plot(comp_accs, label='Compositional', color=COLORS['comp'], linewidth=2.5)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('üìà Training Accuracy', fontsize=13, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 105])
ax2.axhline(100, color='green', linestyle='--', alpha=0.5, label='Perfect')

plt.tight_layout()
plt.savefig('../figures/source_model_training.png', dpi=150, bbox_inches='tight')
plt.show()

print('\n' + '='*60)
print('‚úÖ Both models achieve 100% accuracy!')
print('‚ùì But do they learn the same internal representations?')
print('='*60)

---

## üîç Extract and Analyze Activations

Let's extract the hidden layer activations and see if they encode the **carry count**.

In [None]:
# Extract activations
mono_model.eval()
comp_model.eval()

with torch.no_grad():
    X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
    
    _, mono_activations = mono_model(X_tensor)
    _, comp_activations = comp_model(X_tensor)
    
    mono_activations = mono_activations.cpu().numpy()
    comp_activations = comp_activations.cpu().numpy()

print(f'üìä Monolithic activations: {mono_activations.shape}')
print(f'üìä Compositional activations: {comp_activations.shape}')

In [None]:
# Visualize activations with PCA
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, activations, name, color in [
    (axes[0], mono_activations, 'Monolithic', COLORS['mono']),
    (axes[1], comp_activations, 'Compositional', COLORS['comp'])
]:
    # PCA projection
    pca = PCA(n_components=2)
    act_2d = pca.fit_transform(activations)
    
    # Scatter plot colored by carry count
    scatter = ax.scatter(act_2d[:, 0], act_2d[:, 1], c=carry_counts, 
                         cmap='viridis', s=30, alpha=0.7, edgecolors='white', linewidth=0.3)
    
    # Compute linear accessibility
    reg = LinearRegression().fit(activations, carry_counts)
    r2 = r2_score(carry_counts, reg.predict(activations))
    
    ax.set_xlabel('PC1', fontsize=11)
    ax.set_ylabel('PC2', fontsize=11)
    ax.set_title(f'{name} Activations\nR¬≤ (carry) = {r2:.4f}', fontsize=12, fontweight='bold')
    
plt.colorbar(scatter, ax=axes, label='Carry Count', shrink=0.8)

plt.tight_layout()
plt.savefig('../figures/activation_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print('\nüí° Both models encode carry count information, but with different geometries!')

In [None]:
# Save activations for Delta Observer training
np.savez('../data/monolithic_activations.npz', 
         activations=mono_activations, inputs=X, carry_counts=carry_counts)
np.savez('../data/compositional_activations.npz', 
         activations=comp_activations, inputs=X, carry_counts=carry_counts)

print('‚úÖ Activations saved to ../data/')
print('   - monolithic_activations.npz')
print('   - compositional_activations.npz')

---

## üìù Summary

| Model | Architecture | Parameters | Final Accuracy | R¬≤ (Carry) |
|-------|-------------|------------|----------------|------------|
| **Monolithic** | 8 ‚Üí 64 ‚Üí 64 ‚Üí 5 | ~4,500 | 100% | ~0.85 |
| **Compositional** | 4 √ó (3 ‚Üí 16 ‚Üí 16) + output | ~4,800 | 100% | ~0.90 |

**Key Observations:**
1. Both models achieve perfect accuracy on the task
2. Both encode carry count information in their activations
3. The internal representations have different geometric structures

**Next:** The Delta Observer will learn to map between these different representations, discovering shared semantic structure!

---

## üöÄ Next Steps

Continue to **`02_delta_observer_training.ipynb`** to train the Delta Observer that maps between these representations.

| Notebook | Description | Colab |
|----------|-------------|-------|
| **02_delta_observer_training** | Train the Delta Observer | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EntroMorphic/delta-observer/blob/main/notebooks/02_delta_observer_training.ipynb) |
| **03_analysis_visualization** | Geometric analysis | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EntroMorphic/delta-observer/blob/main/notebooks/03_analysis_visualization.ipynb) |
| **99_full_reproduction** | Complete pipeline | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EntroMorphic/delta-observer/blob/main/notebooks/99_full_reproduction.ipynb) |

---

**For Science!** üî¨üåä