# Memory Injection Experiment

This notebook compares three injection modes to evaluate whether preserving
injection memory across ACT steps improves Sudoku solving accuracy:

1. **Baseline (`broadcast`)** — No memory, gated broadcast of controller state
2. **Cross-Attention with Memory (`cross_attn`)** — Memory bank of past controller states, tokens cross-attend
3. **Broadcast with Memory (`broadcast_memory`)** — Memory bank + learned summary + gated broadcast

All runs share the same hyperparameters (d_model=512, H_cycles=2, L_cycles=6, halt_max_steps=4).

**Key change:** `injection_memory` is now preserved across ACT steps via `ACTCarry`,
so modes 2 and 3 can accumulate reasoning context over multiple outer iterations.

In [None]:
# Setup: clone repo and install deps
!git clone https://github.com/Eran-BA/PoT.git
%cd PoT
!pip install -q torch torchvision torchaudio
!pip install -q tqdm numpy huggingface_hub wandb

## Shared Hyperparameters

| Parameter | Value |
|-----------|-------|
| d_model | 512 |
| d_ff | 2048 |
| n_heads | 8 |
| H_cycles | 2 |
| L_cycles | 6 |
| H_layers | 2 |
| L_layers | 2 |
| halt_max_steps | 4 |
| controller | transformer |
| d_ctrl | 256 |
| epochs | 500 |
| batch_size | 512 |
| lr | 3e-4 |
| warmup_steps | 2000 |
| dropout | 0.039 |

## Run 1: Baseline Broadcast (no memory)

In [None]:
!python experiments/sudoku_poh_benchmark.py \
    --d-model 512 \
    --d-ff 2048 \
    --model hybrid \
    --controller transformer \
    --d-ctrl 256 \
    --max-depth 32 \
    --injection-mode broadcast \
    --epochs 500 \
    --batch-size 512 \
    --lr 3e-4 \
    --warmup-steps 2000 \
    --n-heads 8 \
    --H-cycles 2 \
    --L-cycles 6 \
    --H-layers 2 \
    --L-layers 2 \
    --hrm-grad-style \
    --halt-max-steps 4 \
    --eval-interval 25 \
    --dropout 0.039 \
    --wandb \
    --project memory-injection-experiment \
    --run-name baseline-broadcast \
    --download

## Run 2: Cross-Attention with Memory Preservation

In [None]:
!python experiments/sudoku_poh_benchmark.py \
    --d-model 512 \
    --d-ff 2048 \
    --model hybrid \
    --controller transformer \
    --d-ctrl 256 \
    --max-depth 32 \
    --injection-mode cross_attn \
    --injection-memory-size 16 \
    --injection-n-heads 4 \
    --epochs 500 \
    --batch-size 512 \
    --lr 3e-4 \
    --warmup-steps 2000 \
    --n-heads 8 \
    --H-cycles 2 \
    --L-cycles 6 \
    --H-layers 2 \
    --L-layers 2 \
    --hrm-grad-style \
    --halt-max-steps 4 \
    --eval-interval 25 \
    --dropout 0.039 \
    --wandb \
    --project memory-injection-experiment \
    --run-name cross-attn-memory \
    --download

## Run 3: Broadcast with Memory

In [None]:
!python experiments/sudoku_poh_benchmark.py \
    --d-model 512 \
    --d-ff 2048 \
    --model hybrid \
    --controller transformer \
    --d-ctrl 256 \
    --max-depth 32 \
    --injection-mode broadcast_memory \
    --injection-memory-size 16 \
    --injection-n-heads 4 \
    --epochs 500 \
    --batch-size 512 \
    --lr 3e-4 \
    --warmup-steps 2000 \
    --n-heads 8 \
    --H-cycles 2 \
    --L-cycles 6 \
    --H-layers 2 \
    --L-layers 2 \
    --hrm-grad-style \
    --halt-max-steps 4 \
    --eval-interval 25 \
    --dropout 0.039 \
    --wandb \
    --project memory-injection-experiment \
    --run-name broadcast-memory \
    --download

## Per-Step Accuracy Analysis

Load the best checkpoint from each run and evaluate with intermediate output
collection to see how accuracy improves across ACT steps.

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import os
import sys
sys.path.insert(0, '.')

from src.pot.models.sudoku_solver import HybridPoHHRMSolver
from src.data.sudoku_dataset import SudokuDataset
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
def evaluate_per_step(model, dataloader, device, max_batches=50):
    """
    Evaluate model and compute per-ACT-step accuracy.
    
    Returns:
        Dict with per_step_cell_acc and per_step_grid_acc lists.
    """
    model.eval()
    
    # Initialize per-step counters
    n_steps = model.halt_max_steps
    step_correct_cells = [0] * n_steps
    step_total_cells = [0] * n_steps
    step_correct_grids = [0] * n_steps
    step_total_grids = [0] * n_steps
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= max_batches:
                break
            
            inp, label, puzzle_ids = batch
            inp = inp.to(device)
            label = label.to(device)
            puzzle_ids = puzzle_ids.to(device)
            
            # Get intermediate outputs
            result = model.forward_with_intermediate(inp, puzzle_ids)
            
            for step_idx, step_logits in enumerate(result['intermediate_logits']):
                preds = step_logits.argmax(dim=-1)
                
                # Cell accuracy
                step_correct_cells[step_idx] += (preds == label).sum().item()
                step_total_cells[step_idx] += label.numel()
                
                # Grid accuracy
                grid_correct = (preds == label).all(dim=1)
                step_correct_grids[step_idx] += grid_correct.sum().item()
                step_total_grids[step_idx] += label.size(0)
    
    # Compute percentages
    n_actual = len(result['intermediate_logits'])  # May be fewer if early halting
    per_step_cell_acc = [
        100.0 * step_correct_cells[i] / max(step_total_cells[i], 1)
        for i in range(n_actual)
    ]
    per_step_grid_acc = [
        100.0 * step_correct_grids[i] / max(step_total_grids[i], 1)
        for i in range(n_actual)
    ]
    
    return {
        'per_step_cell_acc': per_step_cell_acc,
        'per_step_grid_acc': per_step_grid_acc,
    }

