# üî¨ Feature Injection Ablation Study

Benchmark all **6 Feature Injection modes** on Sudoku-Extreme dataset.

## Experiment 1: Injection Mode Comparison
| Mode | Description |
|------|-------------|
| `none` | Routing-only (baseline) |
| `broadcast` | Gated broadcast to all tokens |
| `film` | FiLM modulation (Œ≥*x + Œ≤) |
| `depth_token` | Prepend depth token |
| `cross_attn` | Cross-attention to memory bank |
| `alpha_gated` | Alpha-modulated broadcast |

## Experiment 2: Alpha Aggregation Deep Dive
| Aggregation | Formula |
|-------------|--------|
| `mean` | Average routing weight |
| `max` | Maximum routing weight |
| `entropy` | 1 - H(Œ±)/H_max (confident ‚Üí stronger) |

## Configuration (Paper-Aligned)
- **Dataset**: 9k Sudoku-Extreme puzzles
- **Epochs**: 200 per experiment
- **Batch size**: 768 (A100/H100)
- **Warmup**: 2,000 steps
- **H/L cycles**: 2/6
- **Logging**: Weights & Biases
- **GPU**: A100/H100 required (768 batch size)


## 1. Setup


In [None]:
# Clone PoT repository
!git clone https://github.com/Eran-BA/PoT.git /content/PoT 2>/dev/null || (cd /content/PoT && git pull)
%cd /content/PoT

# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q tqdm numpy huggingface_hub wandb


In [None]:
# Verify GPU
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# Login to Weights & Biases
import wandb
wandb.login()


## 2. Download Dataset


In [None]:
from src.data import download_sudoku_dataset

# Download full 9k puzzles with 100 augmentations each (900k total samples)
download_sudoku_dataset(
    output_dir='data/sudoku-extreme-9k',
    subsample_size=9000,
    num_aug=100,  # 100 augmentations per puzzle
)
print("‚úì Dataset ready")


## 3. Training Infrastructure


In [None]:
import os
import math
from datetime import datetime
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from src.data import SudokuDataset
from src.pot.models import HybridPoHHRMSolver


def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1):
    """Cosine learning rate schedule with warmup."""
    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def evaluate(model, val_loader, device):
    """Evaluate model on validation set."""
    model.eval()
    total_loss = 0
    correct_cells = 0
    total_cells = 0
    correct_grids = 0
    total_grids = 0
    
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['input'].to(device)
            targets = batch['label'].to(device)
            puzzle_ids = batch['puzzle_id'].to(device)
            
            logits, _, _, _ = model(inputs, puzzle_ids)
            mask = (inputs == 0)
            
            if mask.any():
                loss = F.cross_entropy(logits[mask], targets[mask])
                total_loss += loss.item()
            
            preds = logits.argmax(dim=-1)
            correct_cells += ((preds == targets) & mask).sum().item()
            total_cells += mask.sum().item()
            
            grid_correct = ((preds == targets) | ~mask).all(dim=1)
            correct_grids += grid_correct.sum().item()
            total_grids += inputs.size(0)
    
    model.train()
    return {
        'loss': total_loss / len(val_loader),
        'cell_acc': 100 * correct_cells / max(1, total_cells),
        'grid_acc': 100 * correct_grids / max(1, total_grids),
    }


