In [1]:
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
import numpy as np
# import torch.nn.functional as F
from torch.utils.data import DataLoader
from preprocessing import fen_to_piece_maps
from tqdm import tqdm

torch.set_float32_matmul_precision('medium')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE == torch.device("cpu"):
    print("Using CPU, not recommended")

In [3]:
def collate_fn(batch):
    batch_fens = [example['fen'] for example in batch]
    labels = torch.tensor(
        [example['target'] for example in batch],
        dtype=torch.float32
    )
    inputs = torch.stack([
        torch.tensor(fen_to_piece_maps(fen), dtype=torch.float32)
        for fen in batch_fens
    ])
    return inputs, labels

In [None]:
# train_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/train"))
# val_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/validation"))
# test_dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_10m/test"))

# num_training_examples = len(train_dataset)

# train_dataset = train_dataset.to_iterable_dataset(num_shards=32)
# val_dataset = val_dataset.to_iterable_dataset()
# test_dataset = test_dataset.to_iterable_dataset()

# train_dataset = train_dataset.shuffle(buffer_size=10000)
# val_dataset = val_dataset.shuffle(buffer_size=10000)
# test_dataset = test_dataset.shuffle(buffer_size=10000)

In [None]:
dataset = load_from_disk(os.path.join(os.getcwd(), "processed_data/lichess_db_eval_part1"))

# Split the dataset into train, validation, and test sets
train_size = int(0.98 * len(dataset))
val_size = int(0.01 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size]
)

# Convert to iterable datasets
train_dataset = train_dataset.dataset.to_iterable_dataset(num_shards=32)
val_dataset = val_dataset.dataset.to_iterable_dataset()
test_dataset = test_dataset.dataset.to_iterable_dataset()

# Shuffle the datasets
train_dataset = train_dataset.shuffle(buffer_size=10000)
val_dataset = val_dataset.shuffle(buffer_size=10000)
test_dataset = test_dataset.shuffle(buffer_size=10000)

In [12]:
train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=256, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn)

In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, channels, dropout_prob=0.2):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout2d(p=dropout_prob)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out += identity
        out = self.relu(out)
        return out

# Resnet-Like
class ChessEvalResNet(nn.Module):
    def __init__(self, input_planes=17, channels=128, num_blocks=10, dropout_prob=0.2):
        super().__init__()
        self.initial_conv = nn.Sequential(
            nn.Conv2d(input_planes, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True)
        )
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(channels, dropout_prob=dropout_prob) for _ in range(num_blocks)]
        )
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(channels, 1)

    def forward(self, x):
        x = self.initial_conv(x)
        x = self.residual_blocks(x)
        x = self.global_pool(x).view(x.size(0), -1)
        x = self.fc(x)
        return x

In [14]:
input_tensor = torch.randn(32, 17, 8, 8)  # A batch of 32 chessboard positions

model = ChessEvalResNet()
output = model(input_tensor)

print(output.shape)

torch.Size([32, 1])


In [15]:
NUM_EPOCHS = 3
total_iters = NUM_EPOCHS * ((10_000_000 // 256) + 1)

model = ChessEvalResNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=total_iters, eta_min=1e-4)
criterion = nn.MSELoss()

steps, train_losses, train_maes, val_losses, val_maes = [], [], [], [], []       # for tracking performance

In [16]:
best_val_loss = float('inf')
best_val_mae = float('inf')
num_iterations = 0
patience_counter = 0    # Count iterations without sufficient improvement
PATIENCE = 15000        # How many iterations to wait, noting that total_iters is around 117k for 3 epochs and batch size of 256
MIN_IMPROVEMENT = 5e-4  # Minimum improvement required for early stopping to not trigger
VAL_ITERS = 500
LOG_ITERS = 100

early_stop = False

for epoch in range(1, NUM_EPOCHS + 1):
    if early_stop:
        break

    model.train()
    total_loss = 0.0
    total_mae = 0.0

    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch}"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        preds = model(inputs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()   

        batch_loss = loss.item()
        batch_mae = torch.mean(torch.abs(preds - labels)).item()

        total_loss += batch_loss * inputs.size(0)
        total_mae += batch_mae * inputs.size(0)
        num_iterations += 1

        # Every 1000 steps, log training loss and MAE
        if num_iterations % LOG_ITERS == 0:
            avg_train_loss = total_loss / (num_iterations * train_loader.batch_size)
            avg_train_mae = total_mae / (num_iterations * train_loader.batch_size)
            print(f"Step {num_iterations} — train_loss: {avg_train_loss:.4f}, train_mae: {avg_train_mae:.4f}")

        # Every 5000 steps, run validation and record performance
        if num_iterations % VAL_ITERS == 0:
            model.eval()
            val_loss_sum = 0.0
            val_mae_sum = 0.0
            with torch.no_grad():
                for val_inputs, val_labels in val_loader:
                    val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)
                    val_pred = model(val_inputs)
                    val_loss = criterion(val_pred, val_labels)
                    val_mae = torch.mean(torch.abs(val_pred - val_labels))

                    val_loss_sum += val_loss.item() * val_inputs.size(0)
                    val_mae_sum += val_mae.item() * val_inputs.size(0)

            avg_val_loss = val_loss_sum / 250_000
            avg_val_mae = val_mae_sum / 250_000

            print(f"Validating! Step {num_iterations} — train_loss: {avg_train_loss:.4f}, train_mae: {avg_train_mae:.4f}, val_loss: {avg_val_loss:.4f}, val_mae: {avg_val_mae:.4f}, current_LR: {scheduler.get_last_lr()[0]:.4f}")

            steps.append(num_iterations)
            train_losses.append(avg_train_loss)
            train_maes.append(avg_train_mae)
            val_losses.append(avg_val_loss)
            val_maes.append(avg_val_mae)

            # Checkpoint best
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), "best_chess_transformer.pth")
                print(f"Saved new best model after {num_iterations} iters")
            
            # Check if validation MAE improved by 0.0005 at least
            if avg_val_mae + MIN_IMPROVEMENT < best_val_mae:
                best_val_mae = avg_val_mae
                patience_counter = 0  # Reset patience
                print(f"--------Validation MAE improved to {best_val_mae:.6f}--------")
            else:
                patience_counter += VAL_ITERS  # because we validate after every VAL_ITERS 
                print(f"--------No significant MAE improvement for {patience_counter} iterations--------")

            # Early stopping
            if patience_counter >= PATIENCE:
                print(f"Early stopping at {num_iterations} iterations (no significant MAE improvement)")
                early_stop = True
                break

            model.train() 

  return F.mse_loss(input, target, reduction=self.reduction)
