In [1]:
!wget -nc https://database.lichess.org/lichess_db_puzzle.csv.zst

File ‘lichess_db_puzzle.csv.zst’ already there; not retrieving.


In [2]:
!pip install pandas tqdm scikit-learn zstandard chess torchinfo torcheval lightning matplotlib tensorboard multimethod vit-pytorch

[0m

In [3]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

tqdm.pandas()

df = pd.read_csv("lichess_db_puzzle.csv.zst")
# df = df.iloc[:500000]
df.head()

Unnamed: 0,PuzzleId,FEN,Moves,Rating,RatingDeviation,Popularity,NbPlays,Themes,GameUrl,OpeningTags
0,00008,r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - ...,f2g3 e6e7 b2b1 b3c1 b1c1 h6c1,1974,76,94,6337,crushing hangingPiece long middlegame,https://lichess.org/787zsVup/black#47,
1,0000D,5rk1/1p3ppp/pq3b2/8/8/1P1Q1N2/P4PPP/3R2K1 w - ...,d3d6 f8d8 d6d8 f6d8,1473,73,96,27319,advantage endgame short,https://lichess.org/F8M8OS71#52,
2,0008Q,8/4R3/1p2P3/p4r2/P6p/1P3Pk1/4K3/8 w - - 1 64,e7f7 f5e5 e2f1 e5e6,1444,75,90,575,advantage endgame rookEndgame short,https://lichess.org/MQSyb3KW#126,
3,0009B,r2qr1k1/b1p2ppp/pp4n1/P1P1p3/4P1n1/B2P2Pb/3NBP...,b6c5 e2g4 h3g4 d1g4,1112,74,87,569,advantage middlegame short,https://lichess.org/4MWQCxQ6/black#31,Kings_Pawn_Game Kings_Pawn_Game_Leonardis_Vari...
4,000Vc,8/8/4k1p1/2KpP2p/5PP1/8/8/8 w - - 0 53,g4h5 g6h5 f4f5 e6e5 f5f6 e5f6,1556,81,73,94,crushing endgame long pawnEndgame,https://lichess.org/l6AejDMO#104,


In [4]:
themes = ['advancedPawn', 'arabianMate', 'attackingF2F7', 'attraction',
          'backRankMate', 'bishopEndgame', 'bodenMate', 'capturingDefender',
          'castling', 'clearance', 'defensiveMove', 'deflection',
          'discoveredAttack', 'doubleCheck', 'dovetailMate', 'enPassant',
          'exposedKing', 'fork', 'hangingPiece', 'hookMate', 'interference',
          'intermezzo', 'kingsideAttack', 'knightEndgame', 'long', 'mate',
          'mateIn1', 'mateIn2', 'mateIn3', 'mateIn4', 'mateIn5', 'oneMove',
          'pawnEndgame', 'pin', 'promotion', 'queenEndgame', 'queenRookEndgame',
          'queensideAttack', 'quietMove', 'rookEndgame', 'sacrifice', 'short',
          'skewer', 'smotheredMate', 'trappedPiece', 'veryLong', 'xRayAttack']
themes_df = df.Themes.str.get_dummies(sep=" ")[themes]
themes_df.head()

Unnamed: 0,advancedPawn,arabianMate,attackingF2F7,attraction,backRankMate,bishopEndgame,bodenMate,capturingDefender,castling,clearance,...,queensideAttack,quietMove,rookEndgame,sacrifice,short,skewer,smotheredMate,trappedPiece,veryLong,xRayAttack
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,1,0,1,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [5]:
import chess

piece_vectors = np.eye(2 * len(chess.PIECE_TYPES))


def piece_vector(piece):
    color_idx = int(not piece.color)  # white: 0, black: 1
    idx = piece.piece_type - 1
    return piece_vectors[6 * color_idx + idx]


def get_position_planes(board):
    arr = np.zeros((2 * len(chess.PIECE_TYPES), 8, 8))
    for idx, piece in board.piece_map().items():
        arr[:, idx // 8, idx % 8] = piece_vector(piece)
    return arr


def get_castling_planes(board):
    arr = np.zeros((4, 8, 8))
    arr[0, :, :] = board.has_kingside_castling_rights(chess.WHITE)
    arr[1, :, :] = board.has_queenside_castling_rights(chess.WHITE)
    arr[2, :, :] = board.has_kingside_castling_rights(chess.BLACK)
    arr[3, :, :] = board.has_queenside_castling_rights(chess.BLACK)
    return arr


def get_legal_moves_planes(board):
    moves_arr = np.zeros((12, 8, 8))
    captures_arr = np.zeros((12, 8, 8))
    checks_arr = np.zeros((12, 8, 8))

    for move in board.legal_moves:
        piece = board.piece_at(move.from_square)
        if piece.color == chess.WHITE:
            moves_arr[piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = 1
            captures_arr[piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = board.is_capture(move)
            checks_arr[piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = board.gives_check(move)
        else:
            moves_arr[6 + piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = 1
            captures_arr[6 + piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = board.is_capture(move)
            checks_arr[6 + piece.piece_type - 1, move.to_square // 8, move.to_square % 8] = board.gives_check(move)

    return moves_arr, captures_arr, checks_arr


def get_move_vector(board, move):
    piece_type = board.piece_type_at(move.from_square)
    move_vec = np.array([
        board.is_en_passant(move),
        board.is_capture(move),
        board.is_zeroing(move),
        board.is_irreversible(move),
        board.is_kingside_castling(move),
        board.is_queenside_castling(move),
        board.gives_check(move),
        piece_type == chess.KING,
        piece_type == chess.QUEEN,
        piece_type == chess.PAWN,
        piece_type == chess.BISHOP,
        piece_type == chess.KNIGHT,
        piece_type == chess.ROOK,
        1,
    ])
    if board.color_at(move.from_square) == chess.WHITE:
        return np.hstack([move_vec, np.zeros_like(move_vec)])
    else:
        return np.hstack([np.zeros_like(move_vec), move_vec])

In [6]:
import torch.utils.data as data

rng = np.random.default_rng()


def random_flip(arr, axis, prob):
    if rng.random() < prob:
        return np.flip(arr, axis).copy()
    return arr


class PositionDataset(data.Dataset):
    def __init__(self, df, themes_df, num_moves):
        self.df = df
        self.themes_df = themes_df
        self.num_moves = num_moves

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        board = chess.Board()
        board.set_fen(self.df.iloc[item].FEN)
        moves = self.df.iloc[item].Moves.split()
        position_arr = []
        flip = self.df.iloc[item].FEN.split()[1] == "w"
        for i in range(min(self.num_moves, len(moves))):
            move = chess.Move.from_uci(moves[i])
            board.push(move)

            transformed_board = board
            # if self.add_noise and rng.random() < 0.5:
            #     transformed_board = board.transform(chess.flip_horizontal)
            pos_planes = get_position_planes(transformed_board)
            if flip:
                pos_planes = np.flip(np.flip(pos_planes, axis=1), axis=0).copy()
            position_arr.append(pos_planes)
            # position_arr.append(get_castling_planes(transformed_board))
            # moves_planes, captures_planes, checks_planes = get_legal_moves_planes(transformed_board)
            # position_arr.append(moves_planes)
            # position_arr.append(captures_planes)
            # position_arr.append(checks_planes)

        if len(moves) < self.num_moves:
            # pad with zeros
            position_arr.append(np.zeros(((self.num_moves - len(moves)) * 12, 8, 8)))

        position_arr = np.vstack(position_arr)
        # position_arr = position_arr.reshape((-1, self.num_moves, 8, 8))
        return {"position": position_arr,
                "themes": self.themes_df.iloc[item].to_numpy(),
                "rating_deviation": self.df.iloc[item].RatingDeviation,
                "rating": self.df.iloc[item].Rating}




In [7]:
import torch
from multimethod import multimethod
from dataclasses import dataclass


@multimethod
def g(phi: float):
    return 1.0 / math.sqrt(1 + 3 * phi ** 2 / math.pi ** 2)


@multimethod
def g(phi: torch.Tensor):
    return 1.0 / torch.sqrt(1 + 3 * phi ** 2 / torch.pi ** 2)


@multimethod
def E(mu_1: float, mu_2: float, phi: float):
    return 1.0 / (1 + math.exp(-g(phi) * (mu_1 - mu_2)))


@multimethod
def E(mu_1: torch.Tensor, mu_2: torch.Tensor, phi: torch.Tensor):
    return 1.0 / (1 + torch.exp(-g(phi) * (mu_1 - mu_2)))


MULTIPLIER = 173.7178
DEFAULT_RATING = 1500.0
DEFAULT_DEVIATION = 350.0
DEFAULT_VOLATILITY = 0.09
TAU = 0.09


@dataclass
class Rating:
    rating: float
    rating_deviation: float
    volatility: float


def calculate_new_rating(rating_1: Rating, rating_2: Rating, player_1_win: float) -> Rating:
    # calculate new rating for player 1

    # Step 2
    mu_1 = (rating_1.rating - 1500.0) / MULTIPLIER
    phi_1 = rating_1.rating_deviation / MULTIPLIER
    mu_2 = (rating_2.rating - 1500.0) / MULTIPLIER
    phi_2 = rating_2.rating_deviation / MULTIPLIER

    # Step 3
    v = 1.0 / (g(phi_2) ** 2 * E(mu_1, mu_2, phi_2) * (1.0 - E(mu_1, mu_2, phi_2)))

    # Step 4
    delta = v * g(phi_2) * (player_1_win - E(mu_1, mu_2, phi_2))

    # Step 5
    eps = 1e-6
    alpha = math.log(rating_1.volatility ** 2)

    def f(x: float):
        return math.exp(x) * (delta ** 2 - phi_1 ** 2 - v - math.exp(x)) / (
                2 * (phi_1 ** 2 + v + math.exp(x)) ** 2) - (x - alpha) / TAU ** 2

    A = alpha
    if delta ** 2 > phi_1 ** 2 + v:
        B = math.log(delta ** 2 - phi_1 ** 2 - v)
    else:
        k = 1
        while f(alpha - k * TAU) < 0.0:
            k += 1
        B = alpha - k * TAU
    f_A, f_B = f(A), f(B)
    it = 0
    while abs(B - A) > eps and it < 1000:
        C = A + (A - B) * f_A / (f_B - f_A)
        f_C = f(C)
        if f_C * f_B <= 0.0:
            A = B
            f_A = f_B
        else:
            f_A /= 2.0
        B = C
        f_B = f_C
        it += 1
    if it == 1000:
        raise RuntimeError()

    new_sigma = math.exp(A / 2.0)

    # Step 6
    phi_star = math.sqrt(phi_1 ** 2 + new_sigma ** 2)

    # Step 7
    new_phi = 1.0 / math.sqrt(1.0 / phi_star ** 2 + 1.0 / v)
    new_mu = mu_1 + new_phi ** 2 * g(phi_2) * (player_1_win - E(mu_1, mu_2, phi_2))

    # Step 8
    new_rating = MULTIPLIER * new_mu + 1500.0
    new_rd = MULTIPLIER * new_phi

    return Rating(
        new_rating,
        new_rd,
        new_sigma
    )



In [8]:
# import importlib
# import glicko2
# importlib.reload(glicko2)
# 
# gl = glicko2.Glicko2(tau=TAU, epsilon=1e-6)
# rating_1 = gl.create_rating(DEFAULT_RATING, DEFAULT_DEVIATION, DEFAULT_VOLATILITY)
# rating_2 = gl.create_rating(1600.0, DEFAULT_DEVIATION, DEFAULT_VOLATILITY)
# 
# print(gl.rate(rating_1, [(1.0, rating_2)]))
# 
# print(calculate_new_rating(Rating(DEFAULT_RATING, DEFAULT_DEVIATION, DEFAULT_VOLATILITY),
#                            Rating(1600.0 , DEFAULT_DEVIATION, DEFAULT_VOLATILITY),
#                            True))


In [16]:
from vit_pytorch import SimpleViT
from torch.optim.lr_scheduler import LinearLR
# from vit_pytorch.simple_vit_3d import SimpleViT
import math
from torchvision.ops import SqueezeExcitation
import torch
import torch.nn as nn
import lightning as L


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        return self.block(x)


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, se_channels):
        super().__init__()
        self.conv = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels),
            SqueezeExcitation(out_channels, se_channels),
        )

    def forward(self, x):
        return nn.functional.relu(self.conv(x) + x)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)

    def gem(self, x, p=3, eps=1e-6):
        return nn.functional.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)

    def __repr__(self):
        return self.__class__.__name__ + \
            '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
            ', ' + 'eps=' + str(self.eps) + ')'


class PuzzleEncoder(nn.Module):
    def __init__(self, in_channels, num_blocks, num_filters, se_channels, num_themes):
        super().__init__()
        
        self.input = nn.Sequential(
            ConvBlock(in_channels, num_filters),
        )
        self.themes_encoder = nn.Sequential(
            nn.Linear(num_themes, num_filters),
            nn.BatchNorm1d(num_filters),
            nn.LeakyReLU(),
            nn.Linear(num_filters, num_filters),
            nn.BatchNorm1d(num_filters),
            nn.LeakyReLU(),
            nn.Linear(num_filters, num_filters),
            nn.BatchNorm1d(num_filters),
            nn.LeakyReLU(),
        )
        self.res_tower = nn.Sequential(
            *[
                ResBlock(num_filters, num_filters, se_channels)
                for _ in range(num_blocks)
            ]
        )
        self.pooling = GeM()

    def forward(self, x, z):
        x = self.input(x)
        z = self.themes_encoder(z)
        z = z.unsqueeze(-1).unsqueeze(-1)
        x = self.res_tower(x + z)
        x = self.pooling(x).flatten(1)
        return x


def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)


class LitModel(L.LightningModule):
    def __init__(self, in_channels, depth, dim, num_themes):
        super().__init__()
        self.encoder = PuzzleEncoder(in_channels=in_channels,
                                     num_blocks=depth,
                                     num_filters=dim,
                                     se_channels=dim // 4,
                                     num_themes=num_themes)
        # self.encoder = SimpleViT(
        #     image_size=8,
        #     # image_patch_size=1,
        #     # frames=12,
        #     # frame_patch_size=1,
        #     patch_size=1,
        #     num_classes=dim,
        #     channels=in_channels,
        #     dim=dim,
        #     depth=depth,
        #     heads=16,
        #     mlp_dim=2 * dim,
        # )
        self.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(dim, 1)
        )
        self.loss_fn = nn.BCEWithLogitsLoss()

    def on_fit_start(self):
        for cb in self.trainer.callbacks:
            if isinstance(cb, StochasticWeightAveraging):
                self.swa = cb

    def training_step(self, batch, batch_idx):
        pos = batch["position"].float()
        themes = batch["themes"].float()
        rd = batch["rating_deviation"].float()
        rating = batch["rating"].float()
        phi = rd / MULTIPLIER
        mu = (rating - 1500.0) / MULTIPLIER
        batch_size = len(pos)

        pos_1, pos_2, *_ = torch.split(pos, batch_size // 2, dim=0)
        themes_1, themes_2, *_ = torch.split(themes, batch_size // 2, dim=0)
        rating_1, rating_2, *_ = torch.split(rating, batch_size // 2, dim=0)
        phi_1, phi_2, *_ = torch.split(phi, batch_size // 2, dim=0)
        mu_1, mu_2, *_ = torch.split(mu, batch_size // 2, dim=0)
        
        # Sort subbatches
        # idx_1, idx_2 = torch.argsort(rating_1), torch.argsort(rating_2)
        # pos_1, mu_1, phi_1 = pos_1[idx_1], mu_1[idx_1], phi_1[idx_1]
        # pos_2, mu_2, phi_2 = pos_2[idx_2], mu_2[idx_2], phi_2[idx_2]
        
        target = E(mu_1, mu_2, torch.sqrt(phi_1 ** 2 + phi_2 ** 2))

        y_1, y_2 = self.encoder(pos_1, themes_1), self.encoder(pos_2, themes_2)
        y = y_1 - y_2
        pred = self.fc(y).squeeze(-1)

        loss = self.loss_fn(pred, target)
        self.log("train_loss", loss, prog_bar=True)

        accuracy = ((mu_1 >= mu_2) * (pred >= 0.5) + (mu_1 < mu_2) * (pred < 0.5)).mean(dtype=torch.float)
        self.log("train_acc", accuracy, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        pos = batch["position"].float()
        themes = batch["themes"].float()
        rd = batch["rating_deviation"].float()
        rating = batch["rating"].float()
        phi = rd / MULTIPLIER
        mu = (rating - 1500.0) / MULTIPLIER
        batch_size = len(pos)

        pos_1, pos_2, *_ = torch.split(pos, batch_size // 2, dim=0)
        themes_1, themes_2, *_ = torch.split(themes, batch_size // 2, dim=0)
        rating_1, rating_2, *_ = torch.split(rating, batch_size // 2, dim=0)
        phi_1, phi_2, *_ = torch.split(phi, batch_size // 2, dim=0)
        mu_1, mu_2, *_ = torch.split(mu, batch_size // 2, dim=0)
        
        target = E(mu_1, mu_2, torch.sqrt(phi_1 ** 2 + phi_2 ** 2))

        y_1, y_2 = self.encoder(pos_1, themes_1), self.encoder(pos_2, themes_2)
        y = y_1 - y_2
        pred = self.fc(y).squeeze(-1)

        loss = self.loss_fn(pred, target)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)

        accuracy = ((mu_1 >= mu_2) * (pred >= 0.5) + (mu_1 < mu_2) * (pred < 0.5)).mean(dtype=torch.float)
        self.log("val_acc", accuracy, prog_bar=True, on_epoch=True)

        # Estimate rating for the first element of the batch
        rating_1 = Rating(DEFAULT_RATING, DEFAULT_DEVIATION, DEFAULT_VOLATILITY)
        y_2 = torch.cat([y_1[1:, :], y_2], dim=0)
        y_1 = y_1[0, :].tile((len(y_2), 1))
        y = y_1 - y_2
        pred = nn.functional.sigmoid(self.fc(y).squeeze(-1))
        for i, p in enumerate(pred):
            rating_2 = Rating(rating[i + 1].item(), rd[i + 1].item(), DEFAULT_VOLATILITY)
            rating_1 = calculate_new_rating(rating_1, rating_2, p.item())
        # print("SE", (rating_1.rating - rating[0].item()) ** 2)
        # print("Ratings", rating_1, rating[0])
        self.log("val_mse", (rating_1.rating - rating[0].item()) ** 2, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=10000)
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]





In [10]:
from sklearn.model_selection import train_test_split

num_moves = 12
train_df, val_df, train_themes, val_themes = train_test_split(df, themes_df, test_size=0.1)
train_ds = PositionDataset(train_df, train_themes, num_moves)
val_ds = PositionDataset(val_df, val_themes, num_moves)

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging
from lightning import Trainer

use_lightning = True

num_workers = 31
num_epochs = 20
if __name__ == "__main__":
    train_loader = data.DataLoader(train_ds, batch_size=512, shuffle=True, pin_memory=True,
                                   num_workers=num_workers)
    val_loader = data.DataLoader(val_ds, batch_size=1024, shuffle=False, pin_memory=True,
                                 num_workers=num_workers)
    checkpoint_callback = ModelCheckpoint(
        every_n_epochs=1,
        save_top_k=-1,
    )
    lr_monitor_callback = LearningRateMonitor(logging_interval="step")
    # swa_callback = StochasticWeightAveraging(swa_lrs=1e-5, swa_epoch_start=10, annealing_epochs=1, device=None)
    model = LitModel(in_channels=12 * num_moves, depth=5, dim=512, num_themes=len(themes_df.columns))
    trainer = Trainer(max_epochs=num_epochs, accelerator="gpu", precision="16-mixed",
                      callbacks=[checkpoint_callback, lr_monitor_callback])
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    # trainer.save_checkpoint("swa.ckpt")

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | encoder | PuzzleEncoder     | 24.9 M | train
1 | fc      | Sequential        | 513    | train
2 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
24.9 M    Trainable params
0         Non-trainable params
24.9 M    Total params
99.730    Total estimated model params size (MB)
103       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]