In [59]:
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
import numpy as np
# import torch.nn.functional as F
from torch.utils.data import DataLoader
from preprocessing import add_representations, fen_to_piece_maps, fen_to_token_ids
from tqdm import tqdm

torch.set_float32_matmul_precision('medium')

In [11]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE == torch.device("cpu"):
    print("Using CPU, not recommended")

In [4]:
def collate_fn(batch):
    batch_fens = [example['fen'] for example in batch]
    labels = torch.tensor(
        [example['target'] for example in batch],
        dtype=torch.float32
    )
    inputs = torch.stack([
        torch.tensor(fen_to_token_ids(fen), dtype=torch.long)
        for fen in batch_fens
    ])
    return inputs, labels

In [62]:
train_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/train"))
val_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/validation"))
test_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/test"))

num_training_examples = len(train_dataset)

train_dataset = train_dataset.to_iterable_dataset(num_shards=32)
val_dataset = val_dataset.to_iterable_dataset()
test_dataset = test_dataset.to_iterable_dataset()

train_dataset = train_dataset.shuffle(buffer_size=10000)
val_dataset = val_dataset.shuffle(buffer_size=10000)
test_dataset = test_dataset.shuffle(buffer_size=10000)

In [63]:
train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=256, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn)

In [None]:
class ChessEvalTransformer(nn.Module):
    def __init__(self, vocab_size=31, d_model=256, n_heads=8, n_layers=6):      # vocab size is 31 because there are 12 piece tokens, 16 castling tokens, 2 side-to-move tokens + '0'
        super().__init__()
        seq_len = 64 + 1 + 1
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(seq_len, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, n_heads, dim_feedforward=d_model*4, dropout=0.2
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.norm = nn.LayerNorm(d_model)
        self.reg_head = nn.Sequential(
            nn.Linear(d_model, d_model//2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(d_model//2, 1)
        )

    def forward(self, x):
        # x: (B, 66)
        B, L = x.size()

        tok_emb = self.embed(x)                     # (B,66,d)
        pos = torch.arange(L, device=DEVICE)
        pos_emb = self.pos_embed(pos).unsqueeze(0)  # (1,66,d)

        h = tok_emb + pos_emb                       # (B,66,d)
        h = self.transformer(h.permute(1,0,2))      # (66,B,d)     - transformer expects (seq, batch, d)
        h = h.mean(dim=0)                           # (B,d)
        h = self.norm(h)

        out = self.reg_head(h).squeeze(-1)          # (B,)
        return torch.tanh(out)                      # in the range [-1,1]

In [66]:
NUM_EPOCHS = 3
total_iters = NUM_EPOCHS * ((10_000_000 // 256) + 1)

model = ChessEvalTransformer().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=total_iters, eta_min=5e-4)
criterion = nn.MSELoss()

steps, train_losses, train_maes, val_losses, val_maes = [], [], [], [], []       # for tracking performance



In [None]:
best_val_loss = float('inf')
best_val_mae = float('inf')
num_iterations = 0
patience_counter = 0    # Count iterations without sufficient improvement
PATIENCE = 15000        # How many iterations to wait, noting that total_iters is around 117k for 3 epochs and batch size of 256
MIN_IMPROVEMENT = 5e-4  # Minimum improvement required for early stopping to not trigger
VAL_ITERS = 500
LOG_ITERS = 100

early_stop = False

for epoch in range(1, NUM_EPOCHS + 1):
    if early_stop:
        break

    model.train()
    total_loss = 0.0
    total_mae = 0.0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch}"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        preds = model(inputs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()   

        batch_loss = loss.item()
        batch_mae = torch.mean(torch.abs(preds - labels)).item()

        total_loss += batch_loss * inputs.size(0)
        total_mae += batch_mae * inputs.size(0)
        num_iterations += 1

        # Every 1000 steps, log training loss and MAE
        if num_iterations % LOG_ITERS == 0:
            avg_train_loss = total_loss / (num_iterations * train_loader.batch_size)
            avg_train_mae = total_mae / (num_iterations * train_loader.batch_size)
            print(f"Step {num_iterations} — train_loss: {avg_train_loss:.4f}, train_mae: {avg_train_mae:.4f}, current_LR: {scheduler.get_last_lr()[0]:.4f}")

        # Every 5000 steps, run validation and record performance
        if num_iterations % VAL_ITERS == 0:
            model.eval()
            val_loss_sum = 0.0
            val_mae_sum = 0.0
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)
                    val_pred = model(val_inputs)
                    val_loss = criterion(val_pred, val_labels)
                    val_mae = torch.mean(torch.abs(val_pred - val_labels))

                    val_loss_sum += val_loss.item() * val_inputs.size(0)
                    val_mae_sum += val_mae.item() * val_inputs.size(0)

            avg_val_loss = val_loss_sum / 250_000
            avg_val_mae = val_mae_sum / 250_000

            print(f"Validating! Step {num_iterations} — train_loss: {avg_train_loss:.4f}, train_mae: {avg_train_mae:.4f}, val_loss: {avg_val_loss:.4f}, val_mae: {avg_val_mae:.4f}")

            steps.append(num_iterations)
            train_losses.append(avg_train_loss)
            train_maes.append(avg_train_mae)
            val_losses.append(avg_val_loss)
            val_maes.append(avg_val_mae)

            # Checkpoint best
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), "best_chess_transformer.pth")
                print(f"Saved new best model after {num_iterations} iters")
            
            # Check if validation MAE improved enough
            if avg_val_mae + MIN_IMPROVEMENT < best_val_mae:
                best_val_mae = avg_val_mae
                patience_counter = 0  # Reset patience
                print(f"--------Validation MAE improved to {best_val_mae:.6f}--------")
            else:
                patience_counter += VAL_ITERS  # because we validate after every VAL_ITERS 
                print(f"--------No significant MAE improvement for {patience_counter} iterations--------")

            # Early stopping
            if patience_counter >= PATIENCE:
                print(f"Early stopping at {num_iterations} iterations (no significant MAE improvement)")
                early_stop = True
                break

            model.train() 

