In [74]:
try: import datasets; print(f"datasets: {datasets.__version__}") 
except ImportError: print("datasets: Not installed.") 
try: import torch; print(f"torch: {torch.__version__}") 
except ImportError: print("torch: Not installed.") 
try: import numpy; print(f"numpy: {numpy.__version__}") 
except ImportError: print("numpy: Not installed.") 
try: import tqdm; print(f"tqdm: {tqdm.__version__}") 
except ImportError: print("tqdm: Not installed.") 

datasets: 3.5.0
torch: 2.5.1
numpy: 2.0.1
tqdm: 4.67.1


In [75]:
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
import numpy as np
# import torch.nn.functional as F
from torch.utils.data import DataLoader
from preprocessing import add_representations, fen_to_piece_maps, fen_to_token_ids
from tqdm import tqdm

torch.set_float32_matmul_precision('medium')

In [11]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE == torch.device("cpu"):
    print("Using CPU, not recommended")

In [4]:
def collate_fn(batch):
    batch_fens = [example['fen'] for example in batch]
    labels = torch.tensor(
        [example['target'] for example in batch],
        dtype=torch.float32
    )
    inputs = torch.stack([
        torch.tensor(fen_to_token_ids(fen), dtype=torch.long)
        for fen in batch_fens
    ])
    return inputs, labels

In [62]:
train_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/train"))
val_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/validation"))
test_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/test"))

num_training_examples = len(train_dataset)

train_dataset = train_dataset.to_iterable_dataset(num_shards=32)
val_dataset = val_dataset.to_iterable_dataset()
test_dataset = test_dataset.to_iterable_dataset()

train_dataset = train_dataset.shuffle(buffer_size=10000)
val_dataset = val_dataset.shuffle(buffer_size=10000)
test_dataset = test_dataset.shuffle(buffer_size=10000)

In [63]:
train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=256, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn)

In [83]:
class ChessEvalTransformer(nn.Module):
    def __init__(self, vocab_size=31, d_model=128, n_heads=4, n_layers=8):      # vocab size is 31 because there are 12 piece tokens, 16 castling tokens, 2 side-to-move tokens + '0'
        super().__init__()
        seq_len = 64 + 1 + 1
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(seq_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, n_heads, dim_feedforward=d_model*4, dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.norm = nn.LayerNorm(d_model)
        self.reg_head = nn.Sequential(
            nn.Linear(d_model, d_model//2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model//2, 1)
        )

    def forward(self, x):
        # x: (B, 66)
        B, L = x.size()

        tok_emb = self.embed(x)                     # (B,66,d)
        pos = torch.arange(L, device=DEVICE)
        pos_emb = self.pos_embed(pos).unsqueeze(0)  # (1,66,d)

        h = tok_emb + pos_emb                       # (B,66,d)
        h = self.transformer(h.permute(1,0,2))      # (66,B,d)     - transformer expects (seq, batch, d)
        h = h.mean(dim=0)                           # (B,d)
        h = self.norm(h)

        out = self.reg_head(h).squeeze(-1)          # (B,)
        return torch.tanh(out)                      # in the range [-1,1]

In [85]:
NUM_EPOCHS = 3
total_iters = NUM_EPOCHS * ((10_000_000 // 256) + 1)

model = ChessEvalTransformer().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=total_iters, eta_min=1e-4)
criterion = nn.MSELoss()

steps, train_losses, train_maes, val_losses, val_maes = [], [], [], [], []       # for tracking performance



In [None]:
best_val_loss = float('inf')
best_val_mae = float('inf')
num_iterations = 0
patience_counter = 0    # Count iterations without sufficient improvement
PATIENCE = 15000        # How many iterations to wait, noting that total_iters is around 117k for 3 epochs and batch size of 256
MIN_IMPROVEMENT = 5e-4  # Minimum improvement required for early stopping to not trigger
VAL_ITERS = 5000
LOG_ITERS = 1000

early_stop = False

for epoch in range(1, NUM_EPOCHS + 1):
    if early_stop:
        break

    model.train()
    total_loss = 0.0
    total_mae = 0.0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch}"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        preds = model(inputs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()   

        batch_loss = loss.item()
        batch_mae = torch.mean(torch.abs(preds - labels)).item()

        total_loss += batch_loss * inputs.size(0)
        total_mae += batch_mae * inputs.size(0)
        num_iterations += 1

        # Every 1000 steps, log training loss and MAE
        if num_iterations % LOG_ITERS == 0:
            avg_train_loss = total_loss / (num_iterations * train_loader.batch_size)
            avg_train_mae = total_mae / (num_iterations * train_loader.batch_size)
            print(f"Step {num_iterations} — train_loss: {avg_train_loss:.4f}, train_mae: {avg_train_mae:.4f}")

        # Every 5000 steps, run validation and record performance
        if num_iterations % VAL_ITERS == 0:
            model.eval()
            val_loss_sum = 0.0
            val_mae_sum = 0.0
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)
                    val_pred = model(val_inputs)
                    val_loss = criterion(val_pred, val_labels)
                    val_mae = torch.mean(torch.abs(val_pred - val_labels))

                    val_loss_sum += val_loss.item() * val_inputs.size(0)
                    val_mae_sum += val_mae.item() * val_inputs.size(0)

            avg_val_loss = val_loss_sum / 250_000
            avg_val_mae = val_mae_sum / 250_000

            print(f"Validating! Step {num_iterations} — train_loss: {avg_train_loss:.4f}, train_mae: {avg_train_mae:.4f}, val_loss: {avg_val_loss:.4f}, val_mae: {avg_val_mae:.4f}, current_LR: {scheduler.get_last_lr()[0]:.4f}")

            steps.append(num_iterations)
            train_losses.append(avg_train_loss)
            train_maes.append(avg_train_mae)
            val_losses.append(avg_val_loss)
            val_maes.append(avg_val_mae)

            # Checkpoint best
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), "best_chess_transformer.pth")
                print(f"Saved new best model after {num_iterations} iters")
            
            # Check if validation MAE improved by 0.0005 at least
            if avg_val_mae + MIN_IMPROVEMENT < best_val_mae:
                best_val_mae = avg_val_mae
                patience_counter = 0  # Reset patience
                print(f"--------Validation MAE improved to {best_val_mae:.6f}--------")
            else:
                patience_counter += VAL_ITERS  # because we validate after every VAL_ITERS 
                print(f"--------No significant MAE improvement for {patience_counter} iterations--------")

            # Early stopping
            if patience_counter >= PATIENCE:
                print(f"Early stopping at {num_iterations} iterations (no significant MAE improvement)")
                early_stop = True
                break

            model.train() 

