# Chess Dataset Generation

This notebook generates chess board state datasets for training the equilibrium propagation model.

We create two types of datasets:
1. **Board State Datasets**: Random chess positions for static position evaluation
2. **Sequence Datasets**: Time-series of game states for learning game dynamics

In [1]:
import sys
sys.path.append('..')

import torch
from src.chess import (
    encode_onehot_gamestate,
    generate_random_board_states,
    generate_game_sequence,
    generate_multiple_game_sequences,
    ChessBoardDataset,
    ChessSequenceDataset,
    create_random_board_dataset,
    create_game_sequence_dataset,
)
import chess
import os
from pathlib import Path

## Configuration

Set the parameters for dataset generation:

In [2]:
# Output directory
SCRATCH_DIR = Path('scratch')
SCRATCH_DIR.mkdir(exist_ok=True)

# Board state dataset parameters
NUM_BOARD_STATES = 10000  # Total number of random board positions
MAX_MOVES_FROM_START = 30  # Maximum random moves from starting position

# Sequence dataset parameters
NUM_GAME_SEQUENCES = 1000  # Number of game sequences
HALF_MOVES_PER_GAME = 60   # Number of half-moves (plies) per game

# Train/val/test split ratios
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

print(f"Output directory: {SCRATCH_DIR.absolute()}")
print(f"Board states: {NUM_BOARD_STATES}")
print(f"Game sequences: {NUM_GAME_SEQUENCES}")
print(f"Split ratios - Train: {TRAIN_RATIO}, Val: {VAL_RATIO}, Test: {TEST_RATIO}")

Output directory: /Users/c/eq-chess/notebooks/scratch
Board states: 10000
Game sequences: 1000
Split ratios - Train: 0.7, Val: 0.15, Test: 0.15


## Test Encoding Function

Verify that the board encoding works correctly:

In [3]:
# Test with starting position
test_board = chess.Board()
encoding = encode_onehot_gamestate(test_board)

print(f"Encoding shape: {encoding.shape}")
print(f"Expected shape: (8, 8, 12)")
print(f"Encoding dtype: {encoding.dtype}")
print(f"Total pieces encoded: {encoding.sum().item()}")
print(f"Expected total (32 pieces): 32")

# Visualize a slice - white pawns (channel 0)
print("\nWhite pawns (channel 0, rank 1):")
print(encoding[1, :, 0])  # Should show pawns on rank 2 (index 1)

print("\nBlack pawns (channel 6, rank 6):")
print(encoding[6, :, 6])  # Should show pawns on rank 7 (index 6)

Encoding shape: torch.Size([8, 8, 12])
Expected shape: (8, 8, 12)
Encoding dtype: torch.float32
Total pieces encoded: 32.0
Expected total (32 pieces): 32

White pawns (channel 0, rank 1):
tensor([1., 1., 1., 1., 1., 1., 1., 1.])

Black pawns (channel 6, rank 6):
tensor([1., 1., 1., 1., 1., 1., 1., 1.])


## Generate Board State Dataset

Create random chess positions for static position evaluation:

In [4]:
print("Generating random board states...")

train_board_dataset, val_board_dataset, test_board_dataset = create_random_board_dataset(
    num_boards=NUM_BOARD_STATES,
    max_moves=MAX_MOVES_FROM_START,
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    test_ratio=TEST_RATIO
)

print(f"\nBoard State Dataset Sizes:")
print(f"  Training:   {len(train_board_dataset)}")
print(f"  Validation: {len(val_board_dataset)}")
print(f"  Test:       {len(test_board_dataset)}")
print(f"  Total:      {len(train_board_dataset) + len(val_board_dataset) + len(test_board_dataset)}")

# Test dataset access
sample = train_board_dataset[0]
print(f"\nSample tensor shape: {sample.shape}")
print(f"Sample tensor dtype: {sample.dtype}")

Generating random board states...

Board State Dataset Sizes:
  Training:   7000
  Validation: 1500
  Test:       1500
  Total:      10000

Sample tensor shape: torch.Size([8, 8, 12])
Sample tensor dtype: torch.float32


## Generate Sequence Dataset

Create game sequences for learning temporal dynamics:

In [5]:
print("Generating game sequences...")

train_seq_dataset, val_seq_dataset, test_seq_dataset = create_game_sequence_dataset(
    num_games=NUM_GAME_SEQUENCES,
    num_half_moves_per_game=HALF_MOVES_PER_GAME,
    train_ratio=TRAIN_RATIO,
    val_ratio=VAL_RATIO,
    test_ratio=TEST_RATIO
)

