# üéØ BDH Pathfinding Training

Train BDH to actually solve mazes by predicting the next cell in an optimal path.

**Key Innovation**: Set V=100 (number of cells) so BDH directly predicts next cell!

**Expected Time**: 30-60 minutes on T4 GPU

**IMPORTANT**: 
- Enable **GPU** (T4 or P100)
- Enable **Internet**

## Step 1: Clone Repository & Setup

In [None]:
import os

# Clone BDH repository if not exists
if not os.path.exists('bdh'):
    !git clone https://github.com/krychu/bdh.git
    print("‚úÖ Repository cloned")
else:
    print("‚úÖ Repository already exists")

%cd bdh

In [None]:
# Install dependencies
import torch
print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")

!pip install numpy matplotlib -q
print("‚úÖ Dependencies installed")

## Step 2: Create Pathfinding Dataset

In [None]:
import numpy as np
import random
from collections import deque
from typing import List, Tuple, Optional

class PathfindingDataset(torch.utils.data.Dataset):
    """Dataset for training BDH on pathfinding"""
    
    def __init__(self, num_samples=50000, board_size=10, wall_density=0.25, min_path_length=5):
        self.num_samples = num_samples
        self.board_size = board_size
        self.wall_density = wall_density
        self.min_path_length = min_path_length
        
        print(f"Generating {num_samples} pathfinding samples...")
        self.samples = self._generate_dataset()
        print(f"‚úÖ Generated {len(self.samples)} valid samples")
    
    def _bfs_path(self, board, start, end):
        """Find shortest path using BFS"""
        queue = deque([(start, [start])])
        visited = {start}
        
        while queue:
            (row, col), path = queue.popleft()
            
            if (row, col) == end:
                return path
            
            for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                new_row, new_col = row + dr, col + dc
                
                if (0 <= new_row < self.board_size and
                    0 <= new_col < self.board_size and
                    (new_row, new_col) not in visited and
                    board[new_row, new_col] != 1):
                    
                    visited.add((new_row, new_col))
                    queue.append(((new_row, new_col), path + [(new_row, new_col)]))
        
        return None
    
    def _generate_sample(self):
        """Generate a single training sample"""
        for _ in range(10):  # Max 10 attempts
            # Generate board
            board = np.zeros((self.board_size, self.board_size), dtype=np.int64)
            for i in range(self.board_size):
                for j in range(self.board_size):
                    if random.random() < self.wall_density:
                        board[i, j] = 1
            
            # Random start and end
            start = (random.randint(0, self.board_size-1), random.randint(0, self.board_size-1))
            end = (random.randint(0, self.board_size-1), random.randint(0, self.board_size-1))
            
            if board[start] == 1 or board[end] == 1 or start == end:
                continue
            
            # Find path
            path = self._bfs_path(board, start, end)
            
            if path and len(path) >= self.min_path_length:
                # Create training samples from path
                samples = []
                for i in range(len(path) - 1):
                    board_state = board.copy()
                    board_state[start] = 2  # Start
                    board_state[end] = 3    # End
                    board_state[path[i]] = 4  # Current
                    
                    next_pos = path[i + 1]
                    target_idx = next_pos[0] * self.board_size + next_pos[1]
                    
                    samples.append((board_state.flatten(), target_idx))
                
                return samples
        
        return None
    
    def _generate_dataset(self):
        """Generate full dataset"""
        samples = []
        attempts = 0
        
        while len(samples) < self.num_samples and attempts < self.num_samples * 3:
            sample_set = self._generate_sample()
            if sample_set:
                samples.extend(sample_set)
            attempts += 1
            
            if attempts % 1000 == 0:
                print(f"  Generated {len(samples)}/{self.num_samples} samples...")
        
        return samples[:self.num_samples]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        board_state, target = self.samples[idx]
        return (
            torch.from_numpy(board_state).long(),
            torch.tensor(target, dtype=torch.long)
        )

# Create dataset
dataset = PathfindingDataset(num_samples=50000, board_size=10)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

print(f"\n‚úÖ Dataset ready: {len(dataset)} samples, {len(loader)} batches")

## Step 3: Create BDH Model

In [None]:
from bdh import BDH, BDHParameters