Epoch 1: 101it [00:12,  7.24it/s]

Step 100 — train_loss: 0.2395, train_mae: 0.3369, current_LR: 0.0100


Epoch 1: 201it [00:27,  6.55it/s]

Step 200 — train_loss: 0.2295, train_mae: 0.3282, current_LR: 0.0100


Epoch 1: 301it [00:43,  6.48it/s]

Step 300 — train_loss: 0.2245, train_mae: 0.3238, current_LR: 0.0100


Epoch 1: 401it [00:58,  6.44it/s]

Step 400 — train_loss: 0.2229, train_mae: 0.3222, current_LR: 0.0100


Epoch 1: 499it [01:13,  6.46it/s]

Step 500 — train_loss: 0.2209, train_mae: 0.3206, current_LR: 0.0100


Epoch 1: 501it [02:01, 10.04s/it]

Validating! Step 500 — train_loss: 0.2209, train_mae: 0.3206, val_loss: 0.2162, val_mae: 0.3132
Saved new best model after 500 iters
--------Validation MAE improved to 0.313151--------


Epoch 1: 601it [02:16,  6.10it/s]

Step 600 — train_loss: 0.2210, train_mae: 0.3211, current_LR: 0.0100


Epoch 1: 701it [02:33,  6.50it/s]

Step 700 — train_loss: 0.2208, train_mae: 0.3213, current_LR: 0.0100


Epoch 1: 801it [02:48,  6.37it/s]

Step 800 — train_loss: 0.2204, train_mae: 0.3210, current_LR: 0.0100


Epoch 1: 901it [03:05,  6.04it/s]

Step 900 — train_loss: 0.2202, train_mae: 0.3208, current_LR: 0.0100


Epoch 1: 915it [03:07,  4.88it/s]


KeyboardInterrupt: 

In [None]:
model.eval()
test_loss_sum = 0.0
test_mae_sum = 0.0
with torch.no_grad():
    for test_inputs, test_labels in test_loader:
        test_inputs, test_labels = test_inputs.to(DEVICE), test_labels.to(DEVICE)
        test_pred = model(test_inputs)
        test_loss = criterion(test_pred, test_labels)
        test_mae = torch.mean(torch.abs(test_pred - test_labels))

        test_loss_sum += test_loss.item() * test_inputs.size(0)
        test_mae_sum += test_mae.item() * test_inputs.size(0)

avg_test_loss = test_loss_sum / 250_000
avg_test_mae = test_mae_sum / 250_000

print(f"Test results: test_loss: {avg_test_loss:.4f}, test_mae: {avg_test_mae:.4f}")
