# 🎮 Sokoban Supervised Learning with PoT

This notebook trains a **Pondering over Thoughts (PoT)** model on Sokoban puzzles using **supervised learning** - identical to our Sudoku training pipeline.

## What is Sokoban?

**Sokoban** (倉庫番, "warehouse keeper") is a classic puzzle game where you push boxes onto target locations.

<img src="https://upload.wikimedia.org/wikipedia/commons/4/4b/Sokoban_ani.gif" width="300" alt="Sokoban gameplay animation">

**Rules:**
- 🧑 Player can move in 4 directions (up/down/left/right)
- 📦 Player can push ONE box at a time (not pull)
- 🎯 Goal: Push ALL boxes onto target squares
- ⚠️ Boxes can get stuck in corners (deadlock = game over!)

▶️ **Watch gameplay:** [Sokoban Tutorial on YouTube](https://www.youtube.com/watch?v=4SjXQ_bHTxU)

**Why is it hard for AI?**
- PSPACE-complete (exponential state space)
- Sparse rewards (only get reward when solved)
- Long-horizon planning required
- Easy to create unsolvable states

## Benchmark Comparison

| Method | Simple (6×6, 1 box) | Complex (10×10, 2 boxes) | Notes |
|--------|---------------------|--------------------------|-------|
| SFT (paper) | ~50% | ~15% | Supervised fine-tuning |
| GPT-4 + LangGraph | varies | varies | [Blog](https://blog.gopenai.com/using-llms-and-langgraph-to-tackle-sokoban-puzzles-5f50b43b9515) |
| RL (PPO) | ~20% | <5% | Very hard to train |
| Random | 25% | 25% | 4 actions |
| **PoT (this notebook)** | TBD | TBD | Adaptive depth |

## Dataset

We use the [Xiaofeng77/sokoban](https://huggingface.co/datasets/Xiaofeng77/sokoban) HuggingFace dataset with:
- ~3,000 (board, optimal_action) pairs
- On-the-fly augmentation (8x via rotations/flips)
- Cross-entropy loss + Q-halt loss (identical to Sudoku)


In [None]:
# @title 🔧 Setup (Run First)
# @markdown Install dependencies and clone repository

!pip install -q torch datasets tqdm wandb gym-sokoban

# Clone PoT repository
!git clone -q https://github.com/ebenartzy/PoT.git 2>/dev/null || (cd PoT && git pull -q)

import sys
sys.path.insert(0, 'PoT')

print("✅ Setup complete!")


In [None]:
# @title 🔑 Weights & Biases Login (Optional)
# @markdown Enable USE_WANDB in Configuration to track experiments

import wandb

# Login to W&B (will prompt for API key on first run)
wandb.login()
print("✅ Logged in to Weights & Biases!")


In [None]:
# @title 📊 Configuration
# @markdown ### Model Type
MODEL_TYPE = "hybrid_pot"  # @param ["pot", "hybrid_pot", "baseline"]
CONTROLLER_TYPE = "transformer"  # @param ["transformer", "gru", "lstm", "diffusion", "swin", "mamba"]

# @markdown ### Architecture
D_MODEL = 512  # @param {type:"slider", min:64, max:512, step:64}
D_FF = 1024  # @param {type:"slider", min:128, max:2048, step:128}
N_HEADS = 4  # @param {type:"slider", min:2, max:16, step:2}
N_LAYERS = 2  # @param {type:"slider", min:1, max:8, step:1}
DROPOUT = 0.0  # @param {type:"slider", min:0.0, max:0.5, step:0.1}

# @markdown ### PoT Iteration Parameters
R = 4  # @param {type:"slider", min:1, max:16, step:1}
T = 4  # @param {type:"slider", min:1, max:8, step:1}

# @markdown ### Hybrid PoT Parameters (H/L cycles)
H_LAYERS = 2  # @param {type:"slider", min:1, max:4, step:1}
L_LAYERS = 2  # @param {type:"slider", min:1, max:4, step:1}
H_CYCLES = 2  # @param {type:"slider", min:1, max:8, step:1}
L_CYCLES = 6  # @param {type:"slider", min:1, max:16, step:1}

# @markdown ### Controller Parameters
D_CTRL = 128  # @param {type:"slider", min:32, max:256, step:32}
MAX_DEPTH = 128  # @param {type:"slider", min:32, max:256, step:32}

# @markdown ### Feature Injection
INJECTION_MODE = "broadcast"  # @param ["none", "broadcast", "film", "depth_token", "cross_attn", "alpha_gated"]
INJECTION_MEMORY_SIZE = 8  # @param {type:"slider", min:4, max:32, step:4}
INJECTION_N_HEADS = 4  # @param {type:"slider", min:2, max:8, step:2}

# @markdown ### ACT (Adaptive Computation Time) Parameters
HALT_MAX_STEPS = 4  # @param {type:"slider", min:1, max:16, step:1}
HALT_EXPLORATION_PROB = 0.1  # @param {type:"slider", min:0.0, max:1.0, step:0.05}
ALLOW_EARLY_HALT_EVAL = True  # @param {type:"boolean"}

# @markdown ### Training Hyperparameters
EPOCHS = 100  # @param {type:"slider", min:10, max:500, step:10}
BATCH_SIZE = 64  # @param {type:"slider", min:16, max:256, step:16}
LEARNING_RATE = 3e-4  # @param {type:"number"}
WEIGHT_DECAY = 0.01  # @param {type:"number"}
GRAD_CLIP = 1.0  # @param {type:"slider", min:0.1, max:5.0, step:0.1}
WARMUP_STEPS = 100  # @param {type:"slider", min:0, max:1000, step:50}
LR_MIN_RATIO = 0.1  # @param {type:"slider", min:0.01, max:1.0, step:0.01}
BETA1 = 0.9  # @param {type:"number"}
BETA2 = 0.95  # @param {type:"number"}

# @markdown ### HRM Gradient Style
HRM_GRAD_STYLE = True  # @param {type:"boolean"}

# @markdown ### Data
AUGMENT = True  # @param {type:"boolean"}
SEED = 42  # @param {type:"integer"}
N_GENERATED = 1000  # @param {type:"slider", min:0, max:100000, step:1000}
GEN_DIFFICULTY = "simple"  # @param ["simple", "larger", "two_boxes", "complex"]

# @markdown ### Curriculum Learning
CURRICULUM = False  # @param {type:"boolean"}
CURRICULUM_WARMUP = 0.3  # @param {type:"slider", min:0.1, max:0.5, step:0.1}

# @markdown ### Size Generalization (Padding)
PAD_TO_SIZE = 10  # @param {type:"slider", min:0, max:16, step:2}
# 0 = no padding, 10 = pad all boards to 10x10 for training on small, eval on large

# @markdown ### Logging
USE_WANDB = False  # @param {type:"boolean"}
WANDB_PROJECT = "sokoban-pot"  # @param {type:"string"}

# @markdown ### Evaluation
EVAL_DIFFICULTIES = ["simple", "complex"]  # Easy (6x6,1box) and Hard (10x10,2boxes)
EVAL_SAMPLES = 200  # @param {type:"slider", min:50, max:500, step:50}

# Build config dict for easy access
CONFIG = {
    # Model
    'model_type': MODEL_TYPE,
    'controller_type': CONTROLLER_TYPE,
    'd_model': D_MODEL,
    'd_ff': D_FF,
    'n_heads': N_HEADS,
    'n_layers': N_LAYERS,
    'dropout': DROPOUT,
    # PoT iterations
    'R': R,
    'T': T,
    # Hybrid H/L cycles
    'H_layers': H_LAYERS,
    'L_layers': L_LAYERS,
    'H_cycles': H_CYCLES,
    'L_cycles': L_CYCLES,
    # Controller
    'd_ctrl': D_CTRL,
    'max_depth': MAX_DEPTH,
    # Feature Injection
    'injection_mode': INJECTION_MODE,
    'injection_memory_size': INJECTION_MEMORY_SIZE,
    'injection_n_heads': INJECTION_N_HEADS,
    # ACT
    'halt_max_steps': HALT_MAX_STEPS,
    'halt_exploration_prob': HALT_EXPLORATION_PROB,
    'allow_early_halt_eval': ALLOW_EARLY_HALT_EVAL,
    # Training
    'epochs': EPOCHS,
    'batch_size': BATCH_SIZE,
    'lr': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'grad_clip': GRAD_CLIP,
    'warmup_steps': WARMUP_STEPS,
    'lr_min_ratio': LR_MIN_RATIO,
    'beta1': BETA1,
    'beta2': BETA2,
    # HRM
    'hrm_grad_style': HRM_GRAD_STYLE,
    # Data
    'augment': AUGMENT,
    'seed': SEED,
    'n_generated': N_GENERATED,
    'gen_difficulty': GEN_DIFFICULTY,
    # Curriculum
    'curriculum': CURRICULUM,
    'curriculum_warmup': CURRICULUM_WARMUP,
    # Size generalization
    'pad_to_size': PAD_TO_SIZE,
}

print(f"Config: {MODEL_TYPE} ({CONTROLLER_TYPE})")
print(f"  Architecture: d={D_MODEL}, ff={D_FF}, heads={N_HEADS}, layers={N_LAYERS}")
print(f"  PoT: R={R}, T={T}")
print(f"  Hybrid: H_cycles={H_CYCLES}, L_cycles={L_CYCLES}")
print(f"  Controller: d_ctrl={D_CTRL}, max_depth={MAX_DEPTH}")
print(f"  Injection: mode={INJECTION_MODE}")
print(f"  ACT: halt_max={HALT_MAX_STEPS}")
print(f"  Training: epochs={EPOCHS}, lr={LEARNING_RATE}, batch={BATCH_SIZE}")


In [None]:
# @title 📥 Load Dataset
# @markdown Downloads Sokoban dataset from HuggingFace + optional generated data
# @markdown **Important:** Train/Val/Test are kept PURE (no augmentation leakage)

import torch
from torch.utils.data import DataLoader, Subset, Dataset
import numpy as np

from src.data.sokoban_hf import SokobanHFDataset, SokobanCombinedDataset

# Padding wrapper for size generalization
class PaddedDataset(Dataset):
    """Wraps a dataset and pads all boards to target size with walls."""
    TILE_WALL = 0  # Wall tile for padding
    
    def __init__(self, dataset, target_size: int):
        self.dataset = dataset
        self.target_size = target_size
        self._orig_shape = dataset.board_shape
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        
        # Get original board
        board_input = sample['input']  # [H, W, 7] one-hot
        orig_h, orig_w = board_input.shape[:2]
        
        if orig_h >= self.target_size and orig_w >= self.target_size:
            return sample  # Already big enough
        
        # Calculate padding (center the original board)
        pad_h = self.target_size - orig_h
        pad_w = self.target_size - orig_w
        pad_top = pad_h // 2
        pad_left = pad_w // 2
        
        # Create padded one-hot board (all walls)
        padded = torch.zeros(self.target_size, self.target_size, 7)
        padded[:, :, self.TILE_WALL] = 1.0  # Fill with walls
        
        # Place original board in center
        padded[pad_top:pad_top+orig_h, pad_left:pad_left+orig_w] = board_input
        
        # Also pad board_indices if present
        if 'board_indices' in sample:
            board_idx = sample['board_indices']
            padded_idx = torch.zeros(self.target_size, self.target_size, dtype=torch.long)
            padded_idx[pad_top:pad_top+orig_h, pad_left:pad_left+orig_w] = board_idx
            sample['board_indices'] = padded_idx
        
        sample['input'] = padded
        return sample
    
    @property
    def board_shape(self):
        return (self.target_size, self.target_size)

# IMPORTANT: Load WITHOUT augmentation first to split cleanly
# Then apply augmentation only to training samples
print("Loading datasets...")

if N_GENERATED > 0:
    print(f"  HuggingFace + generating {N_GENERATED} additional {GEN_DIFFICULTY} puzzles...")
    # Load without augmentation for clean split
    full_ds_no_aug = SokobanCombinedDataset(
        hf_split="train",
        n_generated=N_GENERATED,
        difficulty=GEN_DIFFICULTY,
        augment=False,  # No augmentation for splitting
        seed=SEED,
    )
    # Load with augmentation for training
    full_ds_aug = SokobanCombinedDataset(
        hf_split="train",
        n_generated=N_GENERATED,
        difficulty=GEN_DIFFICULTY,
        augment=AUGMENT,
        seed=SEED,
    )
else:
    # Load without augmentation for clean split
    full_ds_no_aug = SokobanHFDataset(split="train", augment=False)
    # Load with augmentation for training
    full_ds_aug = SokobanHFDataset(split="train", augment=AUGMENT)

# Test set: completely separate (from HuggingFace 'test' split)
test_ds = SokobanHFDataset(split="test", augment=False)

# Apply padding for size generalization
if PAD_TO_SIZE > 0:
    orig_shape = full_ds_no_aug.board_shape
    print(f"  🔲 Padding {orig_shape} → ({PAD_TO_SIZE}×{PAD_TO_SIZE}) for size generalization")
    full_ds_no_aug = PaddedDataset(full_ds_no_aug, PAD_TO_SIZE)
    full_ds_aug = PaddedDataset(full_ds_aug, PAD_TO_SIZE)
    # Note: test_ds is NOT padded - we want to evaluate on native sizes

# Split indices (not datasets!) to keep train/val PURE
n_total = len(full_ds_no_aug)
val_size = min(500, n_total // 5)
train_size = n_total - val_size

# Deterministic shuffle
rng = np.random.default_rng(SEED)
indices = rng.permutation(n_total)
train_indices = indices[:train_size].tolist()
val_indices = indices[train_size:].tolist()

# Train: uses augmented dataset (on-the-fly augmentation)
# Val: uses non-augmented dataset (PURE - no augmentation)
train_subset = Subset(full_ds_aug, train_indices)
val_subset = Subset(full_ds_no_aug, val_indices)

# Create data loaders (curriculum handled in training loop)
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# For curriculum: get solution lengths if available (generated data only)
curriculum_order = None
if CURRICULUM and N_GENERATED > 0 and hasattr(full_ds_no_aug, 'generated_examples'):
    # Sort by solution length (easiest first)
    sol_lengths = []
    for idx in train_indices:
        if idx >= full_ds_no_aug.n_hf:
            gen_idx = idx - full_ds_no_aug.n_hf
            sol_len = full_ds_no_aug.generated_examples[gen_idx].get('solution_length', 0)
        else:
            sol_len = 0  # HF data - no solution length, treat as easy
        sol_lengths.append((idx, sol_len))
    
    # Sort by solution length
    sol_lengths.sort(key=lambda x: x[1])
    curriculum_order = [idx for idx, _ in sol_lengths]
    avg_len = np.mean([s for _, s in sol_lengths if s > 0])
    print(f"   📚 Curriculum: sorted by solution length (avg={avg_len:.1f})")
elif CURRICULUM:
    print(f"   ⚠️ Curriculum requires N_GENERATED > 0 (need solution lengths)")

print(f"\n✅ Data splits (PURE - no leakage):")
print(f"   Train: {len(train_subset)} samples (augment={'ON' if AUGMENT else 'OFF'})")
print(f"   Val:   {len(val_subset)} samples (augment=OFF, pure)")
print(f"   Test:  {len(test_ds)} samples (augment=OFF, separate HF split)")
print(f"   Board shape: {full_ds_no_aug.board_shape}")
if N_GENERATED > 0:
    print(f"   +{N_GENERATED} generated {GEN_DIFFICULTY} puzzles")


In [None]:
# @title 🏗️ Create Model

from src.pot.models.sokoban_solver import PoTSokobanSolver, HybridPoTSokobanSolver, BaselineSokobanSolver

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)

H, W = full_ds_no_aug.board_shape
seq_len = H * W

# Controller kwargs for advanced controllers
# Note: n_heads is passed separately to the model, not in controller_kwargs
controller_kwargs = {
    'd_ctrl': D_CTRL,
    'max_depth': MAX_DEPTH,
}

# Injection kwargs for cross_attn mode
injection_kwargs = {
    'memory_size': INJECTION_MEMORY_SIZE,
    'n_heads': INJECTION_N_HEADS,
} if INJECTION_MODE == 'cross_attn' else None

if MODEL_TYPE == "pot":
    model = PoTSokobanSolver(
        d_model=D_MODEL,
        n_heads=N_HEADS,
        n_layers=N_LAYERS,
        d_ff=D_FF,
        dropout=DROPOUT,
        R=R,
        controller_type=CONTROLLER_TYPE,
        controller_kwargs=controller_kwargs,
        max_depth=MAX_DEPTH,
        board_height=H,
        board_width=W,
    )
elif MODEL_TYPE == "hybrid_pot":
    model = HybridPoTSokobanSolver(
        d_model=D_MODEL,
        n_heads=N_HEADS,
        H_layers=H_LAYERS,
        L_layers=L_LAYERS,
        d_ff=D_FF,
        dropout=DROPOUT,
        H_cycles=H_CYCLES,
        L_cycles=L_CYCLES,
        T=T,
        halt_max_steps=HALT_MAX_STEPS,
        halt_exploration_prob=HALT_EXPLORATION_PROB,
        allow_early_halt_eval=ALLOW_EARLY_HALT_EVAL,
        hrm_grad_style=HRM_GRAD_STYLE,
        controller_type=CONTROLLER_TYPE,
        controller_kwargs=controller_kwargs,
        injection_mode=INJECTION_MODE,
        injection_kwargs=injection_kwargs,
        board_height=H,
        board_width=W,
    )
else:
    # Baseline uses different signature (pure CNN, no transformer)
    model = BaselineSokobanSolver(
        n_filters=64,
        n_layers=N_LAYERS,
        d_hidden=D_MODEL,
        dropout=DROPOUT,
    )

model = model.to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✅ {MODEL_TYPE} ({CONTROLLER_TYPE}) model created")
print(f"   Parameters: {n_params:,}")
print(f"   Device: {device}")
print(f"   Board: {H}x{W} = {seq_len} tokens")


In [None]:
# @title 🚀 Train Model
# @markdown Supervised training with cross-entropy + Q-halt loss (identical to Sudoku)

from src.training.sokoban_supervised import train_supervised
import time

if USE_WANDB:
    import wandb
    run_name = f"{MODEL_TYPE}-{CONTROLLER_TYPE}-R{R}-H{H_CYCLES}L{L_CYCLES}"
    wandb.init(
        project=WANDB_PROJECT,
        name=run_name,
        config=CONFIG,
    )

print(f"Training {MODEL_TYPE} ({CONTROLLER_TYPE})...")
print(f"  Epochs: {EPOCHS}, LR: {LEARNING_RATE}, Batch: {BATCH_SIZE}")
print(f"  PoT: R={R}, T={T}, H_cycles={H_CYCLES}, L_cycles={L_CYCLES}")
if CURRICULUM and curriculum_order is not None:
    warmup_epochs = int(EPOCHS * CURRICULUM_WARMUP)
    print(f"  📚 Curriculum: {warmup_epochs} warmup epochs ({CURRICULUM_WARMUP:.0%})")
print()

start_time = time.time()

# Curriculum learning: gradually reveal harder samples
if CURRICULUM and curriculum_order is not None:
    from src.training.sokoban_supervised import train_epoch, evaluate
    from torch.utils.data import Sampler
    
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=LEARNING_RATE, 
        weight_decay=WEIGHT_DECAY,
        betas=(BETA1, BETA2),
    )
    
    best_val_acc = 0.0
    warmup_epochs = int(EPOCHS * CURRICULUM_WARMUP)
    
    for epoch in range(EPOCHS):
        # Calculate how much of the curriculum to use
        if epoch < warmup_epochs:
            # Gradually increase from 30% to 100% during warmup
            frac = 0.3 + 0.7 * (epoch / warmup_epochs)
        else:
            frac = 1.0
        
        n_samples = int(len(curriculum_order) * frac)
        epoch_indices = curriculum_order[:n_samples]
        
        # Create epoch-specific loader with curriculum subset
        epoch_subset = Subset(full_ds_aug, epoch_indices)
        epoch_loader = DataLoader(epoch_subset, batch_size=BATCH_SIZE, shuffle=True)
        
        # Train epoch
        train_metrics = train_epoch(model, epoch_loader, optimizer, device, epoch, use_pot=(MODEL_TYPE != "baseline"), grad_clip=GRAD_CLIP)
        val_metrics = evaluate(model, val_loader, device, compute_solve=(epoch % 10 == 0))
        
        if val_metrics['action_acc'] > best_val_acc:
            best_val_acc = val_metrics['action_acc']
        
        if epoch % 10 == 0 or epoch == EPOCHS - 1:
            solve_str = f", solve={val_metrics.get('solve_rate', 0):.2%}" if 'solve_rate' in val_metrics else ""
            print(f"Epoch {epoch+1}/{EPOCHS}: loss={train_metrics['loss']:.4f}, act_acc={val_metrics['action_acc']:.2%}{solve_str}, samples={n_samples}")
        
        if USE_WANDB:
            log_dict = {
                'epoch': epoch,
                'train_loss': train_metrics['loss'],
                'train_action_acc': train_metrics['action_acc'],  # Like cell_acc
                'val_action_acc': val_metrics['action_acc'],      # Like cell_acc
                'curriculum_frac': frac,
            }
            if 'solve_rate' in val_metrics:
                log_dict['val_solve_rate'] = val_metrics['solve_rate']  # Like grid_acc
            wandb.log(log_dict)
    
    results = {'best_val_acc': best_val_acc, 'final_solve_rate': val_metrics.get('solve_rate')}
else:
    # Standard training (no curriculum)
    results = train_supervised(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=EPOCHS,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        grad_clip=GRAD_CLIP,
        warmup_steps=WARMUP_STEPS,
        use_pot=(MODEL_TYPE != "baseline"),
        wandb_log=USE_WANDB,
        betas=(BETA1, BETA2),
        lr_min_ratio=LR_MIN_RATIO,
    )

train_time = time.time() - start_time

print(f"\n✅ Training complete!")
print(f"   Best val accuracy: {results['best_val_acc']:.2%}")
print(f"   Training time: {train_time / 60:.1f} min")


In [None]:
# @title 📈 Evaluate on HuggingFace Test Set
# @markdown Computes action_acc (like cell_acc) and solve_rate (like grid_acc)

from src.training.sokoban_supervised import evaluate

# Check for board size mismatch
train_shape = full_ds_no_aug.board_shape
test_shape = test_ds.board_shape

if train_shape != test_shape:
    print(f"⚠️ WARNING: Board size mismatch!")
    print(f"   Train: {train_shape}, Test: {test_shape}")
    print(f"   HF test evaluation may not work correctly.")
    print(f"   Use multi-difficulty evaluation with generated boards instead.")
    test_metrics = {'action_acc': 0.0, 'solve_rate': 0.0, 'loss': float('inf'), 'note': 'size_mismatch'}
else:
    # Compute both action_acc and solve_rate (like Sudoku's cell_acc and grid_acc)
    test_metrics = evaluate(model, test_loader, device, compute_solve=True, solve_samples=200)
    print(f"\n📊 HuggingFace Test Set Results:")
    print(f"   Action Acc: {test_metrics['action_acc']:.2%} (like cell_acc)")
    print(f"   Solve Rate: {test_metrics.get('solve_rate', 0):.2%} (like grid_acc)")
    print(f"   Loss: {test_metrics['loss']:.4f}")


In [None]:
# @title 🎯 Multi-Difficulty Evaluation
# @markdown Evaluate on Simple (6×6, 1 box) and Complex (10×10, 2 boxes)
# @markdown Reports both action_acc (like cell_acc) and solve_rate (like grid_acc)

from src.data.sokoban_generator import SokobanGeneratedDataset

difficulty_results = {}
model_size = full_ds_no_aug.board_shape  # Size model was trained on

for difficulty in EVAL_DIFFICULTIES:
    print(f"\nGenerating {difficulty} test boards...")
    
    eval_ds = SokobanGeneratedDataset(
        difficulty=difficulty,
        n_samples=EVAL_SAMPLES,
        seed=1042,
        augment=False,
    )
    
    eval_h, eval_w = eval_ds.board_shape
    
    # If using PAD_TO_SIZE and eval is smaller, pad it
    if PAD_TO_SIZE > 0 and (eval_h < model_size[0] or eval_w < model_size[1]):
        print(f"   🔲 Padding eval from {eval_ds.board_shape} → {model_size}")
        eval_ds = PaddedDataset(eval_ds, PAD_TO_SIZE)
        eval_h, eval_w = eval_ds.board_shape
    
    eval_loader = DataLoader(eval_ds, batch_size=BATCH_SIZE, shuffle=False)
    
    # Handle different board sizes
    if (eval_h, eval_w) != model_size:
        print(f"   ⚠️ Board size mismatch: model={model_size}, eval={eval_ds.board_shape}")
        print(f"   Skipping (model trained on fixed size)")
        difficulty_results[difficulty] = {'action_acc': None, 'solve_rate': None, 'note': 'size_mismatch'}
        continue
    
    # Compute both action_acc and solve_rate (like Sudoku's cell_acc and grid_acc)
    eval_metrics = evaluate(model, eval_loader, device, compute_solve=True, solve_samples=min(100, EVAL_SAMPLES))
    difficulty_results[difficulty] = {
        'action_acc': eval_metrics['action_acc'],  # Like cell_acc
        'solve_rate': eval_metrics.get('solve_rate', 0),  # Like grid_acc
        'loss': eval_metrics['loss'],
    }
    print(f"   {difficulty}: action_acc={eval_metrics['action_acc']:.2%}, solve_rate={eval_metrics.get('solve_rate', 0):.2%}")

print("\n" + "=" * 60)
print("MULTI-DIFFICULTY RESULTS (like Sudoku's cell_acc / grid_acc)")
print("=" * 60)
for diff, res in difficulty_results.items():
    if res.get('action_acc') is not None:
        print(f"  {diff}: action_acc={res['action_acc']:.2%}, solve_rate={res['solve_rate']:.2%}")
    else:
        print(f"  {diff}: N/A ({res.get('note', 'error')})")


In [None]:
# @title 📋 Final Comparison Table

# Get both action_acc (like cell_acc) and solve_rate (like grid_acc)
simple_action_acc = difficulty_results.get('simple', {}).get('action_acc', 0) or 0
simple_solve_rate = difficulty_results.get('simple', {}).get('solve_rate', 0) or 0
complex_action_acc = difficulty_results.get('complex', {}).get('action_acc', 0) or 0
complex_solve_rate = difficulty_results.get('complex', {}).get('solve_rate', 0) or 0

print("""
╔════════════════════════════════════════════════════════════════════════════════════════╗
║                         SOKOBAN BENCHMARK COMPARISON                                   ║
║                  (action_acc = like cell_acc, solve_rate = like grid_acc)              ║
╠════════════════════════════════════════════════════════════════════════════════════════╣
║ Method                   │ Simple (6×6,1)      │ Complex (10×10,2)     │ Notes          ║
║                          │ act_acc / solve     │ act_acc / solve       │                ║
╠══════════════════════════╪═════════════════════╪═══════════════════════╪════════════════╣
║ SFT (paper)              │   ~50%  /  ~30%     │   ~15%  /  ~5%        │ Baseline       ║
║ GPT-4 + LangGraph        │  varies / varies    │  varies / varies      │ Zero-shot      ║
║ RL (PPO)                 │   ~20%  /  ~10%     │   <5%   /  <1%        │ Very hard      ║
║ Random                   │   25%   /   ~1%     │   25%   /  <0.1%      │ 4 actions      ║
╠══════════════════════════╪═════════════════════╪═══════════════════════╪════════════════╣""")

print(f"║ PoT (this run)           │  {simple_action_acc:5.1%} / {simple_solve_rate:5.1%}    │  {complex_action_acc:5.1%} / {complex_solve_rate:5.1%}     │ R={R}            ║")
print("╚════════════════════════════════════════════════════════════════════════════════════════╝")

# Log final results to W&B
if USE_WANDB:
    import wandb
    wandb.log({
        'final/val_action_acc': results['best_val_acc'],
        'final/val_solve_rate': results.get('final_solve_rate', 0) or 0,
        'final/hf_test_action_acc': test_metrics['action_acc'],
        'final/hf_test_solve_rate': test_metrics.get('solve_rate', 0) or 0,
        'final/simple_action_acc': simple_action_acc,
        'final/simple_solve_rate': simple_solve_rate,
        'final/complex_action_acc': complex_action_acc,
        'final/complex_solve_rate': complex_solve_rate,
        'final/train_time_min': train_time / 60,
        'final/n_params': n_params,
    })
    # Log summary metrics
    wandb.run.summary['best_val_action_acc'] = results['best_val_acc']
    wandb.run.summary['simple_action_acc'] = simple_action_acc
    wandb.run.summary['simple_solve_rate'] = simple_solve_rate
    wandb.run.summary['complex_action_acc'] = complex_action_acc
    wandb.run.summary['complex_solve_rate'] = complex_solve_rate
    wandb.run.summary['n_params'] = n_params
    print(f"\n📊 Results logged to W&B: {wandb.run.url}")

# Store ALL config + results for Optuna/hyperparameter search
COLAB_RESULTS = {
    # Full configuration (for Optuna)
    **CONFIG,
    # Results (using Sudoku-style naming)
    'best_val_action_acc': results['best_val_acc'],  # Like cell_acc
    'best_val_solve_rate': results.get('final_solve_rate'),  # Like grid_acc
    'hf_test_action_acc': test_metrics['action_acc'],
    'hf_test_solve_rate': test_metrics.get('solve_rate'),
    'simple_action_acc': simple_action_acc,
    'simple_solve_rate': simple_solve_rate,
    'complex_action_acc': complex_action_acc,
    'complex_solve_rate': complex_solve_rate,
    'train_time_min': train_time / 60,
    'n_params': n_params,
    'difficulty_results': difficulty_results,
}

print(f"\n📊 Full Results (for Optuna search):")
print(f"   Model: {MODEL_TYPE} ({CONTROLLER_TYPE})")
print(f"   Architecture: d={D_MODEL}, ff={D_FF}, heads={N_HEADS}")
print(f"   PoT: R={R}, T={T}, H_cycles={H_CYCLES}, L_cycles={L_CYCLES}")
print(f"   ACT: halt_max={HALT_MAX_STEPS}, max_depth={MAX_DEPTH}")
print(f"   Val Action Acc: {results['best_val_acc']:.2%} (like cell_acc)")
print(f"   Val Solve Rate: {results.get('final_solve_rate', 0) or 0:.2%} (like grid_acc)")
print(f"   Test Action Acc: {test_metrics['action_acc']:.2%}")
print(f"   Test Solve Rate: {test_metrics.get('solve_rate', 0) or 0:.2%}")
print(f"   Simple: action={simple_action_acc:.2%}, solve={simple_solve_rate:.2%}")
print(f"   Complex: action={complex_action_acc:.2%}, solve={complex_solve_rate:.2%}")
print(f"   Params: {n_params:,}")
print(f"   Time: {train_time / 60:.1f} min")


In [None]:
# @title 💾 Save Model & Results (Optional)

import os
import json

os.makedirs("checkpoints", exist_ok=True)

# Save model checkpoint
model_name = f"sokoban_{MODEL_TYPE}_{CONTROLLER_TYPE}_R{R}_H{H_CYCLES}L{L_CYCLES}"
save_path = f"checkpoints/{model_name}.pt"

torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'results': COLAB_RESULTS,
}, save_path)

