# ðŸŽ¯ BDH Pathfinding Training (CORRECTED)

**KEY FIX**: Train to predict DIRECTION (0-3) instead of CELL (0-99)

This ensures the model learns valid adjacent moves!

**Time**: 30-60 minutes on GPU

In [None]:
import os
if not os.path.exists('bdh'):
    !git clone https://github.com/krychu/bdh.git
%cd bdh

import torch
import numpy as np
from collections import deque
import random
print(f"âœ… PyTorch: {torch.__version__}")
print(f"âœ… CUDA: {torch.cuda.is_available()}")

## Dataset (CORRECTED - Predicts Direction)

In [None]:
class DirectionDataset(torch.utils.data.Dataset):
    """Train to predict DIRECTION (0-3) not cell index"""
    
    def __init__(self, num_samples=50000, board_size=10):
        self.board_size = board_size
        self.samples = []
        
        print(f"Generating {num_samples} samples...")
        attempts = 0
        
        while len(self.samples) < num_samples and attempts < num_samples * 3:
            sample_set = self._generate_sample()
            if sample_set:
                self.samples.extend(sample_set)
            attempts += 1
            
            if attempts % 1000 == 0:
                print(f"  {len(self.samples)}/{num_samples}...")
        
        self.samples = self.samples[:num_samples]
        print(f"âœ… Generated {len(self.samples)} samples")
    
    def _bfs_path(self, board, start, end):
        queue = deque([(start, [start])])
        visited = {start}
        
        while queue:
            (row, col), path = queue.popleft()
            if (row, col) == end:
                return path
            
            for dr, dc in [(0,1), (1,0), (0,-1), (-1,0)]:
                nr, nc = row + dr, col + dc
                if (0 <= nr < self.board_size and 0 <= nc < self.board_size and
                    (nr, nc) not in visited and board[nr, nc] != 1):
                    visited.add((nr, nc))
                    queue.append(((nr, nc), path + [(nr, nc)]))
        return None
    
    def _get_direction(self, current, next_pos):
        """Convert position change to direction (0-3)"""
        dr = next_pos[0] - current[0]
        dc = next_pos[1] - current[1]
        
        if dr == -1 and dc == 0: return 0  # Up
        if dr == 1 and dc == 0: return 1   # Down  
        if dr == 0 and dc == -1: return 2  # Left
        if dr == 0 and dc == 1: return 3   # Right
        return None
    
    def _generate_sample(self):
        # Generate random maze
        board = np.zeros((self.board_size, self.board_size), dtype=np.int64)
        for i in range(self.board_size):
            for j in range(self.board_size):
                if random.random() < 0.25:
                    board[i, j] = 1
        
        start = (random.randint(0, 9), random.randint(0, 9))
        end = (random.randint(0, 9), random.randint(0, 9))
        
        if board[start] == 1 or board[end] == 1 or start == end:
            return None
        
        path = self._bfs_path(board, start, end)
        if not path or len(path) < 5:
            return None
        
        samples = []
        for i in range(len(path) - 1):
            board_state = board.copy()
            board_state[start] = 2
            board_state[end] = 3
            board_state[path[i]] = 4
            
            direction = self._get_direction(path[i], path[i+1])
            if direction is not None:
                samples.append((board_state.flatten(), direction))
        
        return samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        board, direction = self.samples[idx]
        return torch.from_numpy(board).long(), torch.tensor(direction, dtype=torch.long)

# Create dataset
dataset = DirectionDataset(num_samples=50000)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)
print(f"\nâœ… Dataset ready: {len(dataset)} samples")

## Model (V=4 for 4 directions!)

In [None]:
from bdh import BDH, BDHParameters

params = BDHParameters(
    V=4,          # 4 directions (up, down, left, right) - KEY CHANGE!
    T=100,
    H=4,
    N=2048,
    D=64,
    L=12,
    dropout=0.1,
    use_rope=True,
    use_abs_pos=False
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BDH(params).to(device)

print(f"\nâœ… Model created")
print(f"   Vocabulary: {params.V} (4 directions)")
print(f"   Device: {device}")

## Training

In [None]:
import torch.nn as nn
import time

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

losses = []
accuracies = []
best_acc = 0

print("\n" + "="*70)
print("ðŸš€ TRAINING")
print("="*70)

start_time = time.time()

for epoch in range(100):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (boards, directions) in enumerate(loader):
        boards, directions = boards.to(device), directions.to(device)
        
        optimizer.zero_grad()
        logits = model(boards, capture_frames=False)
        last_logits = logits[:, -1, :]  # [B, 4]
        
        loss = criterion(last_logits, directions)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        epoch_loss += loss.item()
        preds = last_logits.argmax(dim=-1)
        correct += (preds == directions).sum().item()
        total += directions.size(0)
        
        if batch_idx % 100 == 0:
            acc = 100 * correct / total
            print(f"Epoch {epoch+1:3d} | Batch {batch_idx:4d} | Loss: {loss.item():.4f} | Acc: {acc:.1f}%")
    
    avg_loss = epoch_loss / len(loader)
    accuracy = 100 * correct / total
    losses.append(avg_loss)
    accuracies.append(accuracy)
    
    print(f"\nEpoch {epoch+1} | Loss: {avg_loss:.4f} | Acc: {accuracy:.2f}%\n")
    
    if accuracy > best_acc:
        best_acc = accuracy
        torch.save(model.state_dict(), 'bdh_pathfinding_directions.pth')
        print(f"âœ… New best: {best_acc:.2f}%\n")
    
    if accuracy > 95:
        print("ðŸŽ‰ Excellent accuracy!")
        break

print(f"\nâœ… Training complete: {(time.time()-start_time)/60:.1f}min")
print(f"Best accuracy: {best_acc:.2f}%")

## Save & Download

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(losses, linewidth=2, color='#6366f1')
ax1.set_title('Loss')
ax2.plot(accuracies, linewidth=2, color='#10b981')
ax2.set_title('Accuracy')
plt.tight_layout()
plt.savefig('training.png', dpi=150)
plt.show()

print(f"\nðŸ“¥ Download: bdh_pathfinding_directions.pth")
print(f"   Place in: checkpoints/bdh_pathfinding_trained.pth")
print(f"\nðŸŽ‰ This model will work MUCH better!")