# Sudoku-Extreme Training with Alternative Depth Controllers

This notebook trains the HybridPoHHRMSolver on Sudoku-Extreme with different depth controllers:
- **LSTM Controller** - Stronger gating than GRU
- **Transformer Controller** - Causal attention over depth history
- **PoT Transformer Controller** - **Nested PoT** with gated MHA internally
- **GRU Controller** - Original baseline
- **xLSTM Controller** - Exponential gating
- **minGRU Controller** - Simplified GRU

## Features
- On-the-fly Sudoku augmentation (digit permutation, transpose, row/col shuffling)
- W&B logging (optional)
- Checkpoint saving
- Cosine LR schedule

## Quick Start (Colab)
1. Run "0. Colab Setup" to clone repo and install deps
2. Run "1. Download Dataset" to get Sudoku-Extreme from HuggingFace
3. Run remaining cells to train

## 0. Colab Setup (Run this first if on Colab)

In [None]:
# ============================================================================
# COLAB SETUP - Run this cell first if you're on Google Colab
# ============================================================================

# Check if running on Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repository
    !git clone https://github.com/Eran-BA/PoT.git /content/PoT 2>/dev/null || (cd /content/PoT && git pull)
    %cd /content/PoT
    
    # Install dependencies
    !pip install -q wandb tqdm huggingface_hub
    
    # Add to path
    sys.path.insert(0, '/content/PoT')
    
    print("‚úì Colab setup complete!")
    print(f"  Working directory: {%pwd}")
else:
    print("Not running on Colab - skipping setup")

## 1. Download Sudoku-Extreme Dataset

Downloads raw puzzles from HuggingFace (`sapientinc/sudoku-extreme`).
- 10,000 training puzzles (augmentation applied **on-the-fly** during training)
- 1,000 validation puzzles (no augmentation)

In [None]:
# ============================================================================
# DOWNLOAD SUDOKU-EXTREME DATASET (raw puzzles - augmentation is on-the-fly)
# ============================================================================

import os
import csv
import numpy as np
import torch
from tqdm import tqdm
from huggingface_hub import hf_hub_download

