# Rubik's Cube 2x2 Training Analysis

This notebook analyzes the HRM model training for 2x2 Rubik's cube solving.

## Current Status

### Two Approaches
1. **Policy Network** (dev branch): Predicts full solution sequence
2. **Heuristic Network** (heuristic branch): Learns distance-to-solved for A* search

### Key Findings
- Small fully-connected network (4 layers, 512 nodes) with **one-hot encoding** works much better
- Can train on just **1000 states** and get useful heuristic (2500 nodes avg during search)
- HRM (recurrent) might be overkill for heuristic (single value prediction)
- **Overfitting issue**: Good on train, poor on test ‚Üí dataset bias suspected

### Dataset Bias Problem
States closer to solved are **over-represented** because dataset is generated from solution sequences:
- Each scramble creates a solution sequence
- States early in solution (closer to solved) appear more often
- Model learns to predict low distances too frequently

## Goals
1. Analyze dataset distribution
2. Test different input encodings (state string vs one-hot)
3. Compare HRM vs simple MLP performance
4. Fix overfitting and improve generalization

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import sys

# Add project root
sys.path.insert(0, str(Path.cwd()))
import py222

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Load and Analyze Heuristic Dataset

In [2]:
def load_dataset_split(data_dir: Path, split: str):
    """Load a dataset split (train/test/val)"""
    split_dir = data_dir / split
    if not split_dir.exists():
        raise FileNotFoundError(f"Split {split} not found in {data_dir}")

    # Load metadata
    with open(split_dir / "dataset.json", "r") as f:
        metadata = json.load(f)

    # Load arrays
    inputs = np.load(split_dir / "all__inputs.npy")
    labels = np.load(split_dir / "all__labels.npy")
    puzzle_indices = np.load(split_dir / "all__puzzle_indices.npy")

    return {
        'metadata': metadata,
        'inputs': inputs,
        'labels': labels.squeeze(),  # Remove extra dimension
        'puzzle_indices': puzzle_indices,
    }

# Load heuristic dataset
data_dir = Path("data/cube-2-by-2-heuristic")
if data_dir.exists():
    train_data = load_dataset_split(data_dir, "train")
    test_data = load_dataset_split(data_dir, "test")
    
    print(f"Train samples: {len(train_data['inputs'])}")
    print(f"Test samples: {len(test_data['inputs'])}")
    print(f"Input shape: {train_data['inputs'].shape}")
    print(f"Label shape: {train_data['labels'].shape}")
else:
    print("‚ö†Ô∏è Heuristic dataset not found. Run: python dataset/build_2x2_heuristic.py")

‚ö†Ô∏è Heuristic dataset not found. Run: python dataset/build_2x2_heuristic.py


## 2. Analyze Distance Distribution (BIAS CHECK)

In [None]:
if data_dir.exists():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Train distribution
    axes[0].hist(train_data['labels'], bins=range(int(train_data['labels'].max()) + 2), 
                 edgecolor='black', alpha=0.7)
    axes[0].set_title('Train: Distance to Solved Distribution')
    axes[0].set_xlabel('Distance (moves)')
    axes[0].set_ylabel('Count')
    axes[0].grid(alpha=0.3)
    
    # Test distribution
    axes[1].hist(test_data['labels'], bins=range(int(test_data['labels'].max()) + 2), 
                 edgecolor='black', alpha=0.7, color='orange')
    axes[1].set_title('Test: Distance to Solved Distribution')
    axes[1].set_xlabel('Distance (moves)')
    axes[1].set_ylabel('Count')
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Statistics
    print("\nüìä Distribution Statistics:")
    print(f"Train - Mean: {train_data['labels'].mean():.2f}, Median: {np.median(train_data['labels']):.0f}, Std: {train_data['labels'].std():.2f}")
    print(f"Test  - Mean: {test_data['labels'].mean():.2f}, Median: {np.median(test_data['labels']):.0f}, Std: {test_data['labels'].std():.2f}")
    
    # Check for bias
    print("\n‚ö†Ô∏è BIAS CHECK:")
    train_low = (train_data['labels'] <= 3).sum() / len(train_data['labels']) * 100
    test_low = (test_data['labels'] <= 3).sum() / len(test_data['labels']) * 100
    print(f"States within 3 moves of solved: Train={train_low:.1f}%, Test={test_low:.1f}%")
    
    if train_low > 40:
        print("üî¥ SEVERE BIAS: >40% of states are close to solved!")
    elif train_low > 25:
        print("üü° MODERATE BIAS: Dataset skewed toward easier states")
    else:
        print("üü¢ Distribution looks reasonable")

## 3. Compare Input Encodings

Test which encoding works better:
1. **State String**: Current approach (24 integers 0-5)
2. **One-Hot**: Each position ‚Üí 6-dimensional vector (24 √ó 6 = 144 dims)

In [None]:
def state_to_onehot(state, num_colors=6):
    """Convert state array to one-hot encoding"""
    # state: (batch, 24) ‚Üí (batch, 24, 6) ‚Üí (batch, 144)
    import torch
    import torch.nn.functional as F
    
    if isinstance(state, np.ndarray):
        state = torch.from_numpy(state)
    
    onehot = F.one_hot(state.long(), num_classes=num_colors)
    return onehot.reshape(onehot.shape[0], -1).float()