print('evaluate_per_step defined.')

In [None]:
def load_model(checkpoint_path, injection_mode, device):
    """
    Load a trained model from checkpoint.
    
    Uses the same hyperparameters as the experiment runs.
    """
    # Build injection kwargs
    injection_kwargs = None
    if injection_mode in ('cross_attn', 'broadcast_memory'):
        injection_kwargs = {
            'memory_size': 16,
            'n_heads': 4,
        }
    
    model = HybridPoHHRMSolver(
        d_model=512,
        n_heads=8,
        H_layers=2,
        L_layers=2,
        d_ff=2048,
        dropout=0.039,
        H_cycles=2,
        L_cycles=6,
        T=32,
        num_puzzles=1,
        hrm_grad_style=True,
        halt_max_steps=4,
        controller_type='transformer',
        controller_kwargs={'d_ctrl': 256},
        injection_mode=injection_mode,
        injection_kwargs=injection_kwargs,
    ).to(device)
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    epoch = checkpoint.get('epoch', '?')
    grid_acc = checkpoint.get('test_grid_acc', '?')
    print(f'Loaded {injection_mode} model from epoch {epoch}, grid_acc={grid_acc}%')
    
    return model

print('load_model defined.')

In [None]:
# Load validation dataset
val_dataset = SudokuDataset(split='test', download=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)
print(f'Validation set: {len(val_dataset)} puzzles')

In [None]:
# ---- Update these paths to your best checkpoints ----
CHECKPOINTS = {
    'broadcast':        'outputs/baseline-broadcast/best_model.pt',
    'cross_attn':       'outputs/cross-attn-memory/best_model.pt',
    'broadcast_memory': 'outputs/broadcast-memory/best_model.pt',
}

results = {}
for mode, ckpt_path in CHECKPOINTS.items():
    if not os.path.exists(ckpt_path):
        print(f'Skipping {mode}: checkpoint not found at {ckpt_path}')
        continue
    
    model = load_model(ckpt_path, mode, device)
    metrics = evaluate_per_step(model, val_loader, device)
    results[mode] = metrics
    
    print(f'\n{mode}:')
    for step_idx, (cell, grid) in enumerate(zip(
        metrics['per_step_cell_acc'], metrics['per_step_grid_acc']
    )):
        print(f'  Step {step_idx+1}: cell={cell:.2f}%, grid={grid:.2f}%')
    
    del model  # Free GPU memory
    torch.cuda.empty_cache()

print('\nDone evaluating all models.')

## Visualization: Per-Step Accuracy Comparison

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

colors = {
    'broadcast': '#1f77b4',
    'cross_attn': '#ff7f0e',
    'broadcast_memory': '#2ca02c',
}
labels = {
    'broadcast': 'Broadcast (baseline)',
    'cross_attn': 'Cross-Attn + Memory',
    'broadcast_memory': 'Broadcast + Memory',
}

for mode, metrics in results.items():
    steps = list(range(1, len(metrics['per_step_cell_acc']) + 1))
    
    ax1.plot(steps, metrics['per_step_cell_acc'],
             marker='o', color=colors[mode], label=labels[mode], linewidth=2)
    ax2.plot(steps, metrics['per_step_grid_acc'],
             marker='o', color=colors[mode], label=labels[mode], linewidth=2)

ax1.set_xlabel('ACT Step')
ax1.set_ylabel('Cell Accuracy (%)')
ax1.set_title('Cell Accuracy per ACT Step')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_xticks(range(1, 5))

ax2.set_xlabel('ACT Step')
ax2.set_ylabel('Grid Accuracy (%)')
ax2.set_title('Grid Accuracy per ACT Step')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xticks(range(1, 5))

plt.suptitle('Memory Injection Experiment: Per-Step Accuracy', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('memory_injection_results.png', dpi=150, bbox_inches='tight')
plt.show()

print('Plot saved to memory_injection_results.png')

## Summary Table

In [None]:
print(f'{"Mode":<25} {"Final Cell Acc (%)":<20} {"Final Grid Acc (%)":<20} {"Step 1 -> Final Gain"}')
print('-' * 85)

for mode, metrics in results.items():
    cell_accs = metrics['per_step_cell_acc']
    grid_accs = metrics['per_step_grid_acc']
    gain = grid_accs[-1] - grid_accs[0] if len(grid_accs) > 1 else 0
    print(f'{labels.get(mode, mode):<25} {cell_accs[-1]:<20.2f} {grid_accs[-1]:<20.2f} {gain:+.2f}')