# Chess EqProp Training

Train the holomorphic equilibrium propagation model on chess positions using PyTorch Lightning and Weights & Biases.

In [1]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import wandb

from pathlib import Path
import numpy as np
from typing import Optional, Dict, Any

from src.eqprop import LinearHolomorphicEQProp
import chess

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pl.__version__}")
print(f"Device: {torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')}")

  Referenced from: <EB3FF92A-5EB1-3EE8-AF8B-5923C1265422> /Users/c/eq-chess/env/lib/python3.11/site-packages/torchvision/image.so
  warn(


PyTorch version: 2.5.1
PyTorch Lightning version: 2.5.6
Device: mps


## Configuration

Set all hyperparameters and training configuration:

In [2]:
# Paths
SCRATCH_DIR = Path('scratch')
CHECKPOINT_DIR = Path('checkpoints')
CHECKPOINT_DIR.mkdir(exist_ok=True)

# Model architecture
INPUT_SIZE = 8 * 8 * 12   # Chess board encoding (768)
HIDDEN_SIZE = 2048        # Hidden layer size

# Output structure: [8 previous board states] + [1 next board state]
# Network always predicts from white's perspective
BOARD_STATE_SIZE = 8 * 8 * 12  # 768 per board
NUM_HISTORY_STATES = 8         # Predict last 8 board states (8 half-moves of history)
NUM_FUTURE_STATES = 1          # Predict next board state
TOTAL_OUTPUT_STATES = NUM_HISTORY_STATES + NUM_FUTURE_STATES  # 9 total states
OUTPUT_SIZE = TOTAL_OUTPUT_STATES * BOARD_STATE_SIZE  # 9 * 768 = 6912

WEIGHT_STD = 0.05        # Weight initialization std (smaller for larger output)

# EqProp dynamics parameters
T1 = 20          # Free phase settling time
T2 = 4           # Nudged phase settling time
N = 4            # Number of holomorphic phases (1=standard EP, 4=holomorphic)
                 # ✓ N>1 now properly implemented with complex dynamics!
BETA = 0.5       # Nudging strength
LR_DYNAMICS = 0.5    # Learning rate for state dynamics
NOISE_STD = 0.0      # Noise during dynamics (0 = no noise)

# Training parameters
BATCH_SIZE = 16  # Smaller batch due to larger output
NUM_EPOCHS = 100
LR_LEARNING = 0.001  # Learning rate for weight updates

# State initialization configuration
STATE_INIT_MODE = 'zeros'  # 'zeros', 'random', 'custom'
STATE_INIT_STD = 0.01      # Std for random initialization

# Wandb configuration
WANDB_PROJECT = 'eq-chess'
WANDB_ENTITY = None  # Set to your wandb username/team
EXPERIMENT_NAME = 'eqprop-history-prediction'

# Hardware
ACCELERATOR = 'auto'  # 'auto', 'cpu', 'gpu', 'mps'
NUM_WORKERS = 0       # DataLoader workers (0 for debugging)

print("Configuration:")
print(f"  Model: Input={INPUT_SIZE}, Hidden={HIDDEN_SIZE}, Output={OUTPUT_SIZE}")
print(f"  Output breakdown: {NUM_HISTORY_STATES} history states + {NUM_FUTURE_STATES} future state = {TOTAL_OUTPUT_STATES} total states")
print(f"  Total output dimensions: {TOTAL_OUTPUT_STATES} × {BOARD_STATE_SIZE} = {OUTPUT_SIZE}")
print(f"  EqProp: T1={T1}, T2={T2}, N={N}, beta={BETA}")
print(f"  Holomorphic EP: Using {N} phases on unit circle with complex dynamics")
print(f"  Training: Batch={BATCH_SIZE}, Epochs={NUM_EPOCHS}, LR={LR_LEARNING}")
print(f"  State init: {STATE_INIT_MODE}")

Configuration:
  Model: Input=768, Hidden=2048, Output=6912
  Output breakdown: 8 history states + 1 future state = 9 total states
  Total output dimensions: 9 × 768 = 6912
  EqProp: T1=20, T2=4, N=4, beta=0.5
  Holomorphic EP: Using 4 phases on unit circle with complex dynamics
  Training: Batch=16, Epochs=100, LR=0.001
  State init: zeros


## Load Datasets

Load the pre-generated chess datasets from the scratch directory:

In [3]:
print("Loading datasets...\n")

# Load SEQUENCE datasets (not board states)
train_seq_dataset = torch.load(SCRATCH_DIR / 'train_seq_dataset.pt')
val_seq_dataset = torch.load(SCRATCH_DIR / 'val_seq_dataset.pt')
test_seq_dataset = torch.load(SCRATCH_DIR / 'test_seq_dataset.pt')

print(f"Sequence dataset sizes:")
print(f"  Train: {len(train_seq_dataset)}")
print(f"  Val:   {len(val_seq_dataset)}")
print(f"  Test:  {len(test_seq_dataset)}")

# We need to create a custom dataset that extracts training samples from sequences
# Each sample: input = board state, target = [8 history states] + [1 next state]
from torch.utils.data import Dataset

class ChessHistoryDataset(Dataset):
    """
    Dataset that creates training samples from game sequences.
    Each sample consists of:
    - Input: Current board state (from white's perspective)
    - Target: [8 previous states] + [1 next state] (9 total states)
    """
    
    def __init__(self, sequence_dataset, num_history=8):
        self.sequences = sequence_dataset.sequences
        self.num_history = num_history
        self.samples = []
        
        # Extract valid training positions from sequences
        for seq_idx, sequence in enumerate(self.sequences):
            # Need at least num_history + 1 positions (history + current + next)
            if len(sequence) < num_history + 2:
                continue
            
            # Create samples from positions that have enough history and a next state
            for pos_idx in range(num_history, len(sequence) - 1):
                self.samples.append((seq_idx, pos_idx))
        
        print(f"Created {len(self.samples)} training samples from {len(self.sequences)} sequences")
    
    def __len__(self):
        return len(self.samples)
    
    def flip_board_for_white_perspective(self, board):
        """
        Flip board representation so it's always from white's perspective.
        If it's black's turn, flip the board vertically and swap colors.
        """
        from src.chess import encode_onehot_gamestate
        
        encoding = encode_onehot_gamestate(board)  # (8, 8, 12)
        
        if board.turn == chess.BLACK:
            # Flip board vertically (rank 0 <-> rank 7)
            encoding = torch.flip(encoding, dims=[0])
            # Swap white and black channels (0-5 <-> 6-11)
            white_channels = encoding[:, :, 0:6].clone()
            black_channels = encoding[:, :, 6:12].clone()
            encoding[:, :, 0:6] = black_channels
            encoding[:, :, 6:12] = white_channels
        
        return encoding
    
    def __getitem__(self, idx):
        seq_idx, pos_idx = self.samples[idx]
        sequence = self.sequences[seq_idx]
        
        # Current position (input)
        current_board = sequence[pos_idx]
        input_encoding = self.flip_board_for_white_perspective(current_board)
        
        # History states (8 previous positions)
        history_encodings = []
        for i in range(NUM_HISTORY_STATES):
            hist_idx = pos_idx - NUM_HISTORY_STATES + i
            hist_board = sequence[hist_idx]
            hist_encoding = self.flip_board_for_white_perspective(hist_board)
            history_encodings.append(hist_encoding.flatten())
        
        # Next state (1 future position)
        next_board = sequence[pos_idx + 1]
        next_encoding = self.flip_board_for_white_perspective(next_board)
        next_flat = next_encoding.flatten()
        
        # Combine history + next state
        # Output: [state_-7, state_-6, ..., state_-1, state_+1]
        target = torch.cat(history_encodings + [next_flat])
        
        return input_encoding, target

# Create datasets
train_dataset = ChessHistoryDataset(train_seq_dataset, num_history=NUM_HISTORY_STATES)
val_dataset = ChessHistoryDataset(val_seq_dataset, num_history=NUM_HISTORY_STATES)
test_dataset = ChessHistoryDataset(test_seq_dataset, num_history=NUM_HISTORY_STATES)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"\nBatches per epoch:")
print(f"  Train: {len(train_loader)}")
print(f"  Val:   {len(val_loader)}")