print(f"\nSequence Dataset Sizes:")
print(f"  Training:   {len(train_seq_dataset)} sequences")
print(f"  Validation: {len(val_seq_dataset)} sequences")
print(f"  Test:       {len(test_seq_dataset)} sequences")
print(f"  Total:      {len(train_seq_dataset) + len(val_seq_dataset) + len(test_seq_dataset)} sequences")

# Test sequence dataset access
sample_seq = train_seq_dataset[0]
print(f"\nSample sequence shape: {sample_seq.shape}")
print(f"Sample sequence dtype: {sample_seq.dtype}")
print(f"Sequence length: {sample_seq.shape[0]}")

Generating game sequences...

Sequence Dataset Sizes:
  Training:   700 sequences
  Validation: 150 sequences
  Test:       150 sequences
  Total:      1000 sequences

Sample sequence shape: torch.Size([61, 8, 8, 12])
Sample sequence dtype: torch.float32
Sequence length: 61


## Save Datasets

Save the datasets to disk for reuse in other notebooks:

In [6]:
print("Saving board state datasets...")

# Save board state datasets
torch.save(train_board_dataset, SCRATCH_DIR / 'train_board_dataset.pt')
torch.save(val_board_dataset, SCRATCH_DIR / 'val_board_dataset.pt')
torch.save(test_board_dataset, SCRATCH_DIR / 'test_board_dataset.pt')

print("Saving sequence datasets...")

# Save sequence datasets
torch.save(train_seq_dataset, SCRATCH_DIR / 'train_seq_dataset.pt')
torch.save(val_seq_dataset, SCRATCH_DIR / 'val_seq_dataset.pt')
torch.save(test_seq_dataset, SCRATCH_DIR / 'test_seq_dataset.pt')

print("\nAll datasets saved successfully!")
print("\nSaved files:")
for file in sorted(SCRATCH_DIR.glob('*.pt')):
    size_mb = file.stat().st_size / (1024 * 1024)
    print(f"  {file.name:<30} {size_mb:>8.2f} MB")

Saving board state datasets...
Saving sequence datasets...

All datasets saved successfully!

Saved files:
  test_board_dataset.pt              4.20 MB
  test_seq_dataset.pt               12.39 MB
  train_board_dataset.pt            19.64 MB
  train_seq_dataset.pt              57.63 MB
  val_board_dataset.pt               4.21 MB
  val_seq_dataset.pt                12.45 MB


## Dataset Statistics

Analyze the generated datasets:

In [7]:
import numpy as np

print("Computing dataset statistics...\n")

# Board state statistics
print("Board State Dataset Statistics:")
print("="*50)

# Sample a few boards to compute average piece counts
sample_size = min(100, len(train_board_dataset))
piece_counts = []

for i in range(sample_size):
    encoding = train_board_dataset[i]
    piece_count = encoding.sum().item()
    piece_counts.append(piece_count)

print(f"Average pieces per board: {np.mean(piece_counts):.2f}")
print(f"Min pieces: {np.min(piece_counts):.0f}")
print(f"Max pieces: {np.max(piece_counts):.0f}")
print(f"Std dev: {np.std(piece_counts):.2f}")

# Sequence statistics
print("\nSequence Dataset Statistics:")
print("="*50)

sample_size = min(100, len(train_seq_dataset))
seq_lengths = []

for i in range(sample_size):
    seq = train_seq_dataset[i]
    seq_lengths.append(seq.shape[0])

print(f"Average sequence length: {np.mean(seq_lengths):.2f}")
print(f"Min sequence length: {np.min(seq_lengths):.0f}")
print(f"Max sequence length: {np.max(seq_lengths):.0f}")
print(f"Std dev: {np.std(seq_lengths):.2f}")

Computing dataset statistics...

Board State Dataset Statistics:
Average pieces per board: 31.37
Min pieces: 27
Max pieces: 32
Std dev: 0.99

Sequence Dataset Statistics:
Average sequence length: 61.00
Min sequence length: 61
Max sequence length: 61
Std dev: 0.00


## Example: Loading Saved Datasets

Demonstrate how to load the datasets in other notebooks:

In [8]:
print("Example: Loading saved datasets\n")

# Load board state datasets
loaded_train_board = torch.load(SCRATCH_DIR / 'train_board_dataset.pt')
loaded_val_board = torch.load(SCRATCH_DIR / 'val_board_dataset.pt')
loaded_test_board = torch.load(SCRATCH_DIR / 'test_board_dataset.pt')

# Load sequence datasets
loaded_train_seq = torch.load(SCRATCH_DIR / 'train_seq_dataset.pt')
loaded_val_seq = torch.load(SCRATCH_DIR / 'val_seq_dataset.pt')
loaded_test_seq = torch.load(SCRATCH_DIR / 'test_seq_dataset.pt')

