# ðŸŒ€ Diffusion HRM Sudoku Solver

Train a **master-level Sudoku AI** using **fully diffusion-based architecture** where everything is diffusion!

## Architecture - Everything is Diffusion!
- **z_H, z_L**: Progressively **denoised** at different timescales (not GRU!)
- **L-level**: Fast denoising (every step)
- **H-level**: Slow denoising (learned timing via Gumbel-softmax)
- **adaLN conditioning**: DiT-style adaptive LayerNorm
- **Noise schedules**: Cosine schedules for both timescales
- **Skip connections**: HRM-style input injection + output residual

## Key Innovations
| Component | HybridHRM (GRU-based) | DiffusionHRMSolver |
|-----------|----------------------|-------------------|
| H,L updates | Deterministic GRU | Diffusion denoising |
| Timing | Fixed T=4 | Learned via Gumbel-softmax |
| State evolution | Recurrent | Progressive denoising |
| Controller | GRU/Transformer | Diffusion denoisers |

## Hardware Requirements
- **GPU**: A100/H100 recommended (16-80GB VRAM)
- **Runtime**: ~12-24 hours for full training


## 1. Setup


In [None]:
# Clone PoT repository
!git clone https://github.com/Eran-BA/PoT.git /content/PoT 2>/dev/null || (cd /content/PoT && git pull)
%cd /content/PoT

# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q tqdm numpy huggingface_hub
!pip install -q wandb  # Weights & Biases


In [None]:
# Verify GPU
import torch
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


In [None]:
# Login to Weights & Biases
import wandb
wandb.login()


## 2. Download Sudoku Dataset

Downloads the **Sudoku-Extreme** dataset from HuggingFace:
- 10,000 extreme-difficulty puzzles
- 100 augmentations per puzzle
- Total: ~1,000,000 training samples


In [None]:
# Download Sudoku-Extreme dataset from HuggingFace
from src.data import download_sudoku_dataset

download_sudoku_dataset(
    output_dir='data/sudoku-extreme-10k-aug-100',
    subsample_size=10000,
    num_aug=100,
)


## 3. Training Script with W&B Integration

The script below implements:
- **DiffusionSudokuSolver** with fully diffusion-based H,L cycles
- **W&B logging** for loss, accuracy, gradients, and hyperparameters
- **Cosine LR schedule** with warmup
- **Gradient clipping** for stable training
- **Best model checkpointing** with W&B artifact saving


In [None]:
%%writefile train_diffusion_sudoku.py
#!/usr/bin/env python3
"""
Diffusion HRM Sudoku Training Script with W&B Integration
==========================================================

Trains a fully diffusion-based Sudoku solver where:
- H,L cycles use diffusion denoising (not GRU)
- Timing for H-updates is learned via Gumbel-softmax
- adaLN conditioning (DiT-style)

Author: Eran Ben Artzy
Year: 2025
License: Apache 2.0
"""

import sys
import os
sys.path.insert(0, '/content/PoT')

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import wandb
from datetime import datetime

from src.data import SudokuDataset
from src.pot.models.diffusion_hrm_solver import DiffusionSudokuSolver


