# ðŸ§  HybridPoHHRM ARC-2 Benchmark

Train an **Abstract Reasoning AI** on the **ARC-AGI-2** dataset using **HybridPoHHRM** - combining HRM's two-timescale reasoning with PoT head routing.

## Architecture
- **z_H, z_L**: Persistent hidden states (HRM-style)
- **L_level**: Fast reasoning (inner loop, 8 cycles)
- **H_level**: Slow reasoning (outer loop, 2 cycles)
- **PoT head routing**: Dynamic attention head selection in both levels
- **2D Positional Embeddings**: For spatial reasoning on 30x30 grids
- **On-the-fly Augmentation**: Color permutation, dihedral transforms, translation

## ARC-AGI-2 Dataset
| Split | Puzzles | Description |
|-------|---------|-------------|
| Training | 1,000 | Public training tasks |
| Evaluation | 120 | Public evaluation tasks |
| Human Performance | 66% | Average human accuracy |

## Hardware Requirements
- **GPU**: H100/A100 recommended (40-80GB VRAM)
- **Memory**: ARC uses 900 tokens (30x30) vs Sudoku's 81 (9x9)
- **Runtime**: ~24-48 hours for full training (20K epochs)


## 1. Setup


In [None]:
# Clone PoT repository
!git clone https://github.com/Eran-BA/PoT.git /content/PoT 2>/dev/null || (cd /content/PoT && git pull)
%cd /content/PoT

# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q tqdm numpy


In [None]:
# Verify GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## 2. Download ARC-AGI-2 Dataset