def train_model(
    injection_mode='none',
    injection_kwargs=None,
    epochs=200,
    batch_size=768,
    lr=3e-4,
    warmup_steps=2000,
    d_model=512,
    n_heads=8,
    H_cycles=2,
    L_cycles=6,
    H_layers=2,
    L_layers=2,
    halt_max_steps=4,
    use_wandb=True,
    run_name=None,
    project='feature-injection-ablation',
):
    """Train HybridPoHHRMSolver with specified injection mode."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Generate run name
    if run_name is None:
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
        if injection_kwargs:
            extra = '-'.join(f"{k}={v}" for k, v in injection_kwargs.items())
            run_name = f"{injection_mode}-{extra}-{timestamp}"
        else:
            run_name = f"{injection_mode}-{timestamp}"
    
    # Initialize W&B
    if use_wandb:
        wandb.init(
            project=project,
            name=run_name,
            config={
                'injection_mode': injection_mode,
                'injection_kwargs': injection_kwargs,
                'epochs': epochs,
                'batch_size': batch_size,
                'lr': lr,
                'd_model': d_model,
                'n_heads': n_heads,
                'H_cycles': H_cycles,
                'L_cycles': L_cycles,
                'H_layers': H_layers,
                'L_layers': L_layers,
                'halt_max_steps': halt_max_steps,
            },
            reinit=True,
        )
    
    # Load datasets
    train_dataset = SudokuDataset('data/sudoku-extreme-9k/train.pt', augment=True)
    val_dataset = SudokuDataset('data/sudoku-extreme-9k/val.pt', augment=False)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"Train: {len(train_dataset)} samples, Val: {len(val_dataset)} samples")
    
    # Create model
    model = HybridPoHHRMSolver(
        vocab_size=10,
        d_model=d_model,
        n_heads=n_heads,
        H_layers=H_layers,
        L_layers=L_layers,
        d_ff=d_model * 4,
        H_cycles=H_cycles,
        L_cycles=L_cycles,
        halt_max_steps=halt_max_steps,
        injection_mode=injection_mode,
        injection_kwargs=injection_kwargs,
    ).to(device)
    
    n_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {n_params:,} ({n_params/1e6:.2f}M)")
    print(f"Injection mode: {injection_mode}")
    if injection_kwargs:
        print(f"Injection kwargs: {injection_kwargs}")
    
    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    total_steps = epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
    
    # Training loop
    best_grid_acc = 0
    
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0
        epoch_cells = 0
        epoch_total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
        for batch in pbar:
            inputs = batch['input'].to(device)
            targets = batch['label'].to(device)
            puzzle_ids = batch['puzzle_id'].to(device)
            
            logits, _, _, _ = model(inputs, puzzle_ids)
            mask = (inputs == 0)
            
            if not mask.any():
                continue
            
            loss = F.cross_entropy(logits[mask], targets[mask])
            
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            with torch.no_grad():
                preds = logits.argmax(dim=-1)
                correct = ((preds == targets) & mask).sum().item()
                epoch_cells += correct
                epoch_total += mask.sum().item()
            
            epoch_loss += loss.item()
            cell_acc = 100 * epoch_cells / max(1, epoch_total)
            pbar.set_postfix({'loss': f'{loss.item():.3f}', 'cell': f'{cell_acc:.1f}%'})
        
        # Evaluate
        val_metrics = evaluate(model, val_loader, device)
        
        # Log to W&B
        if use_wandb:
            wandb.log({
                'epoch': epoch,
                'train/loss': epoch_loss / len(train_loader),
                'train/cell_acc': 100 * epoch_cells / max(1, epoch_total),
                'val/loss': val_metrics['loss'],
                'val/cell_acc': val_metrics['cell_acc'],
                'val/grid_acc': val_metrics['grid_acc'],
                'lr': scheduler.get_last_lr()[0],
            })
        
        # Print progress every 10 epochs
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: val_cell={val_metrics['cell_acc']:.2f}%, val_grid={val_metrics['grid_acc']:.2f}%")
        
        # Save best model
        if val_metrics['grid_acc'] > best_grid_acc:
            best_grid_acc = val_metrics['grid_acc']
            if use_wandb:
                torch.save(model.state_dict(), f'/content/{run_name}_best.pt')
                wandb.save(f'/content/{run_name}_best.pt')
    
    if use_wandb:
        wandb.log({'best_grid_acc': best_grid_acc})
        wandb.finish()
    
    return {
        'injection_mode': injection_mode,
        'injection_kwargs': injection_kwargs,
        'best_grid_acc': best_grid_acc,
        'final_cell_acc': val_metrics['cell_acc'],
        'final_grid_acc': val_metrics['grid_acc'],
    }

print("‚úì Training infrastructure ready")


## 4. Experiment 1: Injection Mode Comparison

Compare all 6 injection modes with identical hyperparameters.


In [None]:
# Injection modes to compare
INJECTION_MODES = [
    ('none', None),                    # Baseline: routing only
    ('broadcast', None),               # Gated broadcast
    ('film', None),                    # FiLM modulation
    ('depth_token', None),             # Depth token
    ('cross_attn', {'memory_size': 16, 'n_heads': 4}),  # Cross-attention
    ('alpha_gated', {'alpha_aggregation': 'mean'}),     # Alpha-gated (mean)
]

print(f"Will run {len(INJECTION_MODES)} experiments:")
for mode, kwargs in INJECTION_MODES:
    print(f"  - {mode}: {kwargs if kwargs else 'default'}")


In [None]:
# Run all injection mode experiments
results_exp1 = []

for injection_mode, injection_kwargs in INJECTION_MODES:
    print(f"\n{'='*60}")
    print(f"Running: {injection_mode}")
    print(f"{'='*60}")
    
    result = train_model(
        injection_mode=injection_mode,
        injection_kwargs=injection_kwargs,
        epochs=200,
        batch_size=768,
        lr=3e-4,
        warmup_steps=2000,
        d_model=512,
        n_heads=8,
        H_cycles=2,
        L_cycles=6,
        halt_max_steps=4,
        use_wandb=True,
        project='feature-injection-ablation',
    )
    results_exp1.append(result)
    print(f"\n‚úì {injection_mode}: grid_acc={result['best_grid_acc']:.2f}%")


In [None]:
# Display Experiment 1 Results
import pandas as pd

df1 = pd.DataFrame(results_exp1)
df1 = df1.sort_values('best_grid_acc', ascending=False)
print("\n" + "="*60)
print("EXPERIMENT 1: Injection Mode Comparison")
print("="*60)
print(df1[['injection_mode', 'best_grid_acc', 'final_cell_acc']].to_string(index=False))


## 5. Experiment 2: Alpha Aggregation Deep Dive

Compare different alpha aggregation strategies for `alpha_gated` mode.


In [None]:
# Alpha aggregation modes to compare
ALPHA_AGGREGATIONS = [
    {'alpha_aggregation': 'mean', 'use_learned_gate': True},   # Mean + learned gate
    {'alpha_aggregation': 'max', 'use_learned_gate': True},    # Max + learned gate
    {'alpha_aggregation': 'entropy', 'use_learned_gate': True}, # Entropy + learned gate
    {'alpha_aggregation': 'mean', 'use_learned_gate': False},  # Mean only (no learned gate)
    {'alpha_aggregation': 'entropy', 'use_learned_gate': False}, # Entropy only (no learned gate)
]

print(f"Will run {len(ALPHA_AGGREGATIONS)} alpha-gated experiments:")
for kwargs in ALPHA_AGGREGATIONS:
    print(f"  - {kwargs}")


In [None]:
# Run all alpha aggregation experiments
results_exp2 = []

for injection_kwargs in ALPHA_AGGREGATIONS:
    print(f"\n{'='*60}")
    print(f"Running alpha_gated with: {injection_kwargs}")
    print(f"{'='*60}")
    
    result = train_model(
        injection_mode='alpha_gated',
        injection_kwargs=injection_kwargs,
        epochs=200,
        batch_size=768,
        lr=3e-4,
        warmup_steps=2000,
        d_model=512,
        n_heads=8,
        H_cycles=2,
        L_cycles=6,
        halt_max_steps=4,
        use_wandb=True,
        project='feature-injection-ablation',
    )
    results_exp2.append(result)
    print(f"\n‚úì {injection_kwargs}: grid_acc={result['best_grid_acc']:.2f}%")


In [None]:
# Display Experiment 2 Results
df2 = pd.DataFrame(results_exp2)
df2 = df2.sort_values('best_grid_acc', ascending=False)
print("\n" + "="*60)
print("EXPERIMENT 2: Alpha Aggregation Comparison")
print("="*60)
print(df2[['injection_kwargs', 'best_grid_acc', 'final_cell_acc']].to_string(index=False))


## 6. Summary & Conclusions


In [None]:
# Final summary
print("\n" + "="*70)
print("FEATURE INJECTION ABLATION STUDY - FINAL RESULTS")
print("="*70)

print("\nüìä EXPERIMENT 1: Injection Mode Comparison")
print("-"*50)
for _, row in df1.iterrows():
    print(f"  {row['injection_mode']:15} | Grid: {row['best_grid_acc']:5.2f}% | Cell: {row['final_cell_acc']:5.2f}%")

print("\nüìä EXPERIMENT 2: Alpha Aggregation Comparison")
print("-"*50)
for _, row in df2.iterrows():
    kwargs = row['injection_kwargs']
    agg = kwargs.get('alpha_aggregation', 'mean')
    gate = 'gate' if kwargs.get('use_learned_gate', True) else 'no-gate'
    print(f"  {agg:8} + {gate:7} | Grid: {row['best_grid_acc']:5.2f}% | Cell: {row['final_cell_acc']:5.2f}%")

print("\nüèÜ BEST OVERALL:")
all_results = results_exp1 + results_exp2
best = max(all_results, key=lambda x: x['best_grid_acc'])
print(f"   Mode: {best['injection_mode']}")
print(f"   Kwargs: {best['injection_kwargs']}")
print(f"   Grid Accuracy: {best['best_grid_acc']:.2f}%")

print("\n‚úì All experiments complete! Check W&B for detailed charts.")


## 7. Quick Single Experiment (Optional)

Run a single experiment with custom settings.


In [None]:
# Quick single experiment (modify as needed)
# result = train_model(
#     injection_mode='alpha_gated',
#     injection_kwargs={'alpha_aggregation': 'entropy', 'use_learned_gate': True},
#     epochs=50,  # Quick test
#     batch_size=128,
#     use_wandb=True,
#     run_name='alpha_gated_entropy_quick_test',
# )
# print(f"Result: {result}")