Loading datasets...



  train_seq_dataset = torch.load(SCRATCH_DIR / 'train_seq_dataset.pt')
  val_seq_dataset = torch.load(SCRATCH_DIR / 'val_seq_dataset.pt')
  val_seq_dataset = torch.load(SCRATCH_DIR / 'val_seq_dataset.pt')

  test_seq_dataset = torch.load(SCRATCH_DIR / 'test_seq_dataset.pt')


Sequence dataset sizes:
  Train: 700
  Val:   150
  Test:  150
Created 36159 training samples from 700 sequences
Created 7800 training samples from 150 sequences
Created 7771 training samples from 150 sequences

Batches per epoch:
  Train: 2260
  Val:   488


## PyTorch Lightning Module

Wrapper for the EqProp model with Lightning training logic:

In [4]:
class ChessEqPropLightning(pl.LightningModule):
    """
    PyTorch Lightning module for training chess EqProp model.
    Predicts 8 historical board states + 1 next board state from current position.
    All from white's perspective.
    """

    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        
        # Save hyperparameters
        self.save_hyperparameters(config)
        
        # Extract config
        self.input_size = config['input_size']
        self.hidden_size = config['hidden_size']
        self.output_size = config['output_size']
        self.board_state_size = config.get('board_state_size', 768)
        self.num_history_states = config.get('num_history_states', 8)
        self.num_future_states = config.get('num_future_states', 1)
        self.total_output_states = config.get('total_output_states', 9)
        
        self.T1 = config['T1']
        self.T2 = config['T2']
        self.N = config['N']
        self.beta = config['beta']
        self.lr_dynamics = config['lr_dynamics']
        self.lr_learning = config['lr_learning']
        self.noise_std = config['noise_std']
        
        self.state_init_mode = config.get('state_init_mode', 'zeros')
        self.state_init_std = config.get('state_init_std', 0.01)
        
        # Create EqProp model
        self.model = LinearHolomorphicEQProp(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            output_size=self.output_size,
            weight_std=config.get('weight_std', 0.1)
        )
        
        # Loss function
        self.mse_loss = nn.MSELoss()

    def configure_state(self, batch_size: Optional[int] = None):
        """
        Configure/reset the model state.
        
        This is where you can implement custom state initialization.
        Called at the beginning of each forward pass or can be called manually.
        
        Args:
            batch_size: If provided, configure for batched processing
        """
        if self.state_init_mode == 'zeros':
            self.model.state.zero_()
        elif self.state_init_mode == 'random':
            self.model.state.normal_(0, self.state_init_std)
            # Keep input state at zero
            self.model.state[:self.input_size] = 0.0
        elif self.state_init_mode == 'custom':
            # Custom state initialization
            self.model.state.zero_()
            # Initialize output state with prior over history reconstruction
            # Bias toward more recent states being more accurate
            o_start = self.input_size + self.hidden_size
            
            # History states: more recent = less noise
            for i in range(self.num_history_states):
                # More recent history (higher i) gets smaller initialization noise
                noise_scale = 0.01 * (1.0 - i / self.num_history_states)
                start_idx = o_start + i * self.board_state_size
                end_idx = start_idx + self.board_state_size
                self.model.state[start_idx:end_idx] = torch.randn(self.board_state_size) * noise_scale
            
            # Future state: slightly higher noise (it's a prediction)
            future_start = o_start + self.num_history_states * self.board_state_size
            future_end = future_start + self.board_state_size
            self.model.state[future_start:future_end] = torch.randn(self.board_state_size) * 0.02

    def parse_output(self, output: torch.Tensor):
        """
        Parse model output into history states and next state prediction.
        
        Args:
            output: Output tensor (output_size,) or (batch_size, output_size)
        
        Returns:
            (history_states, next_state)
        """
        if output.dim() == 1:
            # Single sample
            output_reshaped = output.reshape(self.total_output_states, self.board_state_size)
            history = output_reshaped[:self.num_history_states]  # First 8 states
            next_state = output_reshaped[self.num_history_states:]  # Last 1 state
            return history, next_state
        else:
            # Batch
            batch_size = output.shape[0]
            output_reshaped = output.reshape(batch_size, self.total_output_states, self.board_state_size)
            history = output_reshaped[:, :self.num_history_states, :]  # First 8 states
            next_state = output_reshaped[:, self.num_history_states:, :]  # Last 1 state
            return history, next_state

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the model.
        
        Args:
            x: Input tensor (batch_size, 8, 8, 12) - chess board encoding
        
        Returns:
            Output tensor (batch_size, output_size)
        """
        # Flatten board encoding
        batch_size = x.shape[0]
        x_flat = x.reshape(batch_size, -1)  # (batch_size, 768)
        
        # Process each sample in batch
        outputs = []
        for i in range(batch_size):
            # Configure state for this sample
            self.configure_state()
            
            # Forward pass (settles to equilibrium)
            output = self.model.evaluate(
                x_flat[i],
                noise_std=0.0,  # No noise during inference
                T_settle=self.T1,
                lr_dynamics=self.lr_dynamics
            )
            outputs.append(output)
        
        return torch.stack(outputs)

    def training_step(self, batch, batch_idx):
        """
        Training step using equilibrium propagation.
        """
        # batch = (input, target)
        # input: (batch_size, 8, 8, 12)
        # target: (batch_size, output_size) = (batch_size, 9 * 768)
        x, targets = batch
        batch_size = x.shape[0]
        x_flat = x.reshape(batch_size, -1)
        
        # Train each sample with EqProp
        total_loss = 0.0
        total_history_loss = 0.0
        total_next_loss = 0.0
        
        for i in range(batch_size):
            # Configure state
            self.configure_state()
            
            # Set input
            self.model.set_inputs(x_flat[i])
            
            # Create mask: apply stronger supervision to next state prediction than history
            mask = torch.ones_like(targets[i])
            history_size = self.num_history_states * self.board_state_size
            mask[:history_size] = 0.3  # Lower weight on history reconstruction
            mask[history_size:] = 1.0  # Full weight on next state prediction
            
            # Perform EqProp learning
            self.model.learn(
                target=targets[i],
                mask=mask,
                beta=self.beta,
                T1=self.T1,
                T2=self.T2,
                N=self.N,
                lr_dynamics=self.lr_dynamics,
                lr_learning=self.lr_learning,
                noise_std=self.noise_std
            )
            
            # Compute losses for logging
            with torch.no_grad():
                output = self.model.output_state
                
                # Parse output
                pred_history, pred_next = self.parse_output(output)
                target_history, target_next = self.parse_output(targets[i])
                
                # History reconstruction loss (MSE)
                history_loss = self.mse_loss(pred_history, target_history)
                
                # Next state prediction loss (MSE)
                next_loss = self.mse_loss(pred_next, target_next)
                
                # Combined loss
                loss = 0.3 * history_loss + next_loss
                
                total_loss += loss.item()
                total_history_loss += history_loss.item()
                total_next_loss += next_loss.item()
        
        avg_loss = total_loss / batch_size
        avg_history_loss = total_history_loss / batch_size
        avg_next_loss = total_next_loss / batch_size
        
        # Log metrics
        self.log('train_loss', avg_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_history_loss', avg_history_loss, on_step=False, on_epoch=True)
        self.log('train_next_loss', avg_next_loss, on_step=False, on_epoch=True)
        
        return {'loss': avg_loss}

    def validation_step(self, batch, batch_idx):
        """
        Validation step - inference only.
        """
        x, targets = batch
        batch_size = x.shape[0]
        
        # Forward pass
        outputs = self.forward(x)
        
        # Compute losses
        total_history_loss = 0.0
        total_next_loss = 0.0
        
        for i in range(batch_size):
            pred_history, pred_next = self.parse_output(outputs[i])
            target_history, target_next = self.parse_output(targets[i])
            
            # History reconstruction loss
            history_loss = self.mse_loss(pred_history, target_history)
            total_history_loss += history_loss.item()
            
            # Next state prediction loss
            next_loss = self.mse_loss(pred_next, target_next)
            total_next_loss += next_loss.item()
        
        avg_history_loss = total_history_loss / batch_size
        avg_next_loss = total_next_loss / batch_size
        avg_loss = 0.3 * avg_history_loss + avg_next_loss
        
        # Log metrics
        self.log('val_loss', avg_loss, on_epoch=True, prog_bar=True)
        self.log('val_history_loss', avg_history_loss, on_epoch=True)
        self.log('val_next_loss', avg_next_loss, on_epoch=True, prog_bar=True)
        
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        """
        Configure optimizer.
        
        Note: EqProp handles its own weight updates, so this is primarily
        for compatibility with Lightning. We use a dummy optimizer.
        """
        # Return a dummy optimizer since EqProp handles updates internally
        optimizer = torch.optim.SGD(self.parameters(), lr=0.0)
        return optimizer

    def on_train_epoch_end(self):
        """Called at the end of training epoch."""
        # Log weight statistics
        with torch.no_grad():
            W = self.model.W
            self.log('weight_mean', W.mean())
            self.log('weight_std', W.std())
            self.log('weight_max', W.max())
            self.log('weight_min', W.min())

print("ChessEqPropLightning module defined.")

ChessEqPropLightning module defined.


## Initialize Model

Create the model with configuration:

In [5]:
# Configuration dictionary
config = {
    'input_size': INPUT_SIZE,
    'hidden_size': HIDDEN_SIZE,
    'output_size': OUTPUT_SIZE,
    'board_state_size': BOARD_STATE_SIZE,
    'num_history_states': NUM_HISTORY_STATES,
    'num_future_states': NUM_FUTURE_STATES,
    'total_output_states': TOTAL_OUTPUT_STATES,
    'weight_std': WEIGHT_STD,
    'T1': T1,
    'T2': T2,
    'N': N,
    'beta': BETA,
    'lr_dynamics': LR_DYNAMICS,
    'lr_learning': LR_LEARNING,
    'noise_std': NOISE_STD,
    'state_init_mode': STATE_INIT_MODE,
    'state_init_std': STATE_INIT_STD,
}

# Create model
model = ChessEqPropLightning(config)

print(f"Model created:")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Total state size: {model.model.total_size:,}")
print(f"  Weight matrix shape: {model.model.W.shape}")
print(f"\n  Output structure:")
print(f"    - History states: {NUM_HISTORY_STATES} × {BOARD_STATE_SIZE} = {NUM_HISTORY_STATES * BOARD_STATE_SIZE}")
print(f"    - Next state:     {NUM_FUTURE_STATES} × {BOARD_STATE_SIZE} = {NUM_FUTURE_STATES * BOARD_STATE_SIZE}")
print(f"    - Total output:   {TOTAL_OUTPUT_STATES} × {BOARD_STATE_SIZE} = {OUTPUT_SIZE}")

Model created:
  Total parameters: 94,633,984
  Total state size: 9,728
  Weight matrix shape: torch.Size([9728, 9728])

  Output structure:
    - History states: 8 × 768 = 6144
    - Next state:     1 × 768 = 768
    - Total output:   9 × 768 = 6912


## Setup Wandb Logging

Initialize Weights & Biases for experiment tracking:

In [6]:
# Initialize wandb
wandb.init(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    name=EXPERIMENT_NAME,
    config=config,
)

print("Wandb initialized.")
print(f"  Project: {WANDB_PROJECT}")
print(f"  Run name: {EXPERIMENT_NAME}")
print(f"  Run URL: {wandb.run.get_url() if wandb.run else 'N/A'}")

[34m[1mwandb[0m: Currently logged in as: [33mcharles-s-strauss[0m ([33mcharles-s-strauss-n-a[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Wandb initialized.
  Project: eq-chess
  Run name: eqprop-history-prediction
  Run URL: https://wandb.ai/charles-s-strauss-n-a/eq-chess/runs/xa365r14


## Train Model

Run the manual EqProp training loop.

This directly calls `model.learn()` for each sample, which performs:
1. Free phase: Settle to equilibrium with beta=0
2. Nudged phase(s): Settle with beta (holomorphic N-phase sampling)
3. Weight update: Based on difference between free and nudged equilibria

In [None]:
print("Starting training...\n")

# Manual EqProp training loop
from tqdm import tqdm
import time

# Training history
history = {
    'train_loss': [],
    'train_history_loss': [],
    'train_next_loss': [],
    'val_loss': [],
    'val_history_loss': [],
    'val_next_loss': [],
}

best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    # ========== Training Phase ==========
    model.model.train()
    train_loss = 0.0
    train_history_loss = 0.0
    train_next_loss = 0.0
    num_train_samples = 0
    
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")
    for batch_idx, (x, targets) in enumerate(train_pbar):
        batch_size = x.shape[0]
        x_flat = x.reshape(batch_size, -1)
        
        # Move to device
        if torch.backends.mps.is_available():
            x_flat = x_flat.to('mps')
            targets = targets.to('mps')
            model.model = model.model.to('mps')
        elif torch.cuda.is_available():
            x_flat = x_flat.cuda()
            targets = targets.cuda()
            model.model = model.model.cuda()
        
        # Train each sample with EqProp
        batch_loss = 0.0
        batch_history_loss = 0.0
        batch_next_loss = 0.0
        
        for i in range(batch_size):
            # Configure state
            model.configure_state()
            
            # Set input
            model.model.set_inputs(x_flat[i])
            
            # Create mask: stronger supervision on next state
            mask = torch.ones_like(targets[i])
            history_size = NUM_HISTORY_STATES * BOARD_STATE_SIZE
            mask[:history_size] = 0.3  # Lower weight on history
            mask[history_size:] = 1.0  # Full weight on next state
            
            # Perform EqProp learning (this updates weights)
            model.model.learn(
                target=targets[i],
                mask=mask,
                beta=BETA,
                T1=T1,
                T2=T2,
                N=N,
                lr_dynamics=LR_DYNAMICS,
                lr_learning=LR_LEARNING,
                noise_std=NOISE_STD
            )
            
            # Compute losses for logging
            with torch.no_grad():
                output = model.model.output_state
                pred_history, pred_next = model.parse_output(output)
                target_history, target_next = model.parse_output(targets[i])
                
                history_loss = model.mse_loss(pred_history, target_history)
                next_loss = model.mse_loss(pred_next, target_next)
                loss = 0.3 * history_loss + next_loss
                
                batch_loss += loss.item()
                batch_history_loss += history_loss.item()
                batch_next_loss += next_loss.item()
        
        # Average over batch
        avg_batch_loss = batch_loss / batch_size
        avg_batch_history_loss = batch_history_loss / batch_size
        avg_batch_next_loss = batch_next_loss / batch_size
        
        train_loss += batch_loss
        train_history_loss += batch_history_loss
        train_next_loss += batch_next_loss
        num_train_samples += batch_size
        
        # Update progress bar
        train_pbar.set_postfix({
            'loss': f'{avg_batch_loss:.4f}',
            'hist': f'{avg_batch_history_loss:.4f}',
            'next': f'{avg_batch_next_loss:.4f}'
        })
        
        # Log to wandb (per batch)
        if wandb.run is not None:
            wandb.log({
                'batch_train_loss': avg_batch_loss,
                'batch_train_history_loss': avg_batch_history_loss,
                'batch_train_next_loss': avg_batch_next_loss,
                'epoch': epoch,
                'batch': batch_idx
            })
    
    # Average training losses
    avg_train_loss = train_loss / num_train_samples
    avg_train_history_loss = train_history_loss / num_train_samples
    avg_train_next_loss = train_next_loss / num_train_samples
    
    # ========== Validation Phase ==========
    model.model.eval()
    val_loss = 0.0
    val_history_loss = 0.0
    val_next_loss = 0.0
    num_val_samples = 0
    
    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]")
    with torch.no_grad():
        for x, targets in val_pbar:
            batch_size = x.shape[0]
            
            # Forward pass (inference only)
            outputs = model.forward(x)
            
            # Compute losses
            for i in range(batch_size):
                pred_history, pred_next = model.parse_output(outputs[i])
                target_history, target_next = model.parse_output(targets[i])
                
                history_loss = model.mse_loss(pred_history, target_history)
                next_loss = model.mse_loss(pred_next, target_next)
                loss = 0.3 * history_loss + next_loss
                
                val_loss += loss.item()
                val_history_loss += history_loss.item()
                val_next_loss += next_loss.item()
            
            num_val_samples += batch_size
            
            # Update progress bar
            val_pbar.set_postfix({
                'loss': f'{val_loss/num_val_samples:.4f}',
                'next': f'{val_next_loss/num_val_samples:.4f}'
            })
    
    # Average validation losses
    avg_val_loss = val_loss / num_val_samples
    avg_val_history_loss = val_history_loss / num_val_samples
    avg_val_next_loss = val_next_loss / num_val_samples
    
    # ========== Logging & Checkpointing ==========
    epoch_time = time.time() - epoch_start
    
    # Save to history
    history['train_loss'].append(avg_train_loss)
    history['train_history_loss'].append(avg_train_history_loss)
    history['train_next_loss'].append(avg_train_next_loss)
    history['val_loss'].append(avg_val_loss)
    history['val_history_loss'].append(avg_val_history_loss)
    history['val_next_loss'].append(avg_val_next_loss)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} - {epoch_time:.1f}s")
    print(f"  Train - Loss: {avg_train_loss:.4f}, History: {avg_train_history_loss:.4f}, Next: {avg_train_next_loss:.4f}")
    print(f"  Val   - Loss: {avg_val_loss:.4f}, History: {avg_val_history_loss:.4f}, Next: {avg_val_next_loss:.4f}")
    
    # Log to wandb (per epoch)
    if wandb.run is not None:
        # Weight statistics
        W = model.model.W
        
        wandb.log({
            'epoch': epoch,
            'train_loss': avg_train_loss,
            'train_history_loss': avg_train_history_loss,
            'train_next_loss': avg_train_next_loss,
            'val_loss': avg_val_loss,
            'val_history_loss': avg_val_history_loss,
            'val_next_loss': avg_val_next_loss,
            'weight_mean': W.mean().item(),
            'weight_std': W.std().item(),
            'weight_max': W.max().item(),
            'weight_min': W.min().item(),
            'epoch_time': epoch_time
        })
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        checkpoint_path = CHECKPOINT_DIR / f'best_model_epoch_{epoch+1}.pt'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.model.state_dict(),
            'val_loss': avg_val_loss,
            'config': config,
        }, checkpoint_path)
        print(f"  → Saved best model to {checkpoint_path}")
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = CHECKPOINT_DIR / f'checkpoint_epoch_{epoch+1}.pt'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.model.state_dict(),
            'val_loss': avg_val_loss,
            'config': config,
            'history': history,
        }, checkpoint_path)
        print(f"  → Saved checkpoint to {checkpoint_path}")
    
    print()

print("\nTraining complete!")

Starting training...



Epoch 1/100 [Train]:   0%|                     | 4/2260 [01:22<12:43:52, 20.32s/it, loss=1.2595, hist=0.9726, next=0.9677] hist=0.9735, next=0.9699]:   0%|                     | 1/2260 [00:42<13:52:57, 22.12s/it, loss=1.2589, hist=0.9743, next=0.9666]619, hist=0.9735, next=0.9699]| 1/2260 [00:22<13:52:57, 22.12s/it, loss=1.2619, hist=0.9735, next=0.9699]60 [00:22<?, ?it/s, loss=1.2619, hist=0.9735, next=0.9699]Epoch 1/100 [Train]:   0%|                     | 2/2260 [00:42<13:08:11, 20.94s/it, loss=1.2589, hist=0.9743, next=0.9666]ext=0.9699]| 1/2260 [00:22<13:52:57, 22.12s/it, loss=1.2619, hist=0.9735, next=0.9699]60 [00:22<?, ?it/s, loss=1.2619, hist=0.9735, next=0.9699]:   0%|                     | 1/2260 [00:42<13:52:57, 22.12s/it, loss=1.2589, hist=0.9743, next=0.9666]619, hist=0.9735, next=0.9699]| 1/2260 [00:22<13:52:57, 22.12s/it, loss=1.2619, hist=0.9735, next=0.9699]60 [00:22<?, ?it/s, loss=1.2619, hist=0.9735, next=0.9699]1:02<13:08:11, 20.94s/it, loss=1.2587, hist=0.9710, ne

## Test Model

Evaluate on the test set:

In [None]:
print("Testing model on test set...\n")

# Load best model
best_checkpoint = list(CHECKPOINT_DIR.glob('best_model_*.pt'))
if best_checkpoint:
    checkpoint = torch.load(best_checkpoint[0])
    model.model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
else:
    print("No checkpoint found, using current model")

# Test evaluation
model.model.eval()
test_loss = 0.0
test_history_loss = 0.0
test_next_loss = 0.0
num_test_samples = 0

test_pbar = tqdm(test_loader, desc="Testing")
with torch.no_grad():
    for x, targets in test_pbar:
        batch_size = x.shape[0]
        
        # Forward pass
        outputs = model.forward(x)
        
        # Compute losses
        for i in range(batch_size):
            pred_history, pred_next = model.parse_output(outputs[i])
            target_history, target_next = model.parse_output(targets[i])
            
            history_loss = model.mse_loss(pred_history, target_history)
            next_loss = model.mse_loss(pred_next, target_next)
            loss = 0.3 * history_loss + next_loss
            
            test_loss += loss.item()
            test_history_loss += history_loss.item()
            test_next_loss += next_loss.item()
        
        num_test_samples += batch_size
        
        # Update progress bar
        test_pbar.set_postfix({
            'loss': f'{test_loss/num_test_samples:.4f}',
            'next': f'{test_next_loss/num_test_samples:.4f}'
        })

# Average test losses
avg_test_loss = test_loss / num_test_samples
avg_test_history_loss = test_history_loss / num_test_samples
avg_test_next_loss = test_next_loss / num_test_samples

print(f"\n{'='*60}")
print(f"Test Results:")
print(f"  Loss:         {avg_test_loss:.4f}")
print(f"  History Loss: {avg_test_history_loss:.4f}")
print(f"  Next Loss:    {avg_test_next_loss:.4f}")
print(f"{'='*60}")

# Log to wandb
if wandb.run is not None:
    wandb.log({
        'test_loss': avg_test_loss,
        'test_history_loss': avg_test_history_loss,
        'test_next_loss': avg_test_next_loss,
    })

## Custom State Configuration Examples

Examples of how to configure state in special ways:

In [None]:
# Example 1: Initialize hidden state with specific pattern
def init_state_chess_prior(model):
    """
    Initialize state with chess-specific priors.
    """
    model.model.state.zero_()
    
    # Initialize hidden units with small random values
    h_start = model.input_size
    h_end = model.input_size + model.hidden_size
    model.model.state[h_start:h_end] = torch.randn(model.hidden_size) * 0.01
    
    # Initialize output units with chess move priors
    # (e.g., center moves slightly more likely)
    o_start = model.input_size + model.hidden_size
    output_state = torch.zeros(model.output_size)
    # TODO: Add chess-specific initialization
    model.model.state[o_start:] = output_state

# Example 2: Warm-start from previous game state
def init_state_from_previous(model, previous_state):
    """
    Initialize state from a previous equilibrium.
    Useful for sequential game positions.
    """
    model.model.state.copy_(previous_state)
    # Clear input portion
    model.model.state[:model.input_size] = 0.0

# Example 3: Initialize with noise in specific regions
def init_state_region_noise(model, region_indices, noise_std=0.1):
    """
    Initialize specific regions with noise.
    """
    model.model.state.zero_()
    model.model.state[region_indices] = torch.randn(len(region_indices)) * noise_std

print("Custom state initialization functions defined.")
print("\nTo use custom initialization, modify the configure_state() method in ChessEqPropLightning.")

## Save/Load Model

Save and load trained models:

In [None]:
# Save final model manually
final_save_path = CHECKPOINT_DIR / 'final_model.pt'
torch.save({
    'model_state_dict': model.model.state_dict(),
    'config': config,
    'history': history,
}, final_save_path)
print(f"Final model saved to: {final_save_path}")

# Load model example
def load_eqprop_model(checkpoint_path, device='cpu'):
    """Load a saved EqProp model."""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Recreate model with saved config
    loaded_config = checkpoint['config']
    loaded_model = ChessEqPropLightning(loaded_config)
    loaded_model.model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"Loaded model from {checkpoint_path}")
    if 'epoch' in checkpoint:
        print(f"  Epoch: {checkpoint['epoch']+1}")
    if 'val_loss' in checkpoint:
        print(f"  Val Loss: {checkpoint['val_loss']:.4f}")
    
    return loaded_model

# Example: Load best model
best_models = list(CHECKPOINT_DIR.glob('best_model_*.pt'))
if best_models:
    loaded_model = load_eqprop_model(best_models[0])
    print(f"\nAccess EqProp module: loaded_model.model")
    print(f"Weight matrix shape: {loaded_model.model.W.shape}")

## Finish wandb run

In [None]:
wandb.finish()
print("Wandb run finished.")

## Notes

### Output Structure

The network predicts **9 board states** from white's perspective:
1. **8 Historical Board States** (6144 dimensions): Reconstructs the last 8 half-moves (4 full turns)
2. **1 Next Board State** (768 dimensions): Predicts the next board position after the current move

**Total output: 6,912 dimensions (9 × 768)**

This is simpler than encoding moves as from-to vectors - the network learns to predict entire board states, which:
- Captures all move types naturally (including castling, en passant, promotion)
- Provides richer supervision signal
- Enables the network to learn board patterns holistically

### Perspective Normalization

All boards are flipped to white's perspective:
- If it's white's turn: board stays as-is
- If it's black's turn: board is flipped vertically and colors are swapped

This ensures the network always learns from white's point of view, simplifying the learning task.

### Masking for Partial Supervision

The training uses masks to weight different parts of the output:
- **History reconstruction: 0.3 weight** (auxiliary task - helps learn temporal dynamics)
- **Next state prediction: 1.0 weight** (primary task - the actual prediction goal)

You can customize these weights in the `training_step()` method.

### State Configuration

The `configure_state()` method has three modes:

1. **'zeros'** (default): Start from zero state
2. **'random'**: Small random initialization
3. **'custom'**: Temporal prior initialization
   - More recent history → less noise (should be easier to recall)
   - Older history → more noise (harder to recall)
   - Next state → slightly higher noise (it's a prediction, not memory)

This biases the network toward better recent memory and uncertain predictions.

### Loss Functions

- **History loss**: MSE between predicted and target historical states
- **Next state loss**: MSE between predicted and target next state
- **Combined loss**: 0.3 × history_loss + 1.0 × next_loss

Logged separately in wandb as `train_history_loss` and `train_next_loss`.

### Evaluation Metrics

Currently using MSE loss for both history and prediction. Future additions:
- Per-timestep history accuracy
- Piece-wise prediction accuracy
- Legal position validation
- Move extraction from state diff

### Performance Tips

- Output size is ~7K dimensions (9 board states)
- Batch size reduced to 16 for memory efficiency
- Hidden layer increased to 2048 to handle temporal complexity
- Monitor history vs next-state loss separately to ensure both tasks are learning
- The network learns both **memory** (history) and **prediction** (next state) simultaneously