Downloads from [arcprize/ARC-AGI-2](https://github.com/arcprize/ARC-AGI-2):
- **Training**: ~1,000 puzzles with train examples
- **Evaluation**: ~120 puzzles with test examples
- **On-the-fly augmentation** during training (like Sudoku):
  - Color permutation (1-9 shuffled, 0 fixed)
  - Dihedral transforms (8 rotations/flips)
  - Translational augmentation (random position in 30x30)


In [None]:
# Download and build ARC-2 dataset (minimal, uses on-the-fly augmentation like Sudoku)
!python scripts/build_arc_dataset.py \
    --version arc-2 \
    --num-aug 0 \
    --output-dir data/arc-2


## 3. Train HybridPoHHRM ARC Solver

### Configuration (HRM-Aligned)
```yaml
# Model
d_model: 512, n_heads: 8
H_cycles: 2, L_cycles: 8  # Two-timescale reasoning
H_layers: 2, L_layers: 2
halt_max_steps: 4         # ACT outer steps

# Training (HRM-aligned)
lr: 1e-4, weight_decay: 1.0
betas: (0.9, 0.95)        # Llama-style
warmup_steps: 2000
lr_min_ratio: 0.1         # Cosine decay floor
batch_size: 32            # Smaller than Sudoku (larger seq_len)

# On-the-fly Augmentation (like Sudoku)
- Color permutation: shuffle colors 1-9, keep 0 fixed
- Dihedral transforms: 8 rotations/flips
- Translation: random position in 30x30 grid
```


In [None]:
# Full training (H100/A100) - HRM-aligned configuration
!python experiments/arc_poh_benchmark.py \
    --data-dir data/arc-2 \
    --model hybrid \
    --hrm-grad-style \
    --halt-max-steps 4 \
    --async-batch \
    --lr 1e-4 \
    --batch-size 32 \
    --weight-decay 1.0 \
    --puzzle-weight-decay 1.0 \
    --puzzle-lr-multiplier 1.0 \
    --puzzle-optimizer signsgd \
    --beta2 0.95 \
    --warmup-steps 2000 \
    --lr-min-ratio 0.1 \
    --dropout 0.0 \
    --epochs 20000 \
    --eval-interval 100


## 4. Quick Test (~2-4 hours)

For a quick sanity check on T4/V100, run with fewer epochs:


In [None]:
# Quick test (~2-4 hours on T4/V100)
!python experiments/arc_poh_benchmark.py \
    --data-dir data/arc-2 \
    --model hybrid \
    --hrm-grad-style \
    --halt-max-steps 4 \
    --lr 1e-4 \
    --batch-size 16 \
    --weight-decay 1.0 \
    --puzzle-optimizer signsgd \
    --beta2 0.95 \
    --warmup-steps 500 \
    --epochs 2000 \
    --eval-interval 100


## 5. Evaluate Model


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import sys
sys.path.insert(0, '.')

from src.data.arc import ARCDataset
from src.pot.models.arc_solver import HybridPoHARCSolver

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load test set (use same data dir as training)
test_dataset = ARCDataset('data/arc-2', 'test', augment=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
print(f"Test set: {len(test_dataset)} examples")

# Load best model (match training config!)
model = HybridPoHARCSolver(
    d_model=512, n_heads=8,
    H_layers=2, L_layers=2, d_ff=2048,
    dropout=0.0,
    H_cycles=2, L_cycles=8, T=4,
    hrm_grad_style=True,
    halt_max_steps=4,
    num_puzzles=1,
).to(device)

checkpoint = torch.load('experiments/results/arc_poh/hybrid_best.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
print(f"Val accuracy: {checkpoint['test_grid_acc']:.2f}%")


In [None]:
# Evaluate on test set
@torch.no_grad()
def evaluate_arc(model, dataloader, device):
    model.eval()
    total_loss = 0
    correct_cells = 0
    total_cells = 0
    correct_grids = 0
    total_grids = 0
    pad_id = 0
    
    for batch in tqdm(dataloader, desc="Evaluating"):
        inp = batch['input'].to(device)
        label = batch['label'].to(device)
        puzzle_ids = batch['puzzle_id'].to(device)
        
        logits = model(inp, puzzle_ids)[0]
        
        loss = F.cross_entropy(
            logits.view(-1, model.vocab_size),
            label.view(-1),
            ignore_index=pad_id,
        )
        
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        mask = label != pad_id
        correct_cells += ((preds == label) & mask).sum().item()
        total_cells += mask.sum().item()
        
        grid_correct = ((preds == label) | ~mask).all(dim=1)
        correct_grids += grid_correct.sum().item()
        total_grids += inp.size(0)
    
    return {
        'loss': total_loss / len(dataloader),
        'cell_acc': 100 * correct_cells / max(1, total_cells),
        'grid_acc': 100 * correct_grids / max(1, total_grids),
    }

print("\nEvaluating on ARC-2 test set...")
test_metrics = evaluate_arc(model, test_loader, device)
print(f"\n{'='*50}")
print(f"FINAL TEST RESULTS")
print(f"{'='*50}")
print(f"  Loss: {test_metrics['loss']:.4f}")
print(f"  Cell Accuracy: {test_metrics['cell_acc']:.2f}%")
print(f"  Grid Accuracy: {test_metrics['grid_acc']:.2f}%")
print(f"\n  Human Performance: ~66%")


## 6. Visualize Predictions


In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# ARC color palette
ARC_COLORS = [
    '#000000',  # 0: black
    '#0074D9',  # 1: blue
    '#FF4136',  # 2: red
    '#2ECC40',  # 3: green
    '#FFDC00',  # 4: yellow
    '#AAAAAA',  # 5: grey
    '#F012BE',  # 6: magenta
    '#FF851B',  # 7: orange
    '#7FDBFF',  # 8: cyan
    '#870C25',  # 9: brown
]
arc_cmap = mcolors.ListedColormap(ARC_COLORS)

def plot_arc_grid(grid, title="", ax=None):
    """Plot a single ARC grid with proper colors."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(4, 4))
    
    # Convert from token space (2-11) to color space (0-9)
    grid_colors = np.clip(grid - 2, 0, 9)
    mask = grid < 2
    grid_colors[mask] = 0
    
    # Find content bounds
    content = grid >= 2
    if content.any():
        rows = np.any(content, axis=1)
        cols = np.any(content, axis=0)
        r_min, r_max = np.where(rows)[0][[0, -1]]
        c_min, c_max = np.where(cols)[0][[0, -1]]
        grid_colors = grid_colors[r_min:r_max+1, c_min:c_max+1]
    
    ax.imshow(grid_colors, cmap=arc_cmap, vmin=0, vmax=9)
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    for i in range(grid_colors.shape[0] + 1):
        ax.axhline(i - 0.5, color='gray', linewidth=0.5)
    for j in range(grid_colors.shape[1] + 1):
        ax.axvline(j - 0.5, color='gray', linewidth=0.5)
    return ax

# Visualize some predictions
model.eval()
batch = next(iter(test_loader))
inp = batch['input'][:4].to(device)
label = batch['label'][:4].to(device)
puzzle_ids = batch['puzzle_id'][:4].to(device)

with torch.no_grad():
    preds = model(inp, puzzle_ids)[0].argmax(dim=-1)

fig, axes = plt.subplots(4, 3, figsize=(12, 16))
for i in range(4):
    inp_grid = inp[i].cpu().numpy().reshape(30, 30)
    label_grid = label[i].cpu().numpy().reshape(30, 30)
    pred_grid = preds[i].cpu().numpy().reshape(30, 30)
    
    mask = label_grid != 0
    is_correct = ((pred_grid == label_grid) | ~mask).all()
    
    plot_arc_grid(inp_grid, "Input", axes[i, 0])
    plot_arc_grid(label_grid, "Ground Truth", axes[i, 1])
    plot_arc_grid(pred_grid, f"Prediction {'âœ“' if is_correct else 'âœ—'}", axes[i, 2])

plt.tight_layout()
plt.show()


## References

- ARC-AGI-2 Repository: https://github.com/arcprize/ARC-AGI-2
- ARC Prize: https://arcprize.org/
- HRM Paper: https://arxiv.org/abs/2506.21734
- HRM GitHub: https://github.com/sapientinc/HRM
- PoT GitHub: https://github.com/Eran-BA/PoT