# Create BDH with V=100 (number of cells)
params = BDHParameters(
    V=100,        # 100 cells (10x10 board) - KEY INNOVATION!
    T=100,        # Sequence length
    H=4,          # Heads
    N=2048,       # Neurons
    D=64,         # Latent dimension
    L=12,         # Layers
    dropout=0.1,
    use_rope=True,
    use_abs_pos=False
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BDH(params)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\n{'='*70}")
print(f"üß† BDH Model Created")
print(f"{'='*70}")
print(f"Parameters: {total_params:,}")
print(f"Device: {device}")
print(f"Vocabulary: {params.V} (maps to 10x10 board)")
print(f"{'='*70}")

## Step 4: Train Model

In [None]:
import torch.nn as nn
import time

# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

num_epochs = 100
losses = []
accuracies = []
best_acc = 0.0

print("\n" + "="*70)
print("üöÄ STARTING TRAINING")
print("="*70)

start_time = time.time()

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (boards, targets) in enumerate(loader):
        boards, targets = boards.to(device), targets.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(boards, capture_frames=False)  # [B, T, V=100]
        last_logits = logits[:, -1, :]  # [B, 100]
        
        # Compute loss
        loss = criterion(last_logits, targets)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        epoch_loss += loss.item()
        
        # Accuracy
        preds = last_logits.argmax(dim=-1)
        correct += (preds == targets).sum().item()
        total += targets.size(0)
        
        if batch_idx % 100 == 0:
            elapsed = time.time() - start_time
            acc = 100 * correct / total if total > 0 else 0
            print(f"Epoch {epoch+1:3d}/{num_epochs} | "
                  f"Batch {batch_idx:4d}/{len(loader)} | "
                  f"Loss: {loss.item():.4f} | "
                  f"Acc: {acc:.1f}% | "
                  f"Time: {elapsed/60:.1f}m")
    
    # Epoch summary
    avg_loss = epoch_loss / len(loader)
    accuracy = 100 * correct / total
    losses.append(avg_loss)
    accuracies.append(accuracy)
    
    elapsed = time.time() - start_time
    print(f"\n{'='*70}")
    print(f"Epoch {epoch+1:3d} Complete | "
          f"Loss: {avg_loss:.4f} | "
          f"Acc: {accuracy:.2f}% | "
          f"Time: {elapsed/60:.1f}m")
    print(f"{'='*70}\n")
    
    # Save best model
    if accuracy > best_acc:
        best_acc = accuracy
        torch.save(model.state_dict(), 'bdh_pathfinding_trained.pth')
        print(f"‚úÖ New best! Acc: {best_acc:.2f}%\n")
    
    # Early stopping
    if accuracy > 95.0:
        print("üéâ Excellent accuracy! Stopping early.")
        break

print("\n" + "="*70)
print("‚úÖ TRAINING COMPLETE!")
print(f"Time: {(time.time() - start_time)/60:.1f} minutes")
print(f"Best accuracy: {best_acc:.2f}%")
print("="*70)

## Step 5: Plot Results

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(losses, linewidth=2, color='#6366f1')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

ax2.plot(accuracies, linewidth=2, color='#10b981')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Training Accuracy', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('pathfinding_training.png', dpi=150)
plt.show()

print(f"\nFinal Results:")
print(f"  Loss: {losses[-1]:.4f}")
print(f"  Accuracy: {accuracies[-1]:.2f}%")
print(f"  Best Accuracy: {best_acc:.2f}%")

## Step 6: Verify Checkpoint

In [None]:
import os

if os.path.exists('bdh_pathfinding_trained.pth'):
    size_mb = os.path.getsize('bdh_pathfinding_trained.pth') / 1e6
    print(f"‚úÖ Checkpoint saved: bdh_pathfinding_trained.pth ({size_mb:.1f} MB)")
    print(f"\nüì• Download this file and place it in your project's checkpoints/ directory")
else:
    print("‚ùå Checkpoint not found!")

## ‚úÖ Training Complete!

### Next Steps:

1. **Download checkpoint**:
   - Find `bdh_pathfinding_trained.pth` in Output section
   - Download it

2. **Deploy to Brain Explorer**:
   ```bash
   mv ~/Downloads/bdh_pathfinding_trained.pth checkpoints/
   ```

3. **Update backend** to use trained model for pathfinding

4. **Test** in Pathfinder Live module

üéâ **BDH can now actually solve mazes!**