In [None]:
# Kaggle Chess Policy Network Training Notebook

import os
import random

import numpy as np
import pandas as pd
import polars as pl

import chess
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Dataset paths (adapt as needed outside Kaggle)
BASE_PATH = "/kaggle/input/chess-dataset-splitted/Chess-dataset"

train_df = pl.read_parquet(f"{BASE_PATH}/train_2450_2550.parquet")
val_df   = pl.read_parquet(f"{BASE_PATH}/val_2450_2550.parquet")
test_df  = pl.read_parquet(f"{BASE_PATH}/test_2450_2550.parquet")

print("Train/Val/Test shapes:", train_df.shape, val_df.shape, test_df.shape)

PIECE_TO_PLANE = {
    chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
    chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5,
}


def board_to_tensor(board: chess.Board):
    tensor = np.zeros((18, 8, 8), dtype=np.float32)

    for piece_type in PIECE_TO_PLANE:
        for square in board.pieces(piece_type, chess.WHITE):
            r, c = divmod(square, 8)
            tensor[PIECE_TO_PLANE[piece_type], r, c] = 1

        for square in board.pieces(piece_type, chess.BLACK):
            r, c = divmod(square, 8)
            tensor[PIECE_TO_PLANE[piece_type] + 6, r, c] = 1

    tensor[12, :, :] = int(board.turn)
    tensor[13, :, :] = board.has_kingside_castling_rights(chess.WHITE)
    tensor[14, :, :] = board.has_queenside_castling_rights(chess.WHITE)
    tensor[15, :, :] = board.has_kingside_castling_rights(chess.BLACK)
    tensor[16, :, :] = board.has_queenside_castling_rights(chess.BLACK)

    tensor[17, :, :] = board.fullmove_number / 100.0
    return tensor


def move_to_index(move: chess.Move):
    return move.from_square * 64 + move.to_square


class ChessPositionDataset(Dataset):
    def __init__(self, df: pl.DataFrame):
        self.df = df

    def __len__(self):
        return self.df.height

    def __getitem__(self, idx):
        row = self.df.row(idx)

        moves = row[self.df.columns.index("moves_uci")]
        if moves is None or len(moves) < 2:
            return self.__getitem__((idx + 1) % len(self))

        ply_idx = random.randint(0, len(moves) - 2)

        board = chess.Board()
        for i in range(ply_idx):
            board.push_uci(moves[i])

        x = board_to_tensor(board)
        target_move = chess.Move.from_uci(moves[ply_idx])
        y = move_to_index(target_move)

        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return torch.relu(out + x)


class ChessPolicyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(18, 256, 3, padding=1)
        self.bn = nn.BatchNorm2d(256)
        self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(10)])
        self.policy = nn.Conv2d(256, 73, 1)
        self.fc = nn.Linear(73 * 8 * 8, 4672)

    def forward(self, x):
        x = torch.relu(self.bn(self.conv(x)))
        x = self.res_blocks(x)
        x = self.policy(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessPolicyNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

print("Model on device:", device)



In [None]:
# DataLoaders

train_loader = DataLoader(
    ChessPositionDataset(train_df),
    batch_size=128,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)

val_loader = DataLoader(
    ChessPositionDataset(val_df),
    batch_size=128,
    shuffle=False,
)

test_loader = DataLoader(
    ChessPositionDataset(test_df),
    batch_size=128,
    shuffle=False,
)



In [None]:
# Training configuration and resume logic

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

import os

# Paths (adapt if needed)
INPUT_BEST_MODEL = "/kaggle/input/chess-model/pytorch/default/1/best_model.pth"
OUTPUT_DIR = "/kaggle/working"
os.makedirs(OUTPUT_DIR, exist_ok=True)

OUTPUT_BEST_MODEL = f"{OUTPUT_DIR}/best_model.pth"
CHECKPOINT_PATH = f"{OUTPUT_DIR}/checkpoint_latest.pth"
METRICS_PATH = f"{OUTPUT_DIR}/metrics_history.pt"

EPOCHS = 100
PATIENCE = 5
MIN_DELTA = 1e-4

# Load baseline best model
model.load_state_dict(torch.load(INPUT_BEST_MODEL, map_location=device))
model.to(device)
print("Loaded baseline best model from Kaggle INPUT")

start_epoch = 0
best_val_loss = float("inf")
early_stop_counter = 0

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=2,
)

if os.path.exists(METRICS_PATH):
    metrics = torch.load(METRICS_PATH)
else:
    metrics = {
        "epochs": [],
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
    }



In [None]:
# Main training loop