print("Loaded board state datasets:")
print(f"  Train: {len(loaded_train_board)}")
print(f"  Val:   {len(loaded_val_board)}")
print(f"  Test:  {len(loaded_test_board)}")

print("\nLoaded sequence datasets:")
print(f"  Train: {len(loaded_train_seq)}")
print(f"  Val:   {len(loaded_val_seq)}")
print(f"  Test:  {len(loaded_test_seq)}")

print("\nDatasets loaded successfully!")

Example: Loading saved datasets



  loaded_train_board = torch.load(SCRATCH_DIR / 'train_board_dataset.pt')
  loaded_val_board = torch.load(SCRATCH_DIR / 'val_board_dataset.pt')
  loaded_test_board = torch.load(SCRATCH_DIR / 'test_board_dataset.pt')
  loaded_val_board = torch.load(SCRATCH_DIR / 'val_board_dataset.pt')
  loaded_test_board = torch.load(SCRATCH_DIR / 'test_board_dataset.pt')

  loaded_train_seq = torch.load(SCRATCH_DIR / 'train_seq_dataset.pt')
  loaded_test_board = torch.load(SCRATCH_DIR / 'test_board_dataset.pt')

  loaded_train_seq = torch.load(SCRATCH_DIR / 'train_seq_dataset.pt')

  loaded_val_seq = torch.load(SCRATCH_DIR / 'val_seq_dataset.pt')
  loaded_test_board = torch.load(SCRATCH_DIR / 'test_board_dataset.pt')

  loaded_train_seq = torch.load(SCRATCH_DIR / 'train_seq_dataset.pt')

  loaded_val_seq = torch.load(SCRATCH_DIR / 'val_seq_dataset.pt')

  loaded_test_seq = torch.load(SCRATCH_DIR / 'test_seq_dataset.pt')


Loaded board state datasets:
  Train: 7000
  Val:   1500
  Test:  1500

Loaded sequence datasets:
  Train: 700
  Val:   150
  Test:  150

Datasets loaded successfully!


## Example: Using with DataLoader

Show how to use these datasets with PyTorch DataLoaders for training:

In [9]:
from torch.utils.data import DataLoader

# Create DataLoaders
batch_size = 32

train_loader = DataLoader(
    loaded_train_board,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0  # Set to 0 for compatibility, increase for faster loading
)

val_loader = DataLoader(
    loaded_val_board,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

print(f"DataLoader created:")
print(f"  Batch size: {batch_size}")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

# Test batch retrieval
batch = next(iter(train_loader))
print(f"\nSample batch shape: {batch.shape}")
print(f"Expected: ({batch_size}, 8, 8, 12)")

DataLoader created:
  Batch size: 32
  Training batches: 219
  Validation batches: 47

Sample batch shape: torch.Size([32, 8, 8, 12])
Expected: (32, 8, 8, 12)


## Visualize a Chess Position

Optionally visualize one of the generated positions:

In [10]:
# Get a random board from the training dataset
sample_board = train_board_dataset.boards[0]

print("Sample chess position:")
print(sample_board)
print(f"\nFEN: {sample_board.fen()}")
print(f"Legal moves: {len(list(sample_board.legal_moves))}")
print(f"Turn: {'White' if sample_board.turn else 'Black'}")
print(f"Game over: {sample_board.is_game_over()}")

Sample chess position:
r . b q k b n r
p p p p p . p .
. . n . . p . .
. . . . . . . p
Q P . . P . . P
. . P . . . . .
P . . P . P P .
R N B . K B N R

FEN: r1bqkbnr/ppppp1p1/2n2p2/7p/QP2P2P/2P5/P2P1PP1/RNB1KBNR w KQ - 1 6
Legal moves: 34
Turn: White
Game over: False


## Summary

Datasets have been generated and saved to `notebooks/scratch/`. You can now load them in other notebooks using:

```python
import torch
from pathlib import Path

SCRATCH_DIR = Path('scratch')

# Load datasets
train_dataset = torch.load(SCRATCH_DIR / 'train_board_dataset.pt')
val_dataset = torch.load(SCRATCH_DIR / 'val_board_dataset.pt')
test_dataset = torch.load(SCRATCH_DIR / 'test_board_dataset.pt')

# Or for sequences
train_seq = torch.load(SCRATCH_DIR / 'train_seq_dataset.pt')
val_seq = torch.load(SCRATCH_DIR / 'val_seq_dataset.pt')
test_seq = torch.load(SCRATCH_DIR / 'test_seq_dataset.pt')
```