Epoch 1: 101it [00:10,  9.68it/s]

Step 100 — train_loss: 0.2339, train_mae: 0.3302, current_LR: 0.0100


Epoch 1: 201it [00:21,  7.87it/s]

Step 200 — train_loss: 0.2264, train_mae: 0.3249, current_LR: 0.0100


Epoch 1: 301it [00:34,  8.22it/s]

Step 300 — train_loss: 0.2223, train_mae: 0.3218, current_LR: 0.0100


Epoch 1: 401it [00:46,  8.21it/s]

Step 400 — train_loss: 0.2213, train_mae: 0.3209, current_LR: 0.0100


Epoch 1: 499it [00:59,  7.68it/s]

Step 500 — train_loss: 0.2196, train_mae: 0.3195, current_LR: 0.0100


Epoch 1: 501it [01:39,  8.53s/it]

Validating! Step 500 — train_loss: 0.2196, train_mae: 0.3195, val_loss: 0.2162, val_mae: 0.3128
Saved new best model after 500 iters
--------Validation MAE improved to 0.312769--------


Epoch 1: 601it [01:53,  5.88it/s]

Step 600 — train_loss: 0.2200, train_mae: 0.3203, current_LR: 0.0100


Epoch 1: 701it [02:08,  7.01it/s]

Step 700 — train_loss: 0.2199, train_mae: 0.3206, current_LR: 0.0100


Epoch 1: 801it [02:22,  7.90it/s]

Step 800 — train_loss: 0.2196, train_mae: 0.3204, current_LR: 0.0100


Epoch 1: 901it [02:36,  6.66it/s]

Step 900 — train_loss: 0.2194, train_mae: 0.3203, current_LR: 0.0100


Epoch 1: 999it [02:51,  6.57it/s]

Step 1000 — train_loss: 0.2196, train_mae: 0.3204, current_LR: 0.0100


Epoch 1: 1001it [03:37,  9.88s/it]

Validating! Step 1000 — train_loss: 0.2196, train_mae: 0.3204, val_loss: 0.2160, val_mae: 0.3178
Saved new best model after 1000 iters
--------No significant MAE improvement for 500 iterations--------


Epoch 1: 1101it [03:52,  6.87it/s]

Step 1100 — train_loss: 0.2195, train_mae: 0.3202, current_LR: 0.0100


Epoch 1: 1201it [04:06,  7.09it/s]

Step 1200 — train_loss: 0.2194, train_mae: 0.3200, current_LR: 0.0100


Epoch 1: 1301it [04:22,  7.26it/s]

Step 1300 — train_loss: 0.2194, train_mae: 0.3201, current_LR: 0.0100


Epoch 1: 1401it [04:37,  7.05it/s]

Step 1400 — train_loss: 0.2193, train_mae: 0.3202, current_LR: 0.0100


Epoch 1: 1499it [04:51,  6.84it/s]

Step 1500 — train_loss: 0.2194, train_mae: 0.3203, current_LR: 0.0100


Epoch 1: 1501it [05:37,  9.86s/it]

Validating! Step 1500 — train_loss: 0.2194, train_mae: 0.3203, val_loss: 0.2162, val_mae: 0.3224
--------No significant MAE improvement for 1000 iterations--------


Epoch 1: 1601it [05:53,  5.81it/s]

Step 1600 — train_loss: 0.2195, train_mae: 0.3204, current_LR: 0.0100


Epoch 1: 1701it [06:09,  6.20it/s]