for epoch in range(start_epoch, EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    # ----- TRAIN -----
    model.train()
    train_loss, correct, total = 0.0, 0, 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x.size(0)
        correct += (logits.argmax(1) == y).sum().item()
        total += y.size(0)

    train_loss /= total
    train_acc = correct / total

    # ----- VALIDATION -----
    model.eval()
    val_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)

            val_loss += loss.item() * x.size(0)
            correct += (logits.argmax(1) == y).sum().item()
            total += y.size(0)

    val_loss /= total
    val_acc = correct / total

    # ----- METRICS -----
    metrics["epochs"].append(epoch)
    metrics["train_loss"].append(train_loss)
    metrics["val_loss"].append(val_loss)
    metrics["train_acc"].append(train_acc)
    metrics["val_acc"].append(val_acc)

    torch.save(metrics, METRICS_PATH)

    print(
        f"Train Loss {train_loss:.4f} | Train Acc {train_acc:.4f} | "
        f"Val Loss {val_loss:.4f} | Val Acc {val_acc:.4f}"
    )

    scheduler.step(val_loss)

    # ----- CHECKPOINT -----
    torch.save(
        {
            "epoch": epoch,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "best_val_loss": best_val_loss,
        },
        CHECKPOINT_PATH,
    )

    # ----- SAVE NEW BEST MODEL -----
    if best_val_loss - val_loss > MIN_DELTA:
        best_val_loss = val_loss
        early_stop_counter = 0

        torch.save(model.state_dict(), OUTPUT_BEST_MODEL)
        print("✅ NEW BEST MODEL saved to /kaggle/working")
    else:
        early_stop_counter += 1
        print(f"No improvement ({early_stop_counter}/{PATIENCE})")

    if early_stop_counter >= PATIENCE:
        print("⛔ Early stopping triggered")
        break



In [None]:
# Plot loss and accuracy curves

metrics = torch.load(METRICS_PATH)

plt.figure(figsize=(14, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(metrics["epochs"], metrics["train_loss"], label="Train Loss")
plt.plot(metrics["epochs"], metrics["val_loss"], label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.legend()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(metrics["epochs"], metrics["train_acc"], label="Train Acc")
plt.plot(metrics["epochs"], metrics["val_acc"], label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy Curve")
plt.legend()

plt.show()



In [None]:
# Reload best model and define inference helpers

MODEL_PATH = "/kaggle/input/chess-model/pytorch/default/1/best_model.pth"

model = ChessPolicyNet().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

print("Model loaded on", device)


def select_bot_move(board: chess.Board, model, temperature=0.7):
    x = board_to_tensor(board)
    x = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(x)[0].cpu().numpy()

    legal_moves = list(board.legal_moves)
    legal_indices = [
        move.from_square * 64 + move.to_square
        for move in legal_moves
    ]

    legal_logits = logits[legal_indices]

    if temperature > 0:
        probs = np.exp(legal_logits / temperature)
        probs /= probs.sum()
        idx = np.random.choice(len(legal_moves), p=probs)
    else:
        idx = np.argmax(legal_logits)

    return legal_moves[idx]



In [None]:
# Visualization helpers and interactive play

import chess.svg
from IPython.display import display, HTML, SVG


def show_boards_side_by_side(board_left, board_right, size=300,
                             left_title="After Human Move",
                             right_title="After Bot Move"):
    svg_left = chess.svg.board(board=board_left, size=size)
    svg_right = chess.svg.board(board=board_right, size=size)

    html = f"""
    <div style="display:flex; gap:30px; align-items:flex-start;">
        <div style="text-align:center;">
            <h4>{left_title}</h4>
            {svg_left}
        </div>
        <div style="text-align:center;">
            <h4>{right_title}</h4>
            {svg_right}
        </div>
    </div>
    """
    display(HTML(html))


def show_board(board, size=300):
    display(SVG(chess.svg.board(board=board, size=size)))


board = chess.Board()


def human_move(move_uci, temperature=0.7):
    global board

    if board.is_game_over():
        print("Game over:", board.result())
        return

    move = chess.Move.from_uci(move_uci)
    if move not in board.legal_moves:
        print("Illegal move:", move_uci)
        return

    board.push(move)
    print("Human plays:", move_uci)

    board_after_human = board.copy()

    if board.is_game_over():
        show_boards_side_by_side(board_after_human, board_after_human)
        print("Game over:", board.result())
        return

    bot_move = select_bot_move(board, model, temperature)
    board.push(bot_move)
    print("Bot plays:", bot_move)

    board_after_bot = board.copy()

    show_boards_side_by_side(
        board_after_human,
        board_after_bot,
        size=280,
    )

    if board.is_game_over():
        print("Game over:", board.result())


def new_game():
    global board
    board = chess.Board()
    print("New game started. You are White.")
    show_board(board)


new_game()



In [None]:
# Example: single move from a custom position

def select_bot_move_temperature(board: chess.Board, model, temperature=1.0):
    x = board_to_tensor(board)
    x = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(x)[0].cpu().numpy()

    legal_moves = list(board.legal_moves)
    legal_indices = [
        move.from_square * 64 + move.to_square
        for move in legal_moves
    ]

    legal_logits = logits[legal_indices]

    if temperature > 0:
        probs = np.exp(legal_logits / temperature)
        probs /= probs.sum()
        chosen_idx = np.random.choice(len(legal_indices), p=probs)
    else:
        chosen_idx = np.argmax(legal_logits)

    return legal_moves[chosen_idx]


board_example = chess.Board(
    "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 2 3"
)
move = select_bot_move_temperature(
    board=board_example,
    model=model,
    temperature=0.8,
)

print("Predicted move:", move)
print("UCI format:", move.uci())



In [None]:
# Top-k accuracy evaluation


def top_k_accuracy(model, loader, k=5):
    correct = 0
    total = 0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)

            topk = logits.topk(k, dim=1).indices
            correct += (topk == y.unsqueeze(1)).any(dim=1).sum().item()
            total += y.size(0)

    return correct / total


print("Top-1:", top_k_accuracy(model, test_loader, k=1))
print("Top-5:", top_k_accuracy(model, test_loader, k=5))