print(f"✅ Model saved to {save_path}")

# Save results JSON (for Optuna aggregation)
results_path = f"checkpoints/{model_name}_results.json"
with open(results_path, 'w') as f:
    # Convert non-serializable items
    results_to_save = {k: v for k, v in COLAB_RESULTS.items() 
                       if not isinstance(v, dict) or k == 'difficulty_results'}
    json.dump(results_to_save, f, indent=2, default=str)

print(f"✅ Results saved to {results_path}")

# Print command to download
print(f"\n📥 Download command:")
print(f"   from google.colab import files")
print(f"   files.download('{save_path}')")
print(f"   files.download('{results_path}')")

# Finish W&B run
if USE_WANDB:
    import wandb
    wandb.save(save_path)  # Save model to W&B
    wandb.save(results_path)  # Save results JSON to W&B
    wandb.finish()
    print(f"\n✅ W&B run finished and artifacts uploaded!")


## 📝 Notes

### Metrics (Identical to Sudoku)

| Sudoku | Sokoban | Description |
|--------|---------|-------------|
| `cell_acc` | `action_acc` | % of single predictions correct |
| `grid_acc` | `solve_rate` | % of puzzles fully solved |

**Note:** `solve_rate` runs full rollouts (model plays until solved/stuck), so it's slower to compute.