def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1):
    """Cosine learning rate schedule with warmup."""
    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + np.cos(np.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def evaluate(model, loader, device):
    """Evaluate model on dataset."""
    model.eval()
    correct_cells = 0
    total_cells = 0
    correct_grids = 0
    total_grids = 0
    total_steps = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating", leave=False):
            inputs = batch['input'].to(device)
            targets = batch['label'].to(device)
            puzzle_ids = batch.get('puzzle_id', torch.zeros(inputs.size(0), dtype=torch.long)).to(device)
            
            logits, _, _, steps = model(inputs, puzzle_ids)
            preds = logits.argmax(dim=-1)
            
            mask = (inputs == 0)
            correct_cells += ((preds == targets) & mask).sum().item()
            total_cells += mask.sum().item()
            
            grid_correct = ((preds == targets) | ~mask).all(dim=1)
            correct_grids += grid_correct.sum().item()
            total_grids += inputs.size(0)
            
            total_steps += steps
            num_batches += 1
    
    return {
        'cell_acc': 100 * correct_cells / max(1, total_cells),
        'grid_acc': 100 * correct_grids / max(1, total_grids),
        'avg_steps': total_steps / max(1, num_batches),
    }


def train_epoch(model, loader, optimizer, scheduler, device, epoch, log_interval=50):
    """Train for one epoch with W&B logging."""
    model.train()
    total_loss = 0
    num_batches = 0
    global_step = (epoch - 1) * len(loader)
    
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for batch_idx, batch in enumerate(pbar):
        inputs = batch['input'].to(device)
        targets = batch['label'].to(device)
        puzzle_ids = batch.get('puzzle_id', torch.zeros(inputs.size(0), dtype=torch.long)).to(device)
        
        logits, q_halt, q_continue, steps = model(inputs, puzzle_ids)
        
        mask = (inputs == 0)
        loss = F.cross_entropy(logits[mask], targets[mask])
        
        optimizer.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        num_batches += 1
        global_step += 1
        
        if batch_idx % log_interval == 0:
            wandb.log({
                'train/loss': loss.item(),
                'train/lr': scheduler.get_last_lr()[0],
                'train/grad_norm': grad_norm.item(),
                'train/act_steps': steps,
                'train/q_halt_mean': q_halt.mean().item(),
                'train/q_continue_mean': q_continue.mean().item(),
                'global_step': global_step,
            })
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}', 'steps': steps})
    
    return total_loss / max(1, num_batches)


