In [1]:
!wget -nc https: // database.lichess.org/lichess_db_puzzle.csv.zst

zsh:1: command not found: wget


In [2]:
!pip install pandas tqdm scikit-learn zstandard chess torchinfo torcheval lightning matplotlib tensorboard multimethod vit-pytorch


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

tqdm.pandas()

train_df = pd.read_csv("lichess_db_puzzle.csv.zst")
# df = df.iloc[:500000]
train_df.head()

Unnamed: 0,PuzzleId,FEN,Moves,Rating,RatingDeviation,Popularity,NbPlays,Themes,GameUrl,OpeningTags
0,00008,r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - ...,f2g3 e6e7 b2b1 b3c1 b1c1 h6c1,1913,75,94,6230,crushing hangingPiece long middlegame,https://lichess.org/787zsVup/black#47,
1,0000D,5rk1/1p3ppp/pq3b2/8/8/1P1Q1N2/P4PPP/3R2K1 w - ...,d3d6 f8d8 d6d8 f6d8,1429,73,96,26521,advantage endgame short,https://lichess.org/F8M8OS71#52,
2,0008Q,8/4R3/1p2P3/p4r2/P6p/1P3Pk1/4K3/8 w - - 1 64,e7f7 f5e5 e2f1 e5e6,1419,75,90,560,advantage endgame rookEndgame short,https://lichess.org/MQSyb3KW#126,
3,0009B,r2qr1k1/b1p2ppp/pp4n1/P1P1p3/4P1n1/B2P2Pb/3NBP...,b6c5 e2g4 h3g4 d1g4,1112,74,87,569,advantage middlegame short,https://lichess.org/4MWQCxQ6/black#31,Kings_Pawn_Game Kings_Pawn_Game_Leonardis_Vari...
4,000Vc,8/8/4k1p1/2KpP2p/5PP1/8/8/8 w - - 0 53,g4h5 g6h5 f4f5 e6e5 f5f6 e5f6,1556,81,73,91,crushing endgame long pawnEndgame,https://lichess.org/l6AejDMO#104,


In [4]:
test_df = pd.read_csv("test_data_with_themes.csv")

test_df.head()

Unnamed: 0,PuzzleId,FEN,Moves,Tags
0,p0003,2R5/3rq1k1/p6p/1p3pNQ/Pb2p3/4P2P/5PP1/6K1 w - ...,c8e8 e7g5 e8g8 g7g8,short
1,p0027,7r/2b2pk1/Q4np1/p2pp1R1/qp3P2/1N6/1PP4P/4R2K b...,e5f4 g5g6 g7g6 e1g1 g6f5 b3d4,attraction sacrifice long
2,p0025,5qk1/p1pR1p2/1b3Rp1/4r3/1r6/7P/PP4P1/1B1Q3K b ...,e5e7 b1g6 e7d7 d1d7,short
3,p0029,4rrk1/pp3pp1/2p5/6qb/P1PP4/4nRB1/1P1QN1B1/R5K1...,a1a3 h5f3 g2f3 f7f5,defensiveMove short
4,p0031,4r1k1/1qr3Bp/p3pPp1/1p2Q3/8/2P2p2/PP1R1PPP/6K1...,e5e6 e8e6 d2d8 e6e8,hangingPiece short


In [5]:
import chess

piece_vectors = np.eye(2 * len(chess.PIECE_TYPES))


def piece_vector(piece):
    color_idx = int(not piece.color)  # white: 0, black: 1
    idx = piece.piece_type - 1
    return piece_vectors[6 * color_idx + idx]


