In [None]:
# Kaggle Chess Policy Network Training Notebook

import os
import random

import numpy as np
import pandas as pd
import polars as pl

import chess
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Dataset paths (adapt as needed outside Kaggle)
BASE_PATH = "/kaggle/input/chess-dataset-splitted/Chess-dataset"

train_df = pl.read_parquet(f"{BASE_PATH}/train_2450_2550.parquet")
val_df   = pl.read_parquet(f"{BASE_PATH}/val_2450_2550.parquet")
test_df  = pl.read_parquet(f"{BASE_PATH}/test_2450_2550.parquet")

print("Train/Val/Test shapes:", train_df.shape, val_df.shape, test_df.shape)

PIECE_TO_PLANE = {
    chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
    chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5,
}


def board_to_tensor(board: chess.Board):
    tensor = np.zeros((18, 8, 8), dtype=np.float32)

    for piece_type in PIECE_TO_PLANE:
        for square in board.pieces(piece_type, chess.WHITE):
            r, c = divmod(square, 8)
            tensor[PIECE_TO_PLANE[piece_type], r, c] = 1

        for square in board.pieces(piece_type, chess.BLACK):
            r, c = divmod(square, 8)
            tensor[PIECE_TO_PLANE[piece_type] + 6, r, c] = 1

    tensor[12, :, :] = int(board.turn)
    tensor[13, :, :] = board.has_kingside_castling_rights(chess.WHITE)
    tensor[14, :, :] = board.has_queenside_castling_rights(chess.WHITE)
    tensor[15, :, :] = board.has_kingside_castling_rights(chess.BLACK)
    tensor[16, :, :] = board.has_queenside_castling_rights(chess.BLACK)

    tensor[17, :, :] = board.fullmove_number / 100.0
    return tensor


def move_to_index(move: chess.Move):
    return move.from_square * 64 + move.to_square


class ChessPositionDataset(Dataset):
    def __init__(self, df: pl.DataFrame):
        self.df = df

    def __len__(self):
        return self.df.height

    def __getitem__(self, idx):
        row = self.df.row(idx)

        moves = row[self.df.columns.index("moves_uci")]
        if moves is None or len(moves) < 2:
            return self.__getitem__((idx + 1) % len(self))

        ply_idx = random.randint(0, len(moves) - 2)

        board = chess.Board()
        for i in range(ply_idx):
            board.push_uci(moves[i])

        x = board_to_tensor(board)
        target_move = chess.Move.from_uci(moves[ply_idx])
        y = move_to_index(target_move)

        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return torch.relu(out + x)


class ChessPolicyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(18, 256, 3, padding=1)
        self.bn = nn.BatchNorm2d(256)
        self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(10)])
        self.policy = nn.Conv2d(256, 73, 1)
        self.fc = nn.Linear(73 * 8 * 8, 4672)

    def forward(self, x):
        x = torch.relu(self.bn(self.conv(x)))
        x = self.res_blocks(x)
        x = self.policy(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessPolicyNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

print("Model on device:", device)