def main():
    parser = argparse.ArgumentParser(description='Train Diffusion HRM Sudoku Solver')
    
    # Model args
    parser.add_argument('--d-model', type=int, default=512)
    parser.add_argument('--n-heads', type=int, default=8)
    parser.add_argument('--max-steps', type=int, default=32, help='Diffusion steps')
    parser.add_argument('--T', type=int, default=4, help='Base H/L timescale ratio')
    parser.add_argument('--noise-schedule', type=str, default='cosine', choices=['linear', 'cosine', 'sqrt'])
    parser.add_argument('--learned-timing', action='store_true', default=True)
    parser.add_argument('--no-learned-timing', action='store_false', dest='learned_timing')
    parser.add_argument('--halt-max-steps', type=int, default=4, help='ACT outer steps')
    parser.add_argument('--num-puzzles', type=int, default=10000)
    
    # Training args
    parser.add_argument('--epochs', type=int, default=10000)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight-decay', type=float, default=0.1)
    parser.add_argument('--warmup-steps', type=int, default=2000)
    parser.add_argument('--lr-min-ratio', type=float, default=0.1)
    parser.add_argument('--dropout', type=float, default=0.0)
    
    # Data args
    parser.add_argument('--data-dir', type=str, default='data/sudoku-extreme-10k-aug-100')
    parser.add_argument('--num-workers', type=int, default=4)
    
    # Eval args
    parser.add_argument('--eval-interval', type=int, default=100)
    parser.add_argument('--save-dir', type=str, default='experiments/results/diffusion_sudoku')
    
    # W&B args
    parser.add_argument('--wandb-project', type=str, default='diffusion-sudoku')
    parser.add_argument('--wandb-entity', type=str, default=None)
    parser.add_argument('--wandb-name', type=str, default=None)
    parser.add_argument('--wandb-tags', type=str, nargs='+', default=['diffusion', 'sudoku', 'hrm'])
    parser.add_argument('--no-wandb', action='store_true', help='Disable W&B logging')
    
    args = parser.parse_args()
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    os.makedirs(args.save_dir, exist_ok=True)
    
    # Initialize W&B
    if not args.no_wandb:
        run_name = args.wandb_name or f"diffusion-sudoku-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
        wandb.init(project=args.wandb_project, entity=args.wandb_entity, name=run_name, tags=args.wandb_tags, config=vars(args))
        print(f"W&B run: {wandb.run.url}")
    
    # Load data
    print("Loading datasets...")
    train_dataset = SudokuDataset(args.data_dir, 'train')
    val_dataset = SudokuDataset(args.data_dir, 'val')
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    
    print(f"Train: {len(train_dataset)} samples, Val: {len(val_dataset)} samples")
    
    # Create model
    print("\nCreating Diffusion HRM Sudoku Solver...")
    model = DiffusionSudokuSolver(
        d_model=args.d_model, n_heads=args.n_heads, max_steps=args.max_steps, T=args.T,
        noise_schedule=args.noise_schedule, num_puzzles=args.num_puzzles, halt_max_steps=args.halt_max_steps,
        dropout=args.dropout, learned_timing=args.learned_timing,
    ).to(device)
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {num_params:,} ({num_params/1e6:.2f}M)")
    
    if not args.no_wandb:
        wandb.config.update({'num_params': num_params})
        wandb.watch(model, log='gradients', log_freq=500)
    
    # Optimizer & Scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.95))
    total_steps = args.epochs * len(train_loader)
    scheduler = get_cosine_schedule_with_warmup(optimizer, args.warmup_steps, total_steps, args.lr_min_ratio)
    
    # Training loop
    best_grid_acc = 0.0
    
    print(f"\n{'='*60}")
    print("Starting Training - Diffusion HRM Sudoku Solver")
    print(f"{'='*60}")
    print(f"  Diffusion steps: {args.max_steps}, T: {args.T}, Noise: {args.noise_schedule}")
    print(f"  Learned timing: {args.learned_timing}, ACT steps: {args.halt_max_steps}")
    print(f"  Epochs: {args.epochs}, Batch: {args.batch_size}, LR: {args.lr}")
    print(f"{'='*60}\n")
    
    for epoch in range(1, args.epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, scheduler, device, epoch)
        
        if not args.no_wandb:
            wandb.log({'epoch': epoch, 'train/epoch_loss': train_loss})
        
        if epoch % args.eval_interval == 0 or epoch == 1:
            val_metrics = evaluate(model, val_loader, device)
            
            print(f"\nEpoch {epoch}: loss={train_loss:.4f}, val_cell={val_metrics['cell_acc']:.2f}%, val_grid={val_metrics['grid_acc']:.2f}%")
            
            if not args.no_wandb:
                wandb.log({'val/cell_acc': val_metrics['cell_acc'], 'val/grid_acc': val_metrics['grid_acc'], 
                          'val/avg_steps': val_metrics['avg_steps'], 'val/best_grid_acc': max(best_grid_acc, val_metrics['grid_acc']), 'epoch': epoch})
            
            if val_metrics['grid_acc'] > best_grid_acc:
                best_grid_acc = val_metrics['grid_acc']
                checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
                             'val_cell_acc': val_metrics['cell_acc'], 'val_grid_acc': val_metrics['grid_acc'], 'args': vars(args)}
                save_path = f"{args.save_dir}/diffusion_best.pt"
                torch.save(checkpoint, save_path)
                print(f"  âœ“ New best model saved! Grid accuracy: {best_grid_acc:.2f}%")
                
                if not args.no_wandb:
                    wandb.save(save_path)
                    wandb.run.summary['best_grid_acc'] = best_grid_acc
                    wandb.run.summary['best_epoch'] = epoch
    
    print(f"\n{'='*60}")
    print(f"Training complete! Best grid accuracy: {best_grid_acc:.2f}%")
    print(f"{'='*60}")
    
    if not args.no_wandb:
        wandb.finish()


if __name__ == '__main__':
    main()


## 4. Full Training (A100/H100)

Run full training with diffusion-based architecture and W&B logging.


In [None]:
# Full training with W&B integration
!python train_diffusion_sudoku.py \
    --d-model 512 \
    --n-heads 8 \
    --max-steps 32 \
    --T 4 \
    --noise-schedule cosine \
    --learned-timing \
    --halt-max-steps 4 \
    --epochs 10000 \
    --batch-size 512 \
    --lr 1e-4 \
    --weight-decay 0.1 \
    --warmup-steps 2000 \
    --eval-interval 100 \
    --wandb-project diffusion-sudoku \
    --wandb-tags diffusion sudoku hrm full-training


## 5. Quick Test (~1-2 hours on T4)

For a quick sanity check with smaller model and fewer epochs.