def download_sudoku_dataset(output_dir, subsample_size=10000, val_size=1000, download_test=True):
    """
    Download Sudoku-Extreme from HuggingFace and save as .pt files.
    
    Augmentation is NOT applied here - it happens on-the-fly during training
    via the SudokuDataset class with shuffle_sudoku().
    
    Args:
        output_dir: Where to save train.pt, val.pt, and optionally test.pt
        subsample_size: Number of training puzzles to use
        val_size: Number of validation puzzles
        download_test: If True, also download the full test set (422k puzzles)
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Check if already downloaded
    has_train_val = os.path.exists(f"{output_dir}/train.pt") and os.path.exists(f"{output_dir}/val.pt")
    has_test = os.path.exists(f"{output_dir}/test.pt")
    
    if has_train_val and (not download_test or has_test):
        train_data = torch.load(f"{output_dir}/train.pt")
        val_data = torch.load(f"{output_dir}/val.pt")
        print(f"‚úì Dataset already exists at {output_dir}")
        print(f"  Train: {len(train_data['inputs'])} puzzles")
        print(f"  Val: {len(val_data['inputs'])} puzzles")
        if has_test:
            test_data = torch.load(f"{output_dir}/test.pt")
            print(f"  Test: {len(test_data['inputs'])} puzzles")
        return
    
    print("Downloading Sudoku-Extreme from HuggingFace...")
    
    # Download train CSV from HuggingFace
    csv_path = hf_hub_download(
        repo_id="sapientinc/sudoku-extreme",
        filename="train.csv",
        repo_type="dataset"
    )
    
    # Parse CSV
    inputs, solutions = [], []
    with open(csv_path, newline="") as f:
        reader = csv.reader(f)
        next(reader)  # Skip header
        for row in tqdm(reader, desc="Parsing CSV"):
            source, q, a, rating = row
            # Convert puzzle string to numpy array
            # '.' or '0' = blank, '1'-'9' = digits
            inp = np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8) - ord('0')
            sol = np.frombuffer(a.encode(), dtype=np.uint8) - ord('0')
            inputs.append(inp)
            solutions.append(sol)
    
    print(f"  Total puzzles in dataset: {len(inputs)}")
    
    # Shuffle and split
    total = len(inputs)
    indices = np.random.permutation(total)
    
    train_size = min(subsample_size, total - val_size)
    train_idx = indices[:train_size]
    val_idx = indices[train_size:train_size + val_size]
    
    # Save train (raw - augmentation happens on-the-fly in DataLoader!)
    train_inputs = torch.tensor(np.array([inputs[i] for i in train_idx]), dtype=torch.long)
    train_solutions = torch.tensor(np.array([solutions[i] for i in train_idx]), dtype=torch.long)
    torch.save({
        'inputs': train_inputs,
        'solutions': train_solutions,
    }, f"{output_dir}/train.pt")
    print(f"‚úì Train: {len(train_idx)} puzzles saved (augmentation: ON-THE-FLY)")
    
    # Save val (no augmentation ever)
    val_inputs = torch.tensor(np.array([inputs[i] for i in val_idx]), dtype=torch.long)
    val_solutions = torch.tensor(np.array([solutions[i] for i in val_idx]), dtype=torch.long)
    torch.save({
        'inputs': val_inputs,
        'solutions': val_solutions,
    }, f"{output_dir}/val.pt")
    print(f"‚úì Val: {len(val_idx)} puzzles saved")
    
    # Download and save full test set (422k puzzles)
    if download_test:
        print("\nDownloading full test set (422k puzzles)...")
        test_csv_path = hf_hub_download(
            repo_id="sapientinc/sudoku-extreme",
            filename="test.csv",
            repo_type="dataset"
        )
        
        test_inputs, test_solutions = [], []
        with open(test_csv_path, newline="") as f:
            reader = csv.reader(f)
            next(reader)  # Skip header
            for row in tqdm(reader, desc="Parsing test CSV"):
                source, q, a, rating = row
                inp = np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8) - ord('0')
                sol = np.frombuffer(a.encode(), dtype=np.uint8) - ord('0')
                test_inputs.append(inp)
                test_solutions.append(sol)
        
        test_inputs_t = torch.tensor(np.array(test_inputs), dtype=torch.long)
        test_solutions_t = torch.tensor(np.array(test_solutions), dtype=torch.long)
        torch.save({
            'inputs': test_inputs_t,
            'solutions': test_solutions_t,
        }, f"{output_dir}/test.pt")
        print(f"‚úì Test: {len(test_inputs)} puzzles saved")
    
    print(f"\n‚úì Dataset saved to {output_dir}")

In [None]:
# Download the dataset
# Adjust paths for Colab vs local

if IN_COLAB:
    DATA_DIR = "/content/PoT/data/sudoku-extreme-10k"
else:
    DATA_DIR = "../data/sudoku-extreme-10k"

download_sudoku_dataset(
    output_dir=DATA_DIR,
    subsample_size=10000,  # 10k training puzzles
    val_size=1000,         # 1k validation puzzles
)

## 2. Setup & Imports

In [None]:
# Install dependencies if needed (uncomment if running on Colab)
# !pip install torch wandb tqdm

In [None]:
import os
import sys
from pathlib import Path
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

# Add project root to path
PROJECT_ROOT = Path(".").resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
from src.pot.models.sudoku_solver import HybridPoHHRMSolver

## 3. Configuration

In [None]:
# =============================================================================
# CONFIGURATION - Modify these settings as needed
# =============================================================================

config = {
    # Controller type: "lstm", "transformer", "pot_transformer", "gru", "xlstm", "mingru"
    # pot_transformer = Nested PoT (gated MHA inside the depth controller itself)
    "controller_type": "lstm",  # <-- CHANGE THIS TO SWITCH CONTROLLERS
    
    # Data (uses DATA_DIR from download step)
    "data_dir": DATA_DIR,
    
    # Model architecture
    "d_model": 512,
    "d_ff": 2048,  # Feedforward hidden dimension (typically 4x d_model)
    "n_heads": 8,
    "h_layers": 2,
    "l_layers": 2,
    "h_cycles": 2,
    "l_cycles": 8,
    "dropout": 0.01,  # Dropout rate
    
    # HRM/ACT parameters
    "T": 4,  # HRM period for pointer controller
    "halt_max_steps": 2,  # Max ACT outer steps (1 = no ACT)
    "halt_exploration_prob": 0.1,  # Q-learning exploration probability
    "hrm_grad_style": True,  # Only last L+H cycles get gradients (HRM-style)
    
    # Transformer/PoT-Transformer controller specific 
    # (only used if controller_type="transformer" or "pot_transformer")
    "d_ctrl": 256,
    "n_ctrl_layers": 2,
    "n_ctrl_heads": 4,
    "max_depth": 32,
    
    # Training
    "epochs": 1000,
    "batch_size": 768,
    "lr": 1e-4,
    "weight_decay": 0.1,
    "beta2": 0.95,  # AdamW beta2 (Llama-style, also used in HRM)
    "warmup_steps": 1000,
    "lr_min_ratio": 0.1,  # Cosine decay floor (10% of peak LR)
    "augment": True,  # On-the-fly Sudoku augmentation
    "eval_interval": 10,  # Evaluate every N epochs
    "halt_histogram_interval": 50,  # Track halt histograms every N epochs
    "num_workers": 4,  # DataLoader workers (parallel data loading)
    
    # Async batching (HRM-style) - samples that halt early are replaced immediately
    "async_batch": True,  # HRM-style async batching for maximum GPU utilization
    
    # Puzzle embedding optimizer (HRM-style dual optimizer)
    "use_puzzle_optimizer": True,  # Separate optimizer for puzzle embeddings
    "puzzle_lr_multiplier": 100.0,  # Puzzle LR = lr * multiplier (HRM-style: puzzle embeds learn 100x faster)
    "puzzle_weight_decay": 0.1,
    "puzzle_optimizer": "adamw",  # "adamw" or "signsgd"
    
    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    
    # Logging
    "use_wandb": False,  # Set to True to enable W&B logging
    "wandb_project": "sudoku-controllers",
    
    # Checkpoints
    "save_dir": "../checkpoints",
    "save_every": 50,
    "resume_from": None,  # Path to checkpoint to resume from (e.g., "checkpoints/best_model.pt")
}

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 4. Sudoku Augmentation (On-the-Fly)

In [None]:
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
    """
    Apply validity-preserving augmentation to a Sudoku puzzle.
    
    Transforms include:
    - Digit permutation (1-9 -> random permutation)
    - Transpose (50% chance)
    - Row band shuffling (shuffle the 3 bands of 3 rows each)
    - Row shuffling within bands
    - Column stack shuffling (shuffle the 3 stacks of 3 columns each)  
    - Column shuffling within stacks
    
    Args:
        board: Input puzzle [81] with 0=blank, 1-9=digits
        solution: Solution [81] with 1-9
        
    Returns:
        Augmented (board, solution) tuple
    """
    # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
    digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
    
    # Randomly decide whether to transpose
    transpose_flag = np.random.rand() < 0.5

    # Generate a valid row permutation:
    # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
    bands = np.random.permutation(3)
    row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])

    # Similarly for columns (stacks).
    stacks = np.random.permutation(3)
    col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])

    # Build an 81->81 mapping
    mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])

    def apply_transformation(x: np.ndarray) -> np.ndarray:
        # Reshape to 9x9 for transpose
        x_2d = x.reshape(9, 9)
        if transpose_flag:
            x_2d = x_2d.T
        x_flat = x_2d.flatten()
        # Apply row/col permutation
        new_board = x_flat[mapping]
        # Apply digit mapping
        return digit_map[new_board]

    return apply_transformation(board), apply_transformation(solution)


class SudokuDataset(Dataset):
    """Sudoku dataset with on-the-fly augmentation.
    
    Returns dict format compatible with src.training functions:
    {'input': tensor, 'label': tensor, 'puzzle_id': tensor}
    """
    
    def __init__(self, inputs, solutions, puzzle_ids=None, augment=True):
        self.inputs = inputs
        self.solutions = solutions
        self.puzzle_ids = puzzle_ids if puzzle_ids is not None else torch.zeros(len(inputs), dtype=torch.long)
        self.augment = augment
        self._epoch = 0
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        inp = self.inputs[idx].numpy() if isinstance(self.inputs[idx], torch.Tensor) else self.inputs[idx]
        sol = self.solutions[idx].numpy() if isinstance(self.solutions[idx], torch.Tensor) else self.solutions[idx]
        pid = self.puzzle_ids[idx]
        
        if self.augment:
            inp, sol = shuffle_sudoku(inp, sol)
        
        # Return dict format for compatibility with src.training functions
        return {
            'input': torch.tensor(inp, dtype=torch.long),
            'label': torch.tensor(sol, dtype=torch.long),
            'puzzle_id': torch.tensor(pid, dtype=torch.long) if not isinstance(pid, torch.Tensor) else pid,
        }
    
    def on_epoch_end(self):
        """Called at end of each epoch (for compatibility with HRM training)."""
        self._epoch += 1

## 5. Data Loading

In [None]:
def load_sudoku_data(data_dir: str):
    """
    Load Sudoku-Extreme dataset.
    
    Expected format: .pt files with 'inputs' and 'solutions' tensors.
    """
    data_path = Path(data_dir)
    
    train_data = {}
    val_data = {}
    
    # Try to load preprocessed .pt files
    train_pt = data_path / "train.pt"
    val_pt = data_path / "val.pt"
    
    if train_pt.exists() and val_pt.exists():
        print(f"Loading preprocessed data from {data_dir}...")
        train_data = torch.load(train_pt, map_location="cpu")
        val_data = torch.load(val_pt, map_location="cpu")
        print(f"  Train samples: {len(train_data['inputs'])}")
        print(f"  Val samples: {len(val_data['inputs'])}")
    else:
        # Create synthetic data for testing
        print("‚ö†Ô∏è No preprocessed data found. Creating synthetic data for testing...")
        print("   Place train.pt and val.pt in the data directory for real training.")
        n_train, n_val = 1000, 200
        train_data = {
            'inputs': torch.randint(0, 10, (n_train, 81)),
            'solutions': torch.randint(1, 10, (n_train, 81)),
            'puzzle_ids': torch.zeros(n_train, dtype=torch.long),
        }
        val_data = {
            'inputs': torch.randint(0, 10, (n_val, 81)),
            'solutions': torch.randint(1, 10, (n_val, 81)),
            'puzzle_ids': torch.zeros(n_val, dtype=torch.long),
        }
    
    return train_data, val_data


def create_dataloaders(train_data, val_data, batch_size: int, augment: bool = True, num_workers: int = 4):
    """Create DataLoaders from data dicts with on-the-fly augmentation."""
    
    # Get or create puzzle IDs
    train_ids = train_data.get('puzzle_ids', torch.zeros(len(train_data['inputs']), dtype=torch.long))
    val_ids = val_data.get('puzzle_ids', torch.zeros(len(val_data['inputs']), dtype=torch.long))
    
    # Use SudokuDataset with augmentation for training
    train_dataset = SudokuDataset(
        train_data['inputs'],
        train_data['solutions'],
        train_ids,
        augment=augment,
    )
    
    # No augmentation for validation
    val_dataset = SudokuDataset(
        val_data['inputs'],
        val_data['solutions'],
        val_ids,
        augment=False,
    )
    
    # Use pin_memory for faster GPU transfer (only when using CUDA)
    pin_memory = torch.cuda.is_available()
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True,
        num_workers=num_workers, pin_memory=pin_memory
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=pin_memory
    )
    
    return train_loader, val_loader

In [None]:
# Load data
train_data, val_data = load_sudoku_data(config["data_dir"])
train_loader, val_loader = create_dataloaders(
    train_data, val_data, 
    batch_size=config["batch_size"],
    augment=config["augment"],
    num_workers=config.get("num_workers", 4),
)

print(f"\nDataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

## 6. Model Creation

In [None]:
def create_model(config):
    """Create HybridPoHHRMSolver with specified controller type."""
    
    controller_type = config["controller_type"]
    
    # Build controller kwargs based on type
    if controller_type in ("transformer", "pot_transformer"):
        # Transformer-based controllers (standard or nested PoT)
        controller_kwargs = {
            "d_ctrl": config["d_ctrl"],
            "n_ctrl_layers": config["n_ctrl_layers"],
            "n_ctrl_heads": config["n_ctrl_heads"],
            "max_depth": config["max_depth"],
            "token_conditioned": True,
        }
    else:
        # LSTM, GRU, xLSTM, minGRU
        controller_kwargs = {
            "d_ctrl": config["d_model"],
            "token_conditioned": True,
        }
    
    model = HybridPoHHRMSolver(
        d_model=config["d_model"],
        d_ff=config.get("d_ff", 2048),
        n_heads=config["n_heads"],
        H_layers=config["h_layers"],
        L_layers=config["l_layers"],
        H_cycles=config["h_cycles"],
        L_cycles=config["l_cycles"],
        dropout=config.get("dropout", 0.1),
        T=config.get("T", 4),
        halt_max_steps=config["halt_max_steps"],
        halt_exploration_prob=config.get("halt_exploration_prob", 0.1),
        hrm_grad_style=config.get("hrm_grad_style", True),
        controller_type=controller_type,
        controller_kwargs=controller_kwargs,
    )
    
    return model

In [None]:
# Create model
device = config["device"]
model = create_model(config).to(device)

param_count = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created on {device}")
print(f"  Controller: {config['controller_type'].upper()}")
print(f"  Total parameters: {param_count:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## 7. Training Functions

In [None]:
# =============================================================================
# TRAINING FUNCTIONS
# Import from src.training for HRM-compatible training with dual optimizers
# =============================================================================

try:
    from src.training import train_epoch, train_epoch_async, evaluate, log_halt_histogram_to_wandb
    print("‚úì Using src.training functions (HRM-compatible)")
    USE_SRC_TRAINING = True
except ImportError:
    print("‚ö†Ô∏è src.training not found, using built-in functions")
    USE_SRC_TRAINING = False

# Fallback halt histogram utility
def log_halt_histogram_to_wandb(halt_histogram, prefix=""):
    """Convert halt histogram to W&B-friendly format."""
    if not halt_histogram:
        return {}
    
    # Build raw steps list for wandb.Histogram
    steps_raw = []
    for step, count in halt_histogram.items():
        steps_raw.extend([step] * count)
    
    # Calculate stats
    total = sum(halt_histogram.values())
    if total == 0:
        return {}
    
    avg_steps = sum(s * c for s, c in halt_histogram.items()) / total
    max_step = max(halt_histogram.keys()) if halt_histogram else 0
    
    return {
        f"{prefix}halt_steps_raw": steps_raw,
        f"{prefix}halt_avg_steps": avg_steps,
        f"{prefix}halt_max_step": max_step,
    }

# Fallback training functions if src.training not available
if not USE_SRC_TRAINING:
    def train_epoch(model, dataloader, optimizer, puzzle_optimizer, device, epoch,
                    use_poh=True, debug=False, scheduler=None, puzzle_scheduler=None):
        """Train for one epoch with dual optimizer support."""
        model.train()
        base_model = model.module if hasattr(model, 'module') else model
        
        total_loss = 0
        correct_cells = 0
        total_cells = 0
        correct_grids = 0
        total_grids = 0
        total_steps = 0
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)
        for batch in pbar:
            inp = batch['input'].to(device)
            label = batch['label'].to(device)
            puzzle_ids = batch['puzzle_id'].to(device)
            
            optimizer.zero_grad()
            if puzzle_optimizer:
                puzzle_optimizer.zero_grad()
            
            # Forward
            model_out = model(inp, puzzle_ids)
            if len(model_out) == 5:
                logits, q_halt, q_continue, steps, target_q_continue = model_out
            else:
                logits, q_halt, q_continue, steps = model_out
                target_q_continue = None
            
            # CE loss
            lm_loss = nn.functional.cross_entropy(
                logits.view(-1, base_model.vocab_size),
                label.view(-1)
            )
            
            # Q-halt loss (if PoH)
            if use_poh and q_halt is not None:
                with torch.no_grad():
                    preds = logits.argmax(dim=-1)
                    is_correct = (preds == label).all(dim=1).float()
                
                q_halt_loss = nn.functional.binary_cross_entropy_with_logits(q_halt, is_correct)
                loss = lm_loss + 0.5 * q_halt_loss
                
                if target_q_continue is not None:
                    q_continue_loss = nn.functional.mse_loss(torch.sigmoid(q_continue), target_q_continue)
                    loss = loss + 0.5 * q_continue_loss
            else:
                loss = lm_loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            if puzzle_optimizer:
                puzzle_optimizer.step()
            
            # Step schedulers (per-step, not per-epoch!)
            if scheduler:
                scheduler.step()
            if puzzle_scheduler:
                puzzle_scheduler.step()
            
            # Metrics
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            correct_cells += (preds == label).sum().item()
            total_cells += label.numel()
            correct_grids += (preds == label).all(dim=1).sum().item()
            total_grids += label.size(0)
            total_steps += steps if isinstance(steps, int) else steps.float().mean().item()
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'cell': f'{100*correct_cells/total_cells:.1f}%',
                'grid': f'{100*correct_grids/total_grids:.1f}%',
            })
        
        return {
            'loss': total_loss / len(dataloader),
            'cell_acc': 100 * correct_cells / total_cells,
            'grid_acc': 100 * correct_grids / total_grids,
            'avg_steps': total_steps / len(dataloader),
        }

    @torch.no_grad()
    def evaluate(model, dataloader, device, use_poh=True, track_halt_histogram=False):
        """Evaluate model."""
        model.eval()
        base_model = model.module if hasattr(model, 'module') else model
        
        total_loss = 0
        correct_cells = 0
        total_cells = 0
        correct_grids = 0
        total_grids = 0
        
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            inp = batch['input'].to(device)
            label = batch['label'].to(device)
            puzzle_ids = batch['puzzle_id'].to(device)
            
            model_out = model(inp, puzzle_ids)
            logits = model_out[0]
            
            loss = nn.functional.cross_entropy(
                logits.view(-1, base_model.vocab_size),
                label.view(-1)
            )
            
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            correct_cells += (preds == label).sum().item()
            total_cells += label.numel()
            correct_grids += (preds == label).all(dim=1).sum().item()
            total_grids += label.size(0)
        
        return {
            'loss': total_loss / len(dataloader),
            'cell_acc': 100 * correct_cells / total_cells,
            'grid_acc': 100 * correct_grids / total_grids,
        }
    
    # Placeholder for async training (requires ACT model)
    def train_epoch_async(*args, **kwargs):
        raise NotImplementedError("Async training requires src.training module")

## 8. Setup Optimizer & Scheduler

In [None]:
# =============================================================================
# DUAL OPTIMIZER SETUP (HRM-style)
# =============================================================================
import math

# Separate puzzle embedding parameters from model parameters
puzzle_params = []
model_params = []

for name, param in model.named_parameters():
    if 'puzzle' in name.lower() or 'embedding' in name.lower():
        puzzle_params.append(param)
    else:
        model_params.append(param)

print(f"Model parameters: {sum(p.numel() for p in model_params):,}")
print(f"Puzzle parameters: {sum(p.numel() for p in puzzle_params):,}")

# Main optimizer (AdamW)
beta2 = config.get("beta2", 0.95)  # Llama-style / HRM default
betas = (0.9, beta2)
optimizer = optim.AdamW(
    model_params,
    lr=config["lr"],
    weight_decay=config["weight_decay"],
    betas=betas,
)

# Puzzle optimizer (optional - HRM uses separate optimizer for puzzle embeddings)
puzzle_optimizer = None
if config["use_puzzle_optimizer"] and len(puzzle_params) > 0:
    puzzle_lr = config["lr"] * config["puzzle_lr_multiplier"]
    
    if config["puzzle_optimizer"] == "signsgd":
        # SignSGD for puzzle embeddings (as in HRM paper)
        class SignSGD(optim.Optimizer):
            def __init__(self, params, lr=1e-3, weight_decay=0):
                defaults = dict(lr=lr, weight_decay=weight_decay)
                super().__init__(params, defaults)
            
            @torch.no_grad()
            def step(self, closure=None):
                for group in self.param_groups:
                    for p in group['params']:
                        if p.grad is None:
                            continue
                        d_p = p.grad.sign()
                        if group['weight_decay'] != 0:
                            d_p = d_p.add(p, alpha=group['weight_decay'])
                        p.add_(d_p, alpha=-group['lr'])
        
        puzzle_optimizer = SignSGD(puzzle_params, lr=puzzle_lr, weight_decay=config["puzzle_weight_decay"])
        print(f"Puzzle optimizer: SignSGD (lr={puzzle_lr:.2e})")
    else:
        puzzle_optimizer = optim.AdamW(
            puzzle_params,
            lr=puzzle_lr,
            weight_decay=config["puzzle_weight_decay"],
            betas=betas,
        )
        print(f"Puzzle optimizer: AdamW (lr={puzzle_lr:.2e})")
else:
    print("Puzzle optimizer: disabled")

# Cosine LR schedule with warmup (per-step, not per-epoch)
total_steps = config["epochs"] * len(train_loader)
warmup_steps = config["warmup_steps"]
lr_min_ratio = config.get("lr_min_ratio", 0.1)

def lr_lambda(step):
    if step < warmup_steps:
        return step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return lr_min_ratio + (1 - lr_min_ratio) * 0.5 * (1 + math.cos(math.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
puzzle_scheduler = optim.lr_scheduler.LambdaLR(puzzle_optimizer, lr_lambda) if puzzle_optimizer else None

print(f"\nOptimizer: AdamW (lr={config['lr']}, weight_decay={config['weight_decay']})")
print(f"Scheduler: Cosine with warmup ({warmup_steps} steps), min_ratio={lr_min_ratio}")
print(f"Total training steps: {total_steps:,}")

## 9. W&B Setup (Optional)

In [None]:
# Initialize W&B if enabled
if config["use_wandb"]:
    import wandb
    
    run_name = f"{config['controller_type']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    wandb.init(
        project=config["wandb_project"],
        name=run_name,
        config=config,
    )
    wandb.watch(model, log="gradients", log_freq=100)
    print(f"W&B initialized: {run_name}")
else:
    print("W&B logging disabled. Set config['use_wandb'] = True to enable.")

## 10. Training Loop

In [None]:
# Create checkpoint directory
save_dir = Path(config["save_dir"]) / config["controller_type"]
save_dir.mkdir(parents=True, exist_ok=True)
print(f"Checkpoints will be saved to: {save_dir}")

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_cell_acc': [],
    'train_grid_acc': [],
    'val_loss': [],
    'val_cell_acc': [],
    'val_grid_acc': [],
    'lr': [],
}

# Resume from checkpoint if specified
start_epoch = 1
best_grid_acc = 0
resume_checkpoint = None

if config.get("resume_from") and os.path.exists(config["resume_from"]):
    print(f"Resuming from checkpoint: {config['resume_from']}")
    resume_checkpoint = torch.load(config["resume_from"], map_location=device)
    
    # Load model weights
    model.load_state_dict(resume_checkpoint['model_state_dict'])
    print(f"  ‚úì Loaded model weights")
    
    # Load optimizer states
    if 'optimizer_state_dict' in resume_checkpoint:
        optimizer.load_state_dict(resume_checkpoint['optimizer_state_dict'])
        print(f"  ‚úì Loaded optimizer state")
    
    if puzzle_optimizer and 'puzzle_optimizer_state_dict' in resume_checkpoint:
        puzzle_optimizer.load_state_dict(resume_checkpoint['puzzle_optimizer_state_dict'])
        print(f"  ‚úì Loaded puzzle optimizer state")
    
    # Load scheduler states
    if 'scheduler_state_dict' in resume_checkpoint:
        scheduler.load_state_dict(resume_checkpoint['scheduler_state_dict'])
        print(f"  ‚úì Loaded scheduler state")
    
    if puzzle_scheduler and 'puzzle_scheduler_state_dict' in resume_checkpoint:
        puzzle_scheduler.load_state_dict(resume_checkpoint['puzzle_scheduler_state_dict'])
        print(f"  ‚úì Loaded puzzle scheduler state")
    
    # Restore training state
    if 'epoch' in resume_checkpoint:
        start_epoch = resume_checkpoint['epoch'] + 1
        print(f"  ‚úì Resuming from epoch {start_epoch}")
    
    if 'best_grid_acc' in resume_checkpoint:
        best_grid_acc = resume_checkpoint['best_grid_acc']
        print(f"  ‚úì Best grid accuracy so far: {best_grid_acc:.2f}%")
    elif 'grid_acc' in resume_checkpoint:
        best_grid_acc = resume_checkpoint['grid_acc']
        print(f"  ‚úì Grid accuracy from checkpoint: {best_grid_acc:.2f}%")

print(f"\n{'='*60}")
print(f"Starting {config['controller_type'].upper()} Controller Training")
if start_epoch > 1:
    print(f"  (Resuming from epoch {start_epoch})")
print(f"{'='*60}\n")

In [None]:
# =============================================================================
# MAIN TRAINING LOOP (HRM-compatible with dual optimizers)
# =============================================================================

use_async = config.get("async_batch", True)
eval_interval = config.get("eval_interval", 10)
halt_histogram_interval = config.get("halt_histogram_interval", 50)
val_metrics = None  # Will be set after first eval

for epoch in range(start_epoch, config["epochs"] + 1):
    # Determine if we should track halt histograms this epoch
    track_halt = (epoch % halt_histogram_interval == 0) or (epoch == 1)
    
    # Train (with dual optimizers and per-step scheduler)
    if use_async:
        train_metrics = train_epoch_async(
            model, train_loader, optimizer, puzzle_optimizer,
            device, epoch, use_poh=True,
            scheduler=scheduler, puzzle_scheduler=puzzle_scheduler,
            track_halt_histogram=track_halt,
        )
    else:
        train_metrics = train_epoch(
            model, train_loader, optimizer, puzzle_optimizer,
            device, epoch, use_poh=True,
            scheduler=scheduler, puzzle_scheduler=puzzle_scheduler,
        )
    
    # Call on_epoch_end (for augmentation reset, etc.)
    if hasattr(train_loader.dataset, 'on_epoch_end'):
        train_loader.dataset.on_epoch_end()
    
    # Evaluate (every eval_interval epochs or first/last epoch)
    do_eval = (epoch % eval_interval == 0) or (epoch == 1) or (epoch == config["epochs"])
    if do_eval:
        val_metrics = evaluate(model, val_loader, device, use_poh=True, track_halt_histogram=track_halt)
    
    # Get current LR
    current_lr = scheduler.get_last_lr()[0]
    
    # Store history
    history['train_loss'].append(train_metrics['loss'])
    history['train_cell_acc'].append(train_metrics['cell_acc'] / 100)  # Normalize to 0-1
    history['train_grid_acc'].append(train_metrics['grid_acc'] / 100)
    if val_metrics:
        history['val_loss'].append(val_metrics['loss'])
        history['val_cell_acc'].append(val_metrics['cell_acc'] / 100)
        history['val_grid_acc'].append(val_metrics['grid_acc'] / 100)
    history['lr'].append(current_lr)
    
    # Print progress
    train_str = f"loss={train_metrics['loss']:.4f}, cell={train_metrics['cell_acc']:.1f}%, grid={train_metrics['grid_acc']:.1f}%"
    if val_metrics:
        val_str = f"loss={val_metrics['loss']:.4f}, cell={val_metrics['cell_acc']:.1f}%, grid={val_metrics['grid_acc']:.1f}%"
        print(f"Epoch {epoch}/{config['epochs']} | Train: {train_str} | Val: {val_str}")
    else:
        print(f"Epoch {epoch}/{config['epochs']} | Train: {train_str}")
    
    # W&B logging
    if config["use_wandb"]:
        log_dict = {
            "epoch": epoch,
            "train/loss": train_metrics['loss'],
            "train/cell_acc": train_metrics['cell_acc'],
            "train/grid_acc": train_metrics['grid_acc'],
            "lr": current_lr,
            "best_grid_acc": best_grid_acc,
        }
        if val_metrics:
            log_dict.update({
                "val/loss": val_metrics['loss'],
                "val/cell_acc": val_metrics['cell_acc'],
                "val/grid_acc": val_metrics['grid_acc'],
            })
        
        # Log halt histograms (if tracked this epoch)
        if 'halt_histogram' in train_metrics:
            train_halt_log = log_halt_histogram_to_wandb(train_metrics['halt_histogram'], prefix="train_")
            log_dict.update(train_halt_log)
            if train_halt_log.get("train_halt_steps_raw"):
                log_dict["train_halt_histogram"] = wandb.Histogram(train_halt_log["train_halt_steps_raw"])
        
        if val_metrics and 'halt_histogram' in val_metrics:
            val_halt_log = log_halt_histogram_to_wandb(val_metrics['halt_histogram'], prefix="val_")
            log_dict.update(val_halt_log)
            if val_halt_log.get("val_halt_steps_raw"):
                log_dict["val_halt_histogram"] = wandb.Histogram(val_halt_log["val_halt_steps_raw"])
        
        wandb.log(log_dict)
    
    # Save best model (only when we have val metrics)
    if val_metrics and val_metrics['grid_acc'] > best_grid_acc:
        best_grid_acc = val_metrics['grid_acc']
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_grid_acc': best_grid_acc,
            'config': config,
        }
        if puzzle_optimizer:
            checkpoint['puzzle_optimizer_state_dict'] = puzzle_optimizer.state_dict()
        if puzzle_scheduler:
            checkpoint['puzzle_scheduler_state_dict'] = puzzle_scheduler.state_dict()
        best_model_path = save_dir / "best_model.pt"
        torch.save(checkpoint, best_model_path)
        print(f"  ‚úì New best! Grid acc: {best_grid_acc:.2f}%")
        
        # Upload to W&B as artifact
        if config["use_wandb"]:
            artifact = wandb.Artifact(
                f"sudoku-{config['controller_type']}-best",
                type="model",
                metadata={"grid_acc": best_grid_acc, "epoch": epoch}
            )
            artifact.add_file(str(best_model_path))
            wandb.log_artifact(artifact, aliases=["best", "latest"])
            print(f"  üì§ Uploaded to W&B artifacts")
    
    # Periodic checkpoint
    if epoch % config["save_every"] == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'grid_acc': val_metrics['grid_acc'] if val_metrics else train_metrics['grid_acc'],
            'best_grid_acc': best_grid_acc,
            'config': config,
        }
        if puzzle_optimizer:
            checkpoint['puzzle_optimizer_state_dict'] = puzzle_optimizer.state_dict()
        if puzzle_scheduler:
            checkpoint['puzzle_scheduler_state_dict'] = puzzle_scheduler.state_dict()
        torch.save(checkpoint, save_dir / f"checkpoint_epoch{epoch}.pt")
        print(f"  üíæ Checkpoint saved: epoch {epoch}")

print(f"\n{'='*60}")
print(f"Training complete! Best grid accuracy: {best_grid_acc:.2f}%")
print(f"{'='*60}")

if config["use_wandb"]:
    wandb.finish()

## 11. Training Visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Cell Accuracy
axes[1].plot([100*x for x in history['train_cell_acc']], label='Train')
axes[1].plot([100*x for x in history['val_cell_acc']], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Cell Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Grid Accuracy
axes[2].plot([100*x for x in history['train_grid_acc']], label='Train')
axes[2].plot([100*x for x in history['val_grid_acc']], label='Val')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Accuracy (%)')
axes[2].set_title('Grid Accuracy')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle(f"{config['controller_type'].upper()} Controller Training", fontsize=14)
plt.tight_layout()
plt.savefig(save_dir / "training_curves.png", dpi=150)
plt.show()

print(f"\nPlot saved to: {save_dir / 'training_curves.png'}")

## 12. Load Best Model & Final Evaluation

In [None]:
# Load best model
checkpoint = torch.load(save_dir / "best_model.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Loaded best model from epoch {checkpoint['epoch']}")
print(f"Best grid accuracy: {100*checkpoint['best_grid_acc']:.2f}%")

In [None]:
# Final evaluation
final_metrics = evaluate(model, val_loader, device)

print(f"\n{'='*40}")
print(f"Final Evaluation Results")
print(f"{'='*40}")
print(f"Controller: {config['controller_type'].upper()}")
print(f"Val Loss: {final_metrics['loss']:.4f}")
print(f"Cell Accuracy: {100*final_metrics['cell_acc']:.2f}%")
print(f"Grid Accuracy: {100*final_metrics['grid_acc']:.2f}%")
print(f"{'='*40}")

## 14. Test Set Evaluation (422k puzzles)

Evaluate the best model on the full Sudoku-Extreme test set.

In [None]:
# ============================================================================
# TEST SET EVALUATION (422k puzzles)
# ============================================================================

# Load test set
test_pt_path = Path(config["data_dir"]) / "test.pt"

if test_pt_path.exists():
    print("Loading test set...")
    test_data = torch.load(test_pt_path, map_location="cpu", weights_only=True)
    
    # Create test dataset (no augmentation)
    test_ids = torch.zeros(len(test_data['inputs']), dtype=torch.long)
    test_dataset = SudokuDataset(
        test_data['inputs'],
        test_data['solutions'],
        test_ids,
        augment=False,
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config.get("num_workers", 4),
        pin_memory=torch.cuda.is_available(),
    )
    
    print(f"‚úì Test set loaded: {len(test_dataset):,} puzzles")
    
    # Load best model
    best_model_path = save_dir / "best_model.pt"
    if best_model_path.exists():
        checkpoint = torch.load(best_model_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"‚úì Loaded best model from epoch {checkpoint['epoch']}")
    
    # Evaluate on test set
    print("\nEvaluating on test set...")
    test_metrics = evaluate(model, test_loader, device, use_poh=True)
    
    print(f"\n{'='*60}")
    print(f"üéØ TEST RESULTS ({len(test_dataset):,} puzzles)")
    print(f"{'='*60}")
    print(f"Controller: {config['controller_type'].upper()}")
    print(f"Test Loss: {test_metrics['loss']:.4f}")
    print(f"Cell Accuracy: {test_metrics['cell_acc']:.2f}%")
    print(f"Grid Accuracy: {test_metrics['grid_acc']:.2f}%")
    print(f"{'='*60}")
    
    # Log to W&B if enabled
    if config["use_wandb"]:
        wandb.log({
            "test/loss": test_metrics['loss'],
            "test/cell_acc": test_metrics['cell_acc'],
            "test/grid_acc": test_metrics['grid_acc'],
        })
        print("‚úì Test results logged to W&B")
else:
    print(f"‚ö†Ô∏è Test set not found at {test_pt_path}")
    print("  Run download_sudoku_dataset(..., download_test=True) to download the 422k test set")

## 14. Test Set Evaluation (422k puzzles)

Evaluate the best model on the full Sudoku-Extreme test set.