Epoch 1: 102it [00:10, 10.82it/s]

Step 100 — train_loss: 1.2811, train_mae: 0.4952


Epoch 1: 202it [00:19, 10.63it/s]

Step 200 — train_loss: 0.7552, train_mae: 0.4167


Epoch 1: 302it [00:29, 10.32it/s]

Step 300 — train_loss: 0.5787, train_mae: 0.3879


Epoch 1: 401it [00:38,  9.79it/s]

Step 400 — train_loss: 0.4905, train_mae: 0.3736


Epoch 1: 498it [00:48,  9.08it/s]

Step 500 — train_loss: 0.4366, train_mae: 0.3641


  return F.mse_loss(input, target, reduction=self.reduction)
Epoch 1: 501it [01:57, 11.89s/it]

Validating! Step 500 — train_loss: 0.4366, train_mae: 0.3641, val_loss: 0.2184, val_mae: 0.3167, current_LR: 0.0010
Saved new best model after 500 iters
--------Validation MAE improved to 0.316702--------


Epoch 1: 602it [02:06, 10.49it/s]

Step 600 — train_loss: 0.4005, train_mae: 0.3576


Epoch 1: 702it [02:16,  9.99it/s]

Step 700 — train_loss: 0.3746, train_mae: 0.3527


Epoch 1: 802it [02:26, 10.03it/s]

Step 800 — train_loss: 0.3554, train_mae: 0.3489


Epoch 1: 901it [02:37,  9.60it/s]

Step 900 — train_loss: 0.3410, train_mae: 0.3466


Epoch 1: 999it [02:47, 10.15it/s]

Step 1000 — train_loss: 0.3291, train_mae: 0.3443


Epoch 1: 1001it [03:55, 11.06s/it]

Validating! Step 1000 — train_loss: 0.3291, train_mae: 0.3443, val_loss: 0.2168, val_mae: 0.3102, current_LR: 0.0010
Saved new best model after 1000 iters
--------Validation MAE improved to 0.310151--------


Epoch 1: 1101it [04:05, 10.48it/s]

Step 1100 — train_loss: 0.3189, train_mae: 0.3419


Epoch 1: 1201it [04:15,  9.34it/s]

Step 1200 — train_loss: 0.3106, train_mae: 0.3400


Epoch 1: 1302it [04:25,  9.94it/s]

Step 1300 — train_loss: 0.3036, train_mae: 0.3386


Epoch 1: 1401it [04:36,  9.10it/s]

Step 1400 — train_loss: 0.2974, train_mae: 0.3373


Epoch 1: 1499it [04:46,  9.13it/s]

Step 1500 — train_loss: 0.2920, train_mae: 0.3361


Epoch 1: 1501it [05:58, 15.03s/it]

Validating! Step 1500 — train_loss: 0.2920, train_mae: 0.3361, val_loss: 0.2161, val_mae: 0.3194, current_LR: 0.0010
Saved new best model after 1500 iters
--------No significant MAE improvement for 500 iterations--------


Epoch 1: 1600it [06:07, 10.44it/s]

Step 1600 — train_loss: 0.2875, train_mae: 0.3351


Epoch 1: 1702it [06:18,  9.86it/s]

Step 1700 — train_loss: 0.2835, train_mae: 0.3343


Epoch 1: 1802it [06:28,  9.51it/s]

Step 1800 — train_loss: 0.2799, train_mae: 0.3336


Epoch 1: 1901it [06:39,  9.42it/s]

Step 1900 — train_loss: 0.2767, train_mae: 0.3328


Epoch 1: 1999it [06:49,  9.39it/s]

Step 2000 — train_loss: 0.2742, train_mae: 0.3325


Epoch 1: 2001it [08:00, 12.56s/it]

Validating! Step 2000 — train_loss: 0.2742, train_mae: 0.3325, val_loss: 0.2167, val_mae: 0.3223, current_LR: 0.0010
--------No significant MAE improvement for 1000 iterations--------


Epoch 1: 2102it [08:10, 10.06it/s]

Step 2100 — train_loss: 0.2717, train_mae: 0.3320


Epoch 1: 2201it [08:20,  9.89it/s]

Step 2200 — train_loss: 0.2691, train_mae: 0.3314


Epoch 1: 2300it [08:30,  4.50it/s]


Step 2300 — train_loss: 0.2670, train_mae: 0.3309


KeyboardInterrupt: 