In [None]:
# Quick test with W&B
!python train_diffusion_sudoku.py \
    --d-model 256 \
    --n-heads 8 \
    --max-steps 16 \
    --T 4 \
    --halt-max-steps 2 \
    --epochs 500 \
    --batch-size 128 \
    --lr 1e-4 \
    --warmup-steps 200 \
    --eval-interval 50 \
    --wandb-project diffusion-sudoku \
    --wandb-tags diffusion sudoku quick-test


## 6. Load and Evaluate Best Model


In [None]:
import torch
from src.pot.models.diffusion_hrm_solver import DiffusionSudokuSolver
from src.data import SudokuDataset
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device('cuda')

# Load checkpoint
checkpoint = torch.load('experiments/results/diffusion_sudoku/diffusion_best.pt')
args = checkpoint['args']

# Recreate model with same config
model = DiffusionSudokuSolver(
    d_model=args['d_model'],
    n_heads=args['n_heads'],
    max_steps=args['max_steps'],
    T=args['T'],
    noise_schedule=args['noise_schedule'],
    num_puzzles=args['num_puzzles'],
    halt_max_steps=args['halt_max_steps'],
    dropout=args['dropout'],
    learned_timing=args['learned_timing'],
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']}")
print(f"  Val Cell Accuracy: {checkpoint['val_cell_acc']:.2f}%")
print(f"  Val Grid Accuracy: {checkpoint['val_grid_acc']:.2f}%")

# Evaluate on test set
test_dataset = SudokuDataset(args['data_dir'], 'test')
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

print(f"\nEvaluating on {len(test_dataset)} test puzzles...")
model.eval()
correct_cells = 0
total_cells = 0
correct_grids = 0
total_grids = 0

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        inputs = batch['input'].to(device)
        targets = batch['label'].to(device)
        puzzle_ids = batch.get('puzzle_id', torch.zeros(inputs.size(0), dtype=torch.long)).to(device)
        
        logits, _, _, _ = model(inputs, puzzle_ids)
        preds = logits.argmax(dim=-1)
        
        mask = (inputs == 0)
        correct_cells += ((preds == targets) & mask).sum().item()
        total_cells += mask.sum().item()
        
        grid_correct = ((preds == targets) | ~mask).all(dim=1)
        correct_grids += grid_correct.sum().item()
        total_grids += inputs.size(0)

print(f"\n{'='*50}")
print(f"FINAL TEST RESULTS (Diffusion HRM Solver)")
print(f"{'='*50}")
print(f"  Cell Accuracy: {100*correct_cells/total_cells:.2f}%")
print(f"  Grid Accuracy: {100*correct_grids/total_grids:.2f}%")


## 7. Visualize Diffusion Process


In [None]:
import matplotlib.pyplot as plt
import torch

# Visualize how noise levels evolve
from src.pot.core.diffusion_hl_cycles import get_noise_schedule

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, schedule_type in zip(axes, ['linear', 'cosine', 'sqrt']):
    L_schedule = get_noise_schedule(schedule_type, 32, torch.device('cpu'))
    H_schedule = get_noise_schedule(schedule_type, 8, torch.device('cpu'))  # 32 // T
    
    ax.plot(range(32), L_schedule.numpy(), 'b-', label='L-level (fast)', linewidth=2)
    ax.plot(range(0, 32, 4), H_schedule.numpy(), 'r--', label='H-level (slow)', linewidth=2, marker='o')
    ax.set_xlabel('Diffusion Step')
    ax.set_ylabel('Noise Level (Ïƒ)')
    ax.set_title(f'{schedule_type.capitalize()} Schedule')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Dual-Timescale Noise Schedules', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('noise_schedules.png', dpi=150)
plt.show()

# Log to W&B if available
try:
    import wandb
    if wandb.run is not None:
        wandb.log({"noise_schedules": wandb.Image('noise_schedules.png')})
except:
    pass

print("Saved: noise_schedules.png")


## 8. Compare with Standard HybridHRM

Compare parameter counts and architectural differences.


In [None]:
# Compare architectures
import torch
from src.pot.models.diffusion_hrm_solver import DiffusionSudokuSolver
from src.pot.models import HybridPoHHRMSolver

# Count parameters for each architecture
diffusion_model = DiffusionSudokuSolver(
    d_model=512, n_heads=8, max_steps=32, T=4,
    halt_max_steps=4, num_puzzles=10000,
)

hybrid_model = HybridPoHHRMSolver(
    d_model=512, n_heads=8, H_cycles=2, L_cycles=8,
    H_layers=2, L_layers=2, halt_max_steps=4, num_puzzles=1,
)

diff_params = sum(p.numel() for p in diffusion_model.parameters())
hybrid_params = sum(p.numel() for p in hybrid_model.parameters())

print("Architecture Comparison")
print("="*60)
print(f"DiffusionSudokuSolver: {diff_params:,} params ({diff_params/1e6:.2f}M)")
print(f"HybridPoHHRMSolver:    {hybrid_params:,} params ({hybrid_params/1e6:.2f}M)")
print()
print("Key Differences:")
print("  Diffusion Architecture:")
print("    - z_H, z_L are progressively denoised")
print("    - H-update timing is learned (Gumbel-softmax)")
print("    - Uses adaLN conditioning (DiT-style)")
print("    - Noise schedules: linear, cosine, or sqrt")
print()
print("  Hybrid (GRU-based) Architecture:")
print("    - z_H, z_L are updated via deterministic GRU")
print("    - H-updates are fixed every T steps")
print("    - Uses standard LayerNorm")


## 9. W&B Hyperparameter Sweep

Run a hyperparameter sweep to find optimal settings.


In [None]:
%%writefile sweep_config.yaml
program: train_diffusion_sudoku.py
method: bayes
metric:
  name: val/grid_acc
  goal: maximize
parameters:
  d_model:
    values: [256, 384, 512]
  max_steps:
    values: [16, 24, 32]
  T:
    values: [2, 4, 6]
  noise_schedule:
    values: ["linear", "cosine", "sqrt"]
  lr:
    distribution: log_uniform_values
    min: 1e-5
    max: 1e-3
  batch_size:
    values: [128, 256, 512]
  halt_max_steps:
    values: [2, 4, 6]
command:
  - python
  - ${program}
  - --epochs
  - "1000"
  - --eval-interval
  - "50"
  - --wandb-project
  - diffusion-sudoku-sweep
  - ${args}


In [None]:
# Initialize sweep (uncomment to run)
# This will create a sweep in your W&B project
# !wandb sweep sweep_config.yaml

# Then run the sweep agent:
# !wandb agent YOUR_ENTITY/diffusion-sudoku-sweep/SWEEP_ID


---

## Summary

This notebook trains a **fully diffusion-based Sudoku solver** where:

1. **H,L cycles** use diffusion denoising (not GRU)
2. **Timing** for H-updates is learned via Gumbel-softmax
3. **adaLN conditioning** provides DiT-style modulation
4. **Dual noise schedules** create two-timescale reasoning
5. **HRM-style skip connections** ensure stable training

### Key Files
- `src/pot/core/diffusion_hl_cycles.py` â€” Core diffusion H,L module
- `src/pot/models/diffusion_hrm_solver.py` â€” Standalone solver
- `tests/test_diffusion_hl_cycles.py` â€” 47 tests (all passing)

### W&B Metrics Tracked
| Metric | Description |
|--------|-------------|
| `train/loss` | Cross-entropy loss on blank cells |
| `train/lr` | Current learning rate |
| `train/grad_norm` | Gradient norm after clipping |
| `train/act_steps` | Number of ACT outer steps used |
| `train/q_halt_mean` | Mean Q-value for halting |
| `train/q_continue_mean` | Mean Q-value for continuing |
| `val/cell_acc` | Cell-level accuracy on validation set |
| `val/grid_acc` | Grid-level accuracy (full puzzle) |
| `val/avg_steps` | Average ACT steps per sample |

### References
- [HRM Paper](https://arxiv.org/abs/2506.21734) â€” Two-timescale reasoning
- [DiT Paper](https://arxiv.org/abs/2212.09748) â€” Diffusion Transformers (adaLN)
- [DDPM Paper](https://arxiv.org/abs/2006.11239) â€” Denoising Diffusion
