In [1]:
import sys
!{sys.executable} -m pip install wandb optuna "ray[tune]" -q

[0m

In [2]:
!cd /workspace && git clone https://github.com/Eran-BA/PoT.git


Cloning into 'PoT'...
remote: Enumerating objects: 2243, done.[K
remote: Counting objects: 100% (18/18), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 2243 (delta 7), reused 15 (delta 5), pack-reused 2225 (from 1)[K
Receiving objects: 100% (2243/2243), 1.39 MiB | 4.46 MiB/s, done.
Resolving deltas: 100% (1404/1404), done.


In [3]:
import os
os.chdir('/workspace')

In [3]:
!cd /workspace/PoT


In [None]:
import wandb
import os

# Set the NEW key
os.environ["WANDB_API_KEY"] = ""

# Force re-login
wandb.login(key=os.environ["WANDB_API_KEY"], relogin=True, force=True)

# Test it
run = wandb.init(project="sudoku-hpo-test", name="test-run")
wandb.log({"test": 1})
wandb.finish()
print("✓ Success!")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33meranbt92[0m ([33meranbt92-open-university-of-israel[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
test,▁

0,1
test,1


✓ Success!


In [None]:
!/venv/main/bin/python /workspace/PoT/experiments/sudoku_hpo.py \
    --n-trials 10 \
    --epochs-per-trial 500 \
    --num-gpus 4 \
    --batch-size 768 \
    --eval-interval 100 \
    --grace-period 500 \
    --download \
    --wandb-project sudoku-hpo-general-reasoning \
    --study-name sudoku_hpo_4xB200

Downloading Sudoku-Extreme dataset from HuggingFace...
Note: Augmentation is now ON-THE-FLY (not pre-computed)
train.csv: 100%|██████████████████████████████| 719M/719M [00:05<00:00, 125MB/s]
  Train puzzles: 9000, Val puzzles: 1000
Processing train: 100%|██████████████████| 9000/9000 [00:00<00:00, 43217.57it/s]
  train: 9000 puzzles (augmentation: on-the-fly)
Processing val: 100%|████████████████████| 1000/1000 [00:00<00:00, 38812.79it/s]
  val: 1000 puzzles (augmentation: on-the-fly)
test.csv: 100%|████████████████████████████| 79.4M/79.4M [00:01<00:00, 77.1MB/s]
Processing test: 422786it [00:10, 41254.51it/s]
  test: 422786 puzzles
✓ Dataset saved to /workspace/PoT/data/sudoku-extreme-10k-aug-100
✓ Data verified at: /workspace/PoT/data/sudoku-extreme-10k-aug-100
Loading data into memory...
[train] Loaded 9000 puzzles
  Augmentation: ON-THE-FLY
[val] Loaded 1000 puzzles
  Augmentation: OFF
✓ Loaded 9000 train, 1000 val puzzles
2025-12-10 02:15:46,356	INFO worker.py:2023 -- Started a 

In [None]:
# ============================================================
# PHASE 2: Continue Training from Best HPO Checkpoint
# ============================================================
import os
import sys
import glob
import torch
import wandb
from datetime import datetime

os.chdir('/workspace/PoT')
sys.path.insert(0, '/workspace/PoT')

# ============================================================
# Step 1: Find Best Checkpoint from HPO
# ============================================================
print("=" * 60)
print("PHASE 2: Continue Training from Best HPO Model")
print("=" * 60)

checkpoints = glob.glob('experiments/hpo_results/checkpoints/*_best.pt')
if not checkpoints:
    raise FileNotFoundError("No HPO checkpoints found!")

# Find best
best_ckpt_path = None
best_acc = 0

for ckpt_path in checkpoints:
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    acc = ckpt['best_grid_acc']
    if acc > best_acc:
        best_acc = acc
        best_ckpt_path = ckpt_path

print(f"✓ Best HPO checkpoint: {best_ckpt_path}")
print(f"  Grid Accuracy: {best_acc:.2f}%")

# Load checkpoint and config
checkpoint = torch.load(best_ckpt_path, map_location='cpu', weights_only=False)
config = checkpoint['config']
start_epoch = checkpoint['epoch']

print(f"  Starting from epoch: {start_epoch}")
print(f"  Config: lr={config['lr']:.2e}, L_cycles={config['L_cycles']}")

# ============================================================
# Step 2: Setup Model and Data
# ============================================================
from torch.utils.data import DataLoader
from src.data import SudokuDataset
from src.pot.models import HybridPoHHRMSolver
from src.training import train_epoch, train_epoch_async, evaluate

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n✓ Device: {device}")

# Load data
data_dir = 'data/sudoku-extreme-10k-aug-100'
train_dataset = SudokuDataset(data_dir, 'train')
val_dataset = SudokuDataset(data_dir, 'val')
test_dataset = SudokuDataset(data_dir, 'test')

batch_size = 768
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

print(f"✓ Data: {len(train_dataset)} train, {len(val_dataset)} val, {len(test_dataset)} test")

# Build model with exact config from HPO
model = HybridPoHHRMSolver(
    d_model=config['d_model'],
    n_heads=config['n_heads'],
    H_layers=config['H_layers'],
    L_layers=config['L_layers'],
    d_ff=config['d_ff'],
    dropout=config['dropout'],
    H_cycles=config['H_cycles'],
    L_cycles=config['L_cycles'],
    T=config['T'],
    num_puzzles=1,
    hrm_grad_style=config['hrm_grad_style'],
    halt_max_steps=3,  # ← Changed from config['halt_max_steps'] (was 2)
    halt_exploration_prob=config['halt_exploration'],
).to(device)

# Load weights from checkpoint
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✓ Model loaded from checkpoint")
config['halt_max_steps'] = 3 
# ============================================================
# Step 3: Setup Optimizers (with lower LR for fine-tuning)
# ============================================================
import math

finetune_lr = config['lr'] * 0.1  # 10x lower for fine-tuning
total_epochs = 3000  # Additional epochs
eval_interval = 50

puzzle_lr = finetune_lr * config['puzzle_lr_multiplier']
betas = (0.9, config['beta2'])

puzzle_params = list(model.puzzle_emb.parameters())
model_params = [p for p in model.parameters() if p not in set(puzzle_params)]

optimizer = torch.optim.AdamW(model_params, lr=finetune_lr, weight_decay=config['weight_decay'], betas=betas)
puzzle_optimizer = torch.optim.AdamW(puzzle_params, lr=puzzle_lr, weight_decay=config['puzzle_weight_decay'], betas=betas)

# Cosine LR schedule
total_steps = total_epochs * len(train_loader)
warmup_steps = 500

def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
puzzle_scheduler = torch.optim.lr_scheduler.LambdaLR(puzzle_optimizer, lr_lambda)

print(f"✓ Optimizer: lr={finetune_lr:.2e} (10x lower for fine-tuning)")

# ============================================================
# Step 4: W&B Setup
# ============================================================
run = wandb.init(
    project="sudoku-finetune-general-reasoning",
    name=f"phase2_maxhalt3_{datetime.now().strftime('%Y%m%d_%H%M')}",  # Updated name
    config={
        **config,
        "phase": 2,
        "start_epoch": start_epoch,
        "finetune_lr": finetune_lr,
        "total_epochs": total_epochs,
        "best_hpo_acc": best_acc,
        "halt_max_steps": 3,  # ← Explicit override
    }
)
print(f"✓ W&B: {run.url}")

# ============================================================
# Step 5: Training Loop
# ============================================================
use_async = config['async_batch']
best_grid_acc = best_acc
save_dir = 'experiments/hpo_results/finetune_checkpoints'
os.makedirs(save_dir, exist_ok=True)

print(f"\n{'='*60}")
print(f"Starting Phase 2 Training: {total_epochs} epochs")
print(f"{'='*60}\n")

for epoch in range(1, total_epochs + 1):
    # Train
    if use_async:
        train_metrics = train_epoch_async(
            model, train_loader, optimizer, puzzle_optimizer,
            device, epoch, use_poh=True,
            scheduler=scheduler, puzzle_scheduler=puzzle_scheduler,
        )
    else:
        train_metrics = train_epoch(
            model, train_loader, optimizer, puzzle_optimizer,
            device, epoch, use_poh=True,
            scheduler=scheduler, puzzle_scheduler=puzzle_scheduler,
        )
    
    train_dataset.on_epoch_end()
    
    # Evaluate
    if epoch % eval_interval == 0 or epoch == 1:
        val_metrics = evaluate(model, val_loader, device, use_poh=True)
        
        is_best = val_metrics['grid_acc'] > best_grid_acc
        if is_best:
            best_grid_acc = val_metrics['grid_acc']
            # Save best checkpoint
            torch.save({
                'epoch': start_epoch + epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_grid_acc': best_grid_acc,
                'config': config,
            }, os.path.join(save_dir, 'best_model.pt'))
            print(f"  🏆 New best: {best_grid_acc:.2f}%")
        
        print(f"Epoch {epoch}: train_loss={train_metrics['loss']:.4f}, "
              f"val_cell={val_metrics['cell_acc']:.1f}%, val_grid={val_metrics['grid_acc']:.1f}%, "
              f"best={best_grid_acc:.1f}%")
        
        wandb.log({
            'epoch': start_epoch + epoch,
            'train_loss': train_metrics['loss'],
            'train_cell_acc': train_metrics['cell_acc'],
            'train_grid_acc': train_metrics['grid_acc'],
            'val_loss': val_metrics['loss'],
            'val_cell_acc': val_metrics['cell_acc'],
            'val_grid_acc': val_metrics['grid_acc'],
            'best_grid_acc': best_grid_acc,
        })

# ============================================================
# Step 6: Final Evaluation on Test Set (422k puzzles)
# ============================================================
print(f"\n{'='*60}")
print("Final Evaluation on Test Set (422k puzzles)")
print(f"{'='*60}")

# Load best model
best_model_path = os.path.join(save_dir, 'best_model.pt')
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, weights_only=False)['model_state_dict'])

test_metrics = evaluate(model, test_loader, device, use_poh=True)

print(f"\n🎯 TEST RESULTS (422k puzzles):")
print(f"   Cell Accuracy: {test_metrics['cell_acc']:.2f}%")
print(f"   Grid Accuracy: {test_metrics['grid_acc']:.2f}%")

wandb.log({
    'test_cell_acc': test_metrics['cell_acc'],
    'test_grid_acc': test_metrics['grid_acc'],
})

wandb.finish()

print(f"\n✓ Done! Best model saved to: {best_model_path}")

In [None]:
# ============================================================
# Step 6: Final Evaluation on Test Set (422k puzzles)
# ============================================================
print(f"\n{'='*60}")
print("Final Evaluation on Test Set (422k puzzles)")
print(f"{'='*60}")

# Load best model
# Load best model
best_model_path = os.path.join(save_dir, 'best_model.pt')
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, weights_only=False)['model_state_dict'])
    print(f"✓ Loaded best model from: {best_model_path}")
else:
    print(f"⚠️ No best checkpoint found, using last epoch weights")
    
test_metrics = evaluate(model, test_loader, device, use_poh=True)

print(f"\n🎯 TEST RESULTS (422k puzzles):")
print(f"   Cell Accuracy: {test_metrics['cell_acc']:.2f}%")
print(f"   Grid Accuracy: {test_metrics['grid_acc']:.2f}%")

wandb.log({
    'test_cell_acc': test_metrics['cell_acc'],
    'test_grid_acc': test_metrics['grid_acc'],
})

wandb.finish()

print(f"\n✓ Done! Best model saved to: {best_model_path}")