Step 1700 — train_loss: 0.2193, train_mae: 0.3201, current_LR: 0.0100


Epoch 1: 1801it [06:24,  6.73it/s]

Step 1800 — train_loss: 0.2193, train_mae: 0.3200, current_LR: 0.0100


Epoch 1: 1901it [06:38,  7.03it/s]

Step 1900 — train_loss: 0.2190, train_mae: 0.3198, current_LR: 0.0100


Epoch 1: 1999it [06:52,  7.15it/s]

Step 2000 — train_loss: 0.2189, train_mae: 0.3197, current_LR: 0.0100


Epoch 1: 2001it [07:38,  9.67s/it]

Validating! Step 2000 — train_loss: 0.2189, train_mae: 0.3197, val_loss: 0.2163, val_mae: 0.3119
--------Validation MAE improved to 0.311883--------


Epoch 1: 2101it [07:50,  7.88it/s]

Step 2100 — train_loss: 0.2187, train_mae: 0.3195, current_LR: 0.0100


Epoch 1: 2201it [08:03,  7.83it/s]

Step 2200 — train_loss: 0.2187, train_mae: 0.3195, current_LR: 0.0100


Epoch 1: 2301it [08:18,  6.81it/s]

Step 2300 — train_loss: 0.2188, train_mae: 0.3196, current_LR: 0.0100


Epoch 1: 2401it [08:32,  6.61it/s]

Step 2400 — train_loss: 0.2189, train_mae: 0.3197, current_LR: 0.0100


Epoch 1: 2499it [08:47,  6.55it/s]

Step 2500 — train_loss: 0.2189, train_mae: 0.3197, current_LR: 0.0100


Epoch 1: 2501it [09:36, 10.36s/it]

Validating! Step 2500 — train_loss: 0.2189, train_mae: 0.3197, val_loss: 0.2162, val_mae: 0.3215
--------No significant MAE improvement for 500 iterations--------


Epoch 1: 2601it [09:51,  6.62it/s]

Step 2600 — train_loss: 0.2188, train_mae: 0.3195, current_LR: 0.0100


Epoch 1: 2701it [10:07,  6.31it/s]

Step 2700 — train_loss: 0.2185, train_mae: 0.3193, current_LR: 0.0100


Epoch 1: 2801it [10:22,  6.53it/s]

Step 2800 — train_loss: 0.2185, train_mae: 0.3194, current_LR: 0.0100


Epoch 1: 2901it [10:38,  6.13it/s]

Step 2900 — train_loss: 0.2184, train_mae: 0.3193, current_LR: 0.0100


Epoch 1: 2999it [10:54,  6.07it/s]

Step 3000 — train_loss: 0.2184, train_mae: 0.3192, current_LR: 0.0100


Epoch 1: 3001it [11:44, 10.61s/it]

Validating! Step 3000 — train_loss: 0.2184, train_mae: 0.3192, val_loss: 0.2163, val_mae: 0.3118
--------No significant MAE improvement for 1000 iterations--------


Epoch 1: 3101it [12:00,  6.10it/s]

Step 3100 — train_loss: 0.2183, train_mae: 0.3191, current_LR: 0.0100


Epoch 1: 3201it [12:17,  5.81it/s]

Step 3200 — train_loss: 0.2183, train_mae: 0.3190, current_LR: 0.0100


Epoch 1: 3301it [12:35,  6.03it/s]

Step 3300 — train_loss: 0.2183, train_mae: 0.3191, current_LR: 0.0100


Epoch 1: 3401it [12:51,  6.10it/s]

Step 3400 — train_loss: 0.2184, train_mae: 0.3191, current_LR: 0.0100


Epoch 1: 3499it [13:07,  6.04it/s]

Step 3500 — train_loss: 0.2183, train_mae: 0.3192, current_LR: 0.0100


Epoch 1: 3501it [13:57, 10.51s/it]

Validating! Step 3500 — train_loss: 0.2183, train_mae: 0.3192, val_loss: 0.2160, val_mae: 0.3187
--------No significant MAE improvement for 1500 iterations--------


Epoch 1: 3601it [14:14,  6.08it/s]

Step 3600 — train_loss: 0.2184, train_mae: 0.3193, current_LR: 0.0100


Epoch 1: 3701it [14:30,  5.86it/s]

Step 3700 — train_loss: 0.2184, train_mae: 0.3193, current_LR: 0.0100


Epoch 1: 3801it [14:47,  5.92it/s]

Step 3800 — train_loss: 0.2183, train_mae: 0.3192, current_LR: 0.0100


Epoch 1: 3901it [15:04,  5.94it/s]

Step 3900 — train_loss: 0.2183, train_mae: 0.3193, current_LR: 0.0100


Epoch 1: 3979it [15:17,  6.23it/s]