def get_position_planes(board):
    arr = np.zeros((2 * len(chess.PIECE_TYPES), 8, 8))
    for idx, piece in board.piece_map().items():
        arr[:, idx // 8, idx % 8] = piece_vector(piece)
    return arr


def get_castling_planes(board):
    arr = np.zeros((4, 8, 8))
    arr[0, :, :] = board.has_kingside_castling_rights(chess.WHITE)
    arr[1, :, :] = board.has_queenside_castling_rights(chess.WHITE)
    arr[2, :, :] = board.has_kingside_castling_rights(chess.BLACK)
    arr[3, :, :] = board.has_queenside_castling_rights(chess.BLACK)
    return arr


def get_legal_moves_planes(board):
    moves_arr = np.zeros((12, 8, 8))
    captures_arr = np.zeros((12, 8, 8))
    checks_arr = np.zeros((12, 8, 8))

    for move in board.legal_moves:
        piece = board.piece_at(move.from_square)
        if piece.color == chess.WHITE:
            moves_arr[piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = 1
            captures_arr[piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = board.is_capture(move)
            checks_arr[piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = board.gives_check(move)
        else:
            moves_arr[6 + piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = 1
            captures_arr[6 + piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = board.is_capture(move)
            checks_arr[6 + piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = board.gives_check(move)

    return moves_arr, captures_arr, checks_arr


def get_move_vector(board, move):
    piece_type = board.piece_type_at(move.from_square)
    move_vec = np.array([
        board.is_en_passant(move),
        board.is_capture(move),
        board.is_zeroing(move),
        board.is_irreversible(move),
        board.is_kingside_castling(move),
        board.is_queenside_castling(move),
        board.gives_check(move),
        piece_type == chess.KING,
        piece_type == chess.QUEEN,
        piece_type == chess.PAWN,
        piece_type == chess.BISHOP,
        piece_type == chess.KNIGHT,
        piece_type == chess.ROOK,
        1,
    ])
    if board.color_at(move.from_square) == chess.WHITE:
        return np.hstack([move_vec, np.zeros_like(move_vec)])
    else:
        return np.hstack([np.zeros_like(move_vec), move_vec])


def row_to_item(row, num_moves):
    board = chess.Board()
    board.set_fen(row.FEN)
    moves = row.Moves.split()
    position_arr = []
    flip = row.FEN.split()[1] == "w"
    for i in range(min(num_moves, len(moves))):
        move = chess.Move.from_uci(moves[i])
        board.push(move)

        transformed_board = board
        # if self.add_noise and rng.random() < 0.5:
        #     transformed_board = board.transform(chess.flip_horizontal)
        pos_planes = get_position_planes(transformed_board)
        if flip:
            pos_planes = np.flip(np.flip(pos_planes, axis=1), axis=0).copy()
        position_arr.append(pos_planes)
        position_arr.append(get_castling_planes(transformed_board))
        moves_planes, captures_planes, checks_planes = get_legal_moves_planes(transformed_board)
        position_arr.append(moves_planes)
        position_arr.append(captures_planes)
        position_arr.append(checks_planes)
    if len(moves) < num_moves:
        # pad with zeros
        position_arr.append(np.zeros(((num_moves - len(moves)) * 52, 8, 8)))
    position_arr = np.vstack(position_arr)
    item_dict = {"position": position_arr}
    if "Rating" in row:
        item_dict["rating_deviation"] = row.RatingDeviation
        item_dict["rating"] = row.Rating
    return item_dict

In [6]:
import torch.utils.data as data

rng = np.random.default_rng()


def random_flip(arr, axis, prob):
    if rng.random() < prob:
        return np.flip(arr, axis).copy()
    return arr


class PositionDataset(data.Dataset):
    def __init__(self, df, num_moves):
        self.df = df
        self.num_moves = num_moves

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        row = self.df.iloc[item]
        return row_to_item(row, self.num_moves)


In [7]:
from typing import List
import torch
from multimethod import multimethod
from dataclasses import dataclass


@multimethod
def g(phi: float):
    return 1.0 / math.sqrt(1 + 3 * phi ** 2 / math.pi ** 2)


@multimethod
def g(phi: torch.Tensor):
    return 1.0 / torch.sqrt(1 + 3 * phi ** 2 / torch.pi ** 2)


@multimethod
def E(mu_1: float, mu_2: float, phi: float):
    return 1.0 / (1 + math.exp(-g(phi) * (mu_1 - mu_2)))


@multimethod
def E(mu_1: torch.Tensor, mu_2: torch.Tensor, phi: torch.Tensor):
    return 1.0 / (1 + torch.exp(-g(phi) * (mu_1 - mu_2)))


MULTIPLIER = 173.7178
DEFAULT_RATING = 1500.0
DEFAULT_DEVIATION = 350.0
DEFAULT_VOLATILITY = 0.09
TAU = 0.09


@dataclass
class Rating:
    rating: float
    rating_deviation: float
    volatility: float


@dataclass
class Result:
    opponent_rating: Rating
    outcome: float


def calculate_new_rating(rating_1: Rating, results: List[Result]) -> Rating:
    # calculate new rating for player 1

    # Step 2
    mu = (rating_1.rating - 1500.0) / MULTIPLIER
    phi = rating_1.rating_deviation / MULTIPLIER

    # Step 3
    inv_v = 0.0
    for res in results:
        mu_j = (res.opponent_rating.rating - 1500.0) / MULTIPLIER
        phi_j = res.opponent_rating.rating_deviation / MULTIPLIER
        inv_v += g(phi_j) ** 2 * E(mu, mu_j, phi_j) * (1.0 - E(mu, mu_j, phi_j))
    v = 1.0 / inv_v

    # Step 4
    delta = 0.0
    for res in results:
        mu_j = (res.opponent_rating.rating - 1500.0) / MULTIPLIER
        phi_j = res.opponent_rating.rating_deviation / MULTIPLIER
        delta += g(phi_j) * (res.outcome - E(mu, mu_j, phi_j))
    delta *= v

    # Step 5
    eps = 1e-6
    alpha = math.log(rating_1.volatility ** 2)

    def f(x: float):
        return math.exp(x) * (delta ** 2 - phi ** 2 - v - math.exp(x)) / (
                2 * (phi ** 2 + v + math.exp(x)) ** 2) - (x - alpha) / TAU ** 2

    A = alpha
    if delta ** 2 > phi ** 2 + v:
        B = math.log(delta ** 2 - phi ** 2 - v)
    else:
        k = 1
        while f(alpha - k * TAU) < 0.0:
            k += 1
        B = alpha - k * TAU
    f_A, f_B = f(A), f(B)
    it = 0
    while abs(B - A) > eps and it < 1000:
        C = A + (A - B) * f_A / (f_B - f_A)
        f_C = f(C)
        if f_C * f_B <= 0.0:
            A = B
            f_A = f_B
        else:
            f_A /= 2.0
        B = C
        f_B = f_C
        it += 1
    if it == 1000:
        raise RuntimeError()

    new_sigma = math.exp(A / 2.0)

    # Step 6
    phi_star = math.sqrt(phi ** 2 + new_sigma ** 2)

    # Step 7
    new_phi = 1.0 / math.sqrt(1.0 / phi_star ** 2 + 1.0 / v)
    new_mu = mu + new_phi ** 2 * delta / v

    # Step 8
    new_rating = MULTIPLIER * new_mu + 1500.0
    new_rd = MULTIPLIER * new_phi

    return Rating(
        new_rating,
        new_rd,
        new_sigma
    )



In [8]:
from torch.optim.lr_scheduler import LinearLR
from vit_pytorch import SimpleViT
import math
from torchvision.ops import SqueezeExcitation
import torch
import torch.nn as nn
import lightning as L


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        return self.block(x)


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, se_channels):
        super().__init__()
        self.conv = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels),
            SqueezeExcitation(out_channels, se_channels),
        )

    def forward(self, x):
        return nn.functional.relu(self.conv(x) + x)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)

    def gem(self, x, p=3, eps=1e-6):
        return nn.functional.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)

    def __repr__(self):
        return self.__class__.__name__ + \
            '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
            ', ' + 'eps=' + str(self.eps) + ')'


class PuzzleEncoder(nn.Module):
    def __init__(self, in_channels, num_blocks, num_filters, se_channels):
        super().__init__()
        self.input = nn.Sequential(
            ConvBlock(in_channels, num_filters),
        )
        self.res_tower = nn.Sequential(
            *[
                ResBlock(num_filters, num_filters, se_channels)
                for _ in range(num_blocks)
            ]
        )
        self.pooling = GeM()

    def forward(self, x):
        x = self.input(x)
        x = self.res_tower(x)
        x = self.pooling(x).flatten(1)
        return x


def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)


class LitModel(L.LightningModule):
    def __init__(self, in_channels, num_blocks, num_filters, se_channels):
        super().__init__()
        # self.encoder = PuzzleEncoder(in_channels=in_channels,
        #                              num_blocks=num_blocks,
        #                              num_filters=num_filters,
        #                              se_channels=se_channels)
        self.encoder = SimpleViT(
            image_size=8,
            patch_size=1,
            num_classes=num_filters,
            channels=in_channels,
            dim=1024,
            depth=6,
            heads=16,
            mlp_dim=2048,
        )
        self.fc = nn.Sequential(
            nn.Linear(num_filters, 1)
        )
        self.loss_fn = nn.BCEWithLogitsLoss()

    def training_step(self, batch, batch_idx):
        pos = batch["position"].float()
        rd = batch["rating_deviation"].float()
        rating = batch["rating"].float()
        phi = rd / MULTIPLIER
        mu = (rating - 1500.0) / MULTIPLIER
        batch_size = len(pos)

        pos_1, pos_2, *_ = torch.split(pos, batch_size // 2, dim=0)
        phi_1, phi_2, *_ = torch.split(phi, batch_size // 2, dim=0)
        mu_1, mu_2, *_ = torch.split(mu, batch_size // 2, dim=0)
        target = E(mu_1, mu_2, torch.sqrt(phi_1 ** 2 + phi_2 ** 2))

        y_1, y_2 = self.encoder(pos_1), self.encoder(pos_2)
        y = y_1 - y_2
        pred = self.fc(y).squeeze(-1)

        loss = self.loss_fn(pred, target)
        self.log("train_loss", loss, prog_bar=True)

        accuracy = ((mu_1 >= mu_2) * (pred >= 0.5) + (mu_1 < mu_2) * (pred < 0.5)).mean(dtype=torch.float)
        self.log("train_acc", accuracy, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        pos = batch["position"].float()
        rd = batch["rating_deviation"].float()
        rating = batch["rating"].float()
        phi = rd / MULTIPLIER
        mu = (rating - 1500.0) / MULTIPLIER
        batch_size = len(pos)

        pos_1, pos_2, *_ = torch.split(pos, batch_size // 2, dim=0)
        phi_1, phi_2, *_ = torch.split(phi, batch_size // 2, dim=0)
        mu_1, mu_2, *_ = torch.split(mu, batch_size // 2, dim=0)
        target = E(mu_1, mu_2, torch.sqrt(phi_1 ** 2 + phi_2 ** 2))

        y_1, y_2 = self.encoder(pos_1), self.encoder(pos_2)
        y = y_1 - y_2
        pred = self.fc(y).squeeze(-1)

        loss = self.loss_fn(pred, target)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)

        accuracy = ((mu_1 >= mu_2) * (pred >= 0.5) + (mu_1 < mu_2) * (pred < 0.5)).mean(dtype=torch.float)
        self.log("val_acc", accuracy, prog_bar=True, on_epoch=True)

        # Estimate rating for the first element of the batch
        rating_1 = Rating(DEFAULT_RATING, DEFAULT_DEVIATION, DEFAULT_VOLATILITY)
        y_2 = torch.cat([y_1[1:, :], y_2], dim=0)
        y_1 = y_1[0, :].tile((len(y_2), 1))
        y = y_1 - y_2
        pred = nn.functional.sigmoid(self.fc(y).squeeze(-1))
        for i, p in enumerate(pred):
            rating_2 = Rating(rating[i + 1].item(), rd[i + 1].item(), DEFAULT_VOLATILITY)
            rating_1 = calculate_new_rating(rating_1, rating_2, p.item())
        # print("SE", (rating_1.rating - rating[0].item()) ** 2)
        # print("Ratings", rating_1, rating[0])
        self.log("val_mse", (rating_1.rating - rating[0].item()) ** 2, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=20000)
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]





In [9]:
num_moves = 12
model = LitModel.load_from_checkpoint("models/glicko_sim/epoch=9-step=71410.ckpt", in_channels=52 * num_moves,
                                      num_blocks=20, num_filters=256, se_channels=32)

/Users/andry/PycharmProjects/chess-analysis/venv/lib/python3.9/site-packages/lightning/pytorch/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.3.3, which is newer than your current Lightning version: v2.3.0


In [20]:
from more_itertools import chunked
from tqdm.notebook import trange

num_workers = 11
if __name__ == "__main__":
    train_ds = PositionDataset(train_df[train_df.RatingDeviation < 80.0], num_moves)
    train_loader = data.DataLoader(train_ds, batch_size=3072, shuffle=True, pin_memory=True,
                                   num_workers=num_workers, persistent_workers=True)
    test_ds = PositionDataset(test_df, num_moves)
    test_loader = data.DataLoader(test_ds, batch_size=1, shuffle=False, pin_memory=True,
                                  num_workers=num_workers, persistent_workers=True)

    # train_df.sort_values(by=["Rating"])
    device = "cuda"

    train_y = []
    train_rating = []
    train_rd = []
    train_iter = iter(train_loader)
    for i in trange(10):
        with torch.no_grad():
            # Compute train batch embeddings
            train_batch = next(train_iter)
            train_position = train_batch["position"].float().to(device)
            train_rating.append(train_batch["rating"].float())
            train_rd.append(train_batch["rating_deviation"].float())
            train_y.append(model.encoder(train_position))

    train_rating = torch.cat(train_rating, dim=0)
    train_rd = torch.cat(train_rd, dim=0)
    train_y = torch.cat(train_y, dim=0)

    model.eval()
    model.to(device)
    ratings = []
    rd = []
    for i, test_batch in enumerate(tqdm(test_loader)):
        position = test_batch["position"].float().to(device)
        # Initialize rating
        rating_1 = Rating(DEFAULT_RATING, DEFAULT_DEVIATION, DEFAULT_VOLATILITY)

        with torch.no_grad() as a, torch.amp.autocast(device_type="cuda") as b:
            # Compute test item embedding
            test_y = model.encoder(position)

            test_y = test_y.tile((len(train_y), 1))

            # Estimate outcomes
            y = test_y - train_y
            pred = nn.functional.sigmoid(model.fc(y).squeeze(-1))

            # Estimate Glicko
            minibatch_size = 5
            for j, p in enumerate(chunked(pred, minibatch_size)):
                results = [
                    Result(Rating(train_rating[minibatch_size * j + k].item(), train_rd[minibatch_size * j + k].item(), DEFAULT_VOLATILITY),
                           outcome.item()) for k, outcome in enumerate(p)]
                rating_1 = calculate_new_rating(rating_1, results)

        rd.append(rating_1.rating_deviation)
        ratings.append(str(round(rating_1.rating)))

    with open("submission_glicko.txt", "w") as f:
        f.write("\n".join(ratings))

  0%|          | 0/2282 [00:00<?, ?it/s]

In [None]:
ratings_df = pd.DataFrame({"rating": ratings, "rd": rd})
ratings_df.head()