### Loss Function (Identical to Sudoku - 3 losses)
```python
# LOSS 1: Main task (cross-entropy on action)
ce_loss = cross_entropy(action_logits, action_label)

# LOSS 2: Q-halt (should I stop iterating?)
q_halt_loss = bce(q_halt, is_correct)

# LOSS 3: Q-continue (ACT Q-learning)
q_continue_loss = mse(sigmoid(q_continue), target_q_continue)

# Combined
loss = ce_loss + 0.5 * q_halt_loss + 0.5 * q_continue_loss
```

### Configuration Parameters for Optuna Search

| Parameter | Description | Typical Range |
|-----------|-------------|---------------|
| **Architecture** | | |
| `D_MODEL` | Hidden dimension | 64-512 |
| `D_FF` | Feedforward dimension | 128-2048 |
| `N_HEADS` | Attention heads | 2-16 |
| `N_LAYERS` | Transformer layers | 1-8 |
| **PoT Iterations** | | |
| `R` | Refinement iterations | 1-16 |
| `T` | HRM period | 1-8 |
| **Hybrid H/L Cycles** | | |
| `H_CYCLES` | High-level (slow) cycles | 1-8 |
| `L_CYCLES` | Low-level (fast) cycles | 1-16 |
| `H_LAYERS` | Layers in H-level | 1-4 |
| `L_LAYERS` | Layers in L-level | 1-4 |
| **Controller** | | |
| `D_CTRL` | Controller hidden dimension | 32-256 |
| `MAX_DEPTH` | Max controller depth | 32-256 |
| **Feature Injection** | | |
| `INJECTION_MODE` | Injection mode | none/broadcast/film/etc |
| `INJECTION_MEMORY_SIZE` | Memory size (cross_attn) | 4-32 |
| `INJECTION_N_HEADS` | Heads (cross_attn) | 2-8 |
| **ACT (Adaptive Computation)** | | |
| `HALT_MAX_STEPS` | Max halting steps | 1-16 |
| `HALT_EXPLORATION_PROB` | Exploration probability | 0.0-1.0 |
| **Controller Types** | | |
| `transformer` | CausalDepthTransformerRouter | Default |
| `gru` | GRU-based controller | Fast |
| `lstm` | LSTM-based controller | |
| `diffusion` | Diffusion denoising | DiT-style |
| `swin` | Swin Transformer | Vision |
| `mamba` | Mamba SSM | State-space |
| **Data Generation** | | |
| `N_GENERATED` | Extra puzzles to generate | 0-100000 |
| `GEN_DIFFICULTY` | Difficulty of generated | simple/larger/two_boxes/complex |
| **Curriculum Learning** | | |
| `CURRICULUM` | Enable curriculum (easy→hard) | True/False |
| `CURRICULUM_WARMUP` | Fraction of epochs for warmup | 0.1-0.5 |
| **Size Generalization** | | |
| `PAD_TO_SIZE` | Pad boards to NxN (0=off) | 0, 10, 12, 16 |

### Difficulty Levels
| Difficulty | Size | Boxes | Avg Solution Length |
|------------|------|-------|---------------------|
| simple | 6×6 | 1 | ~4 moves |
| larger | 10×10 | 1 | ~15 moves |
| two_boxes | 6×6 | 2 | ~20 moves |
| complex | 10×10 | 2 | ~18 moves |

### Example Optuna Study
```python
import optuna

def objective(trial):
    R = trial.suggest_int('R', 1, 16)
    H_CYCLES = trial.suggest_int('H_cycles', 1, 8)
    L_CYCLES = trial.suggest_int('L_cycles', 1, 16)
    D_MODEL = trial.suggest_categorical('d_model', [128, 256, 512])
    # ... train and return complex_acc
    return complex_acc

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)
```

### References
- [Debunking SFT Generalization](https://arxiv.org/pdf/2510.00237)
- [LLMs + LangGraph for Sokoban](https://blog.gopenai.com/using-llms-and-langgraph-to-tackle-sokoban-puzzles-5f50b43b9515)
- [HuggingFace Dataset](https://huggingface.co/datasets/Xiaofeng77/sokoban)
