# ðŸ§© HybridPoHHRM Sudoku Benchmark

Train a **master-level Sudoku AI** using **HybridPoHHRM** - combining HRM's two-timescale reasoning with PoT head routing.

## Architecture
- **z_H, z_L**: Persistent hidden states (HRM-style)
- **L_level**: Fast reasoning (inner loop, 8 cycles)
- **H_level**: Slow reasoning (outer loop, 2 cycles)
- **PoT head routing**: Dynamic attention head selection in both levels
- **ACT**: Adaptive Computation Time with Q-learning halting
- **Async batching**: HRM-style sample replacement for GPU efficiency

## Expected Results
| Model | Parameters | Grid Accuracy | Note |
|-------|------------|---------------|------|
| HRM (paper) | 27M | ~55% | On 1K training puzzles |
| **HybridPoHHRM** | ~26M | TBD | Matching HRM config |

## Hardware Requirements
- **GPU**: H100/A100 recommended (40-80GB VRAM)
- **Runtime**: ~12-24 hours for full training (20K epochs)


## 1. Setup


In [None]:
# Clone PoT repository
!git clone https://github.com/Eran-BA/PoT.git /content/PoT 2>/dev/null || (cd /content/PoT && git pull)
%cd /content/PoT

# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q tqdm numpy huggingface_hub


In [None]:
# Verify GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## 2. Download Sudoku Dataset

Downloads the **Sudoku-Extreme** dataset from HuggingFace (`sapientinc/sudoku-extreme`):
- 1000 extreme-difficulty puzzles
- 1000 augmentations per puzzle (validity-preserving transforms)
- Total: ~1,000,000 training samples


In [None]:
# Optional: Pre-download dataset (--download in training command also works)
!python experiments/sudoku_poh_benchmark.py --download --epochs 0


## 3. Train HybridPoHHRM Sudoku Solver

### HRM-Aligned Configuration
```yaml
# Model
d_model: 512, n_heads: 8
H_cycles: 2, L_cycles: 8  # Two-timescale reasoning
H_layers: 2, L_layers: 2
halt_max_steps: 4         # ACT outer steps

# Training (HRM-aligned)
lr: 1e-4, weight_decay: 1.0
betas: (0.9, 0.95)        # Llama-style
warmup_steps: 2000
lr_min_ratio: 0.1         # Cosine decay floor
batch_size: 768           # Large batch for H100

# Optimizers
optimizer: AdamW
puzzle_optimizer: SignSGD  # HRM uses SignSGD for embeddings
```


In [None]:
# Full training (H100/A100) - HRM-aligned configuration
!python experiments/sudoku_poh_benchmark.py \
    --download \
    --model hybrid \
    --hrm-grad-style \
    --halt-max-steps 4 \
    --async-batch \
    --lr 1e-4 \
    --batch-size 768 \
    --weight-decay 1.0 \
    --puzzle-weight-decay 1.0 \
    --puzzle-lr-multiplier 1.0 \
    --puzzle-optimizer signsgd \
    --beta2 0.95 \
    --warmup-steps 2000 \
    --lr-min-ratio 0.1 \
    --constraint-weight 0 \
    --dropout 0.0


## 4. Quick Test (~1-2 hours)

For a quick sanity check, run with fewer epochs and smaller batch size:


In [None]:
# Quick test (~1-2 hours on T4/V100)
!python experiments/sudoku_poh_benchmark.py \
    --download \
    --model hybrid \
    --hrm-grad-style \
    --halt-max-steps 4 \
    --lr 1e-4 \
    --batch-size 128 \
    --weight-decay 1.0 \
    --puzzle-optimizer signsgd \
    --beta2 0.95 \
    --warmup-steps 500 \
    --constraint-weight 0 \
    --epochs 2000 \
    --eval-interval 100


## 5. View Results


In [None]:
# Evaluate best model on full test set (422k puzzles)
import torch
from torch.utils.data import DataLoader
from src.data import SudokuDataset
from src.pot.models import HybridPoHHRMSolver
from src.training import evaluate

device = torch.device('cuda')

# Load test set
test_dataset = SudokuDataset('data/sudoku-extreme-10k-aug-100', 'test')
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)
print(f"Test set: {len(test_dataset)} puzzles")

# Load best model (match training config!)
model = HybridPoHHRMSolver(
    d_model=512, n_heads=8,
    H_layers=2, L_layers=2, d_ff=2048,
    dropout=0.0,
    H_cycles=2, L_cycles=8, T=4,
    hrm_grad_style=True,
    halt_max_steps=4,
    num_puzzles=1,
).to(device)

checkpoint = torch.load('experiments/results/sudoku_poh/hybrid_best.pt')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
print(f"Val accuracy: {checkpoint['test_grid_acc']:.2f}%")

# Evaluate
print("\nEvaluating on 422k test puzzles...")
test_metrics = evaluate(model, test_loader, device, use_poh=True)
print(f"\n{'='*50}")
print(f"FINAL TEST RESULTS")
print(f"{'='*50}")
print(f"  Cell Accuracy: {test_metrics['cell_acc']:.2f}%")
print(f"  Grid Accuracy: {test_metrics['grid_acc']:.2f}%")

## References

- [HRM Paper](https://arxiv.org/abs/2506.21734): Hierarchical Reasoning Model
- [HRM GitHub](https://github.com/sapientinc/HRM): Official implementation  
- [PoT GitHub](https://github.com/Eran-BA/PoT): Pointer-over-Heads Transformer