# Example
if data_dir.exists():
    sample_states = train_data['inputs'][:5]
    print("Original encoding (first sample):")
    print(sample_states[0])
    print(f"Shape: {sample_states.shape}")
    
    onehot_sample = state_to_onehot(sample_states)
    print("\nOne-hot encoding (first sample):")
    print(onehot_sample[0].numpy())
    print(f"Shape: {onehot_sample.shape}")
    
    print("\n‚úÖ One-hot encoding: 24 positions √ó 6 colors = 144 features")

## 4. Simple MLP Baseline (What Your Colleague Found Works Best)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleMLP(nn.Module):
    """Simple 4-layer MLP with one-hot encoding
    
    This is what your colleague found works well:
    - 4 fully connected layers
    - 512 hidden units
    - One-hot input encoding
    - Trains in ~5 min on CPU
    - Achieves ~350 nodes average during search
    """
    def __init__(self, input_dim=144, hidden_dim=512, output_dim=1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        # x: (batch, 24) state indices
        # Convert to one-hot
        x_onehot = F.one_hot(x.long(), num_classes=6).float()
        x_onehot = x_onehot.reshape(x_onehot.shape[0], -1)  # (batch, 144)
        
        x = F.relu(self.fc1(x_onehot))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)  # Output: distance estimate
        return x.squeeze(-1)

# Create model
simple_model = SimpleMLP()
print(simple_model)
print(f"\nParameters: {sum(p.numel() for p in simple_model.parameters()):,}")

# Test forward pass
if data_dir.exists():
    sample_input = torch.from_numpy(train_data['inputs'][:4])
    with torch.no_grad():
        output = simple_model(sample_input)
    print(f"\nTest output shape: {output.shape}")
    print(f"Sample predictions: {output.numpy()}")
    print(f"True labels: {train_data['labels'][:4]}")

## 5. Quick Training Test (CPU-friendly)

Train a small model on CPU to verify everything works

In [None]:
from torch.utils.data import TensorDataset, DataLoader
from tqdm.auto import tqdm

def train_simple_model(model, train_data, test_data, epochs=10, lr=1e-3, batch_size=64, device='cpu'):
    """Quick training function for simple MLP"""
    
    # Prepare data
    train_dataset = TensorDataset(
        torch.from_numpy(train_data['inputs']),
        torch.from_numpy(train_data['labels'].astype(np.float32))
    )
    test_dataset = TensorDataset(
        torch.from_numpy(test_data['inputs']),
        torch.from_numpy(test_data['labels'].astype(np.float32))
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    history = {'train_loss': [], 'test_loss': [], 'test_mae': []}
    
    for epoch in range(epochs):
        # Train
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Evaluate
        model.eval()
        test_loss = 0
        test_mae = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                test_loss += criterion(outputs, targets).item()
                test_mae += torch.abs(outputs - targets).mean().item()
        
        test_loss /= len(test_loader)
        test_mae /= len(test_loader)
        
        history['train_loss'].append(train_loss)
        history['test_loss'].append(test_loss)
        history['test_mae'].append(test_mae)
        
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Test Loss={test_loss:.4f}, Test MAE={test_mae:.4f}")
    
    return history

# Train on small subset for quick test
if data_dir.exists():
    print("Training simple MLP on small subset (1000 samples)...")
    small_train = {
        'inputs': train_data['inputs'][:1000],
        'labels': train_data['labels'][:1000]
    }
    small_test = {
        'inputs': test_data['inputs'][:200],
        'labels': test_data['labels'][:200]
    }
    
    model = SimpleMLP()
    history = train_simple_model(model, small_train, small_test, epochs=5, lr=1e-3)
    
    # Plot
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['test_loss'], label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.title('Training Progress')
    
    plt.subplot(1, 2, 2)
    plt.plot(history['test_mae'])
    plt.xlabel('Epoch')
    plt.ylabel('Mean Absolute Error')
    plt.grid(alpha=0.3)
    plt.title('Test MAE (Lower is Better)')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n‚úÖ Final Test MAE: {history['test_mae'][-1]:.4f} moves")

## 6. Next Steps

### Immediate Actions
1. **Fix dataset bias**: 
   - Generate uniformly random scrambles (not from solutions)
   - Or weight samples inversely to distance
   
2. **Test simple MLP vs HRM**:
   - Train both on same data
   - Compare speed, accuracy, and search performance
   
3. **Input encoding experiments**:
   - One-hot (current best)
   - Learned embeddings
   - Relative position encoding

### For Presentation
- Dataset distribution analysis (this notebook)
- Model architecture comparison (HRM vs MLP)
- Search performance metrics (nodes expanded, solution length)
- Visualizations of cube solving
- Comparison to baselines (BFS, DeepCubeA)
- Extension to 3x3 (if time permits)

### Code Cleanup
- Refactor solver.py with Hydra CLI
- Add more natural comments
- Clean up AI-generated code sections