In [1]:
#dependencies
!pip install python-chess
import numpy as np
from chess.pgn import Game
from chess import Board
from typing import List
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.utils.data import DataLoader, Dataset
import chess.pgn
import numpy as np
from tqdm import tqdm

Collecting python-chess
  Downloading python_chess-1.999-py3-none-any.whl.metadata (776 bytes)
Collecting chess<2,>=1 (from python-chess)
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m59.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading python_chess-1.999-py3-none-any.whl (1.4 kB)
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=57073f51ed9423bcb345e63007e8a0fbcbed298a1cf7a88990b3bba3df6fceae
  Stored in directory: /root/.cache/pip/wheels/fb/5d/5c/59a62d8a695285e59ec9c1f66add6f8a9ac4152499a2be0113
Successfully built chess
Installing collected packages: chess, python-chess
Successfully installed chess-1.11.2 python-chess-1.999


In [2]:
#CNN

class ChessCNN(nn.Module):
    #model architecture: input -> conv2d -> relu -> conv2d -> relu -> conv2d -> relu -> flatten -> linear -> relu -> linear
    #input is an 8 x 8 matrix (representing a chess board) with 13 channels (12 for each unique piece and 1 for legal moves)
    #num_classes is the total number of unique moves in the dataset

    def __init__(self, num_classes):
        super().__init__()

        self.conv1 = nn.Conv2d(13,64,3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.conv3 = nn.Conv2d(64,128,3,stride=1,padding=1)

        self.fc1 = nn.Linear(8*8*128,512)
        self.fc2 = nn.Linear(512,num_classes)

        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))

        x = self.flatten(x)

        x = self.relu(self.fc1(x))
        x = self.fc2(x)

        return x

In [3]:
#auxiliary functions

#returns a tensor representation of a chess board. Tensor is of shape (13,8,8)
#requires: board is of type Board
def board_to_tensor(board: Board):
    tensor = np.zeros((13,8,8))
    piece_map = board.piece_map()
    for square, piece in piece_map.items():
        row,col = divmod(square,8)
        piece_type = piece.piece_type - 1
        piece_color = 0 if piece.color else 6
        tensor[piece_type + piece_color, row, col] = 1

    legal_moves = board.legal_moves
    for move in legal_moves:
        to_square = move.to_square
        row_to, col_to = divmod(to_square,8)
        tensor[12,row_to,col_to] = 1

    return tensor

#returns an np.array of board tensors and an np.array of labels, where the board tensors are (13,8,8) and the labels are uci formatted strings.
#label y_i is the move that was played in position X_i
#requires: games is of type List[Game]
def games_to_input(games: List[Game]):
    X = []
    y = []
    for game in games:
        board = game.board()
        for move in game.mainline_moves():
            X.append(board_to_tensor(board))
            y.append(move.uci())
            board.push(move)
    return np.array(X, dtype=np.float32), np.array(y)

#returns an np.array of moves encoded as ints, a dict mapping moves to ints, and a dict mapping ints to moves.
#requires: moves is a list of uci formatted strings
def encode_moves(moves):
    unique_moves = list(set(moves))
    move_to_int = {move: int for int, move in enumerate(unique_moves)}
    int_to_move = {int: move for int, move in enumerate(unique_moves)}
    moves = [move_to_int[move] for move in moves]
    return np.array(moves, dtype=np.float32), move_to_int, int_to_move

In [4]:
#custom dataset

#X are the board tensors and y are the labels
class ChessDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def __len__(self):
        return len(self.X)

In [6]:
#training loop

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

pgn = open('./lichess_elite_2020-08.pgn')

print("processing games...")
games = []
i = 0
while True and i<=1000:
    game = chess.pgn.read_game(pgn)
    if game is None:
        break
    else:
        games.append(game)
    i += 1
print("games processed")

print("converting games to input...")
X, y = games_to_input(games)
y, moves_to_int, int_to_moves = encode_moves(y)
print("games converted")

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

print(f"number of inputs = {len(X)}")
num_classes = len(moves_to_int)
print(f"num_classes = {num_classes}")

dataset = ChessDataset(X,y)
model = ChessCNN(num_classes=num_classes).to(device)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

loss_fn = nn.CrossEntropyLoss()

optimizer = SGD(model.parameters(), lr=0.1)

for epoch in range(100):

    model.train()

    total_loss = 0

    for input, label in tqdm(dataloader):
        input = input.to(device)
        label = label.to(device)
        output = model(input)
        loss = loss_fn(output, label)
        loss.backward()
        total_loss += float(loss)
        optimizer.step()
        optimizer.zero_grad()

    print(f"loss for epoch{epoch} = {total_loss}")

import os
os.makedirs("./checkpoints", exist_ok=True)

torch.save(model.state_dict(), "./checkpoints/test_path.pth")

ERROR:chess.pgn:illegal san: 'Bd3' in rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1 while parsing <Game at 0x7a47e0458390 ('?' vs. '?', '????.??.??' at '?')>


Using device:  cuda
processing games...
games processed
converting games to input...
games converted
number of inputs = 84950
num_classes = 1807


100%|██████████| 2655/2655 [00:07<00:00, 357.64it/s]


loss for epoch0 = 16402.388924121857


100%|██████████| 2655/2655 [00:06<00:00, 406.10it/s]


loss for epoch1 = 14150.058712005615


100%|██████████| 2655/2655 [00:06<00:00, 423.92it/s]


loss for epoch2 = 11940.820191860199


100%|██████████| 2655/2655 [00:06<00:00, 438.30it/s]


loss for epoch3 = 10004.203803300858


100%|██████████| 2655/2655 [00:06<00:00, 423.97it/s]


loss for epoch4 = 8635.778760194778


100%|██████████| 2655/2655 [00:06<00:00, 436.88it/s]


loss for epoch5 = 7475.585998177528


100%|██████████| 2655/2655 [00:06<00:00, 427.53it/s]


loss for epoch6 = 6350.445924639702


100%|██████████| 2655/2655 [00:06<00:00, 439.64it/s]


loss for epoch7 = 5179.12811934948


100%|██████████| 2655/2655 [00:06<00:00, 420.93it/s]


loss for epoch8 = 4007.3418013453484


100%|██████████| 2655/2655 [00:06<00:00, 402.51it/s]


loss for epoch9 = 2998.798249512911


100%|██████████| 2655/2655 [00:06<00:00, 417.29it/s]


loss for epoch10 = 2254.850557282567


100%|██████████| 2655/2655 [00:05<00:00, 457.01it/s]


loss for epoch11 = 1759.9377104341984


100%|██████████| 2655/2655 [00:06<00:00, 416.56it/s]


loss for epoch12 = 1392.6569809019566


100%|██████████| 2655/2655 [00:05<00:00, 450.15it/s]


loss for epoch13 = 1170.5935730859637


100%|██████████| 2655/2655 [00:06<00:00, 414.59it/s]


loss for epoch14 = 952.7079080250114


100%|██████████| 2655/2655 [00:05<00:00, 449.73it/s]


loss for epoch15 = 804.6554706813768


100%|██████████| 2655/2655 [00:06<00:00, 408.57it/s]


loss for epoch16 = 690.6676504435018


100%|██████████| 2655/2655 [00:06<00:00, 415.32it/s]


loss for epoch17 = 599.561608534772


100%|██████████| 2655/2655 [00:06<00:00, 383.47it/s]


loss for epoch18 = 531.8639446767047


100%|██████████| 2655/2655 [00:06<00:00, 436.98it/s]


loss for epoch19 = 503.8573249850888


100%|██████████| 2655/2655 [00:06<00:00, 415.76it/s]


loss for epoch20 = 473.25766187463887


100%|██████████| 2655/2655 [00:05<00:00, 454.43it/s]


loss for epoch21 = 447.49956290749833


100%|██████████| 2655/2655 [00:06<00:00, 411.98it/s]


loss for epoch22 = 429.45703486935236


100%|██████████| 2655/2655 [00:05<00:00, 452.54it/s]


loss for epoch23 = 415.18000579893123


100%|██████████| 2655/2655 [00:06<00:00, 409.54it/s]


loss for epoch24 = 406.55841566761956


100%|██████████| 2655/2655 [00:05<00:00, 456.43it/s]


loss for epoch25 = 396.6004428579472


100%|██████████| 2655/2655 [00:06<00:00, 411.46it/s]


loss for epoch26 = 391.9941503284499


100%|██████████| 2655/2655 [00:05<00:00, 455.08it/s]


loss for epoch27 = 382.1396534984233


100%|██████████| 2655/2655 [00:06<00:00, 413.68it/s]


loss for epoch28 = 378.1771206657868


100%|██████████| 2655/2655 [00:05<00:00, 454.31it/s]


loss for epoch29 = 373.81534895440564


100%|██████████| 2655/2655 [00:06<00:00, 416.62it/s]


loss for epoch30 = 370.3868843583623


100%|██████████| 2655/2655 [00:05<00:00, 456.63it/s]


loss for epoch31 = 365.9400028967066


100%|██████████| 2655/2655 [00:06<00:00, 416.97it/s]


loss for epoch32 = 356.95607827056665


100%|██████████| 2655/2655 [00:05<00:00, 453.78it/s]


loss for epoch33 = 355.5370217623422


100%|██████████| 2655/2655 [00:06<00:00, 413.27it/s]


loss for epoch34 = 351.30205466679763


100%|██████████| 2655/2655 [00:05<00:00, 460.31it/s]


loss for epoch35 = 349.64136759447865


100%|██████████| 2655/2655 [00:06<00:00, 381.04it/s]


loss for epoch36 = 347.54339981090743


100%|██████████| 2655/2655 [00:05<00:00, 456.45it/s]


loss for epoch37 = 342.6951189621468


100%|██████████| 2655/2655 [00:06<00:00, 417.09it/s]


loss for epoch38 = 340.6568666488165


100%|██████████| 2655/2655 [00:05<00:00, 453.88it/s]


loss for epoch39 = 339.64984165050555


100%|██████████| 2655/2655 [00:06<00:00, 416.93it/s]


loss for epoch40 = 336.73492827406153


100%|██████████| 2655/2655 [00:05<00:00, 457.04it/s]


loss for epoch41 = 334.8938437856268


100%|██████████| 2655/2655 [00:06<00:00, 411.97it/s]


loss for epoch42 = 334.71635723958025


100%|██████████| 2655/2655 [00:05<00:00, 455.30it/s]


loss for epoch43 = 331.2479234606144


100%|██████████| 2655/2655 [00:06<00:00, 414.72it/s]


loss for epoch44 = 329.3009665723657


100%|██████████| 2655/2655 [00:05<00:00, 453.80it/s]


loss for epoch45 = 327.27358552027727


100%|██████████| 2655/2655 [00:06<00:00, 418.23it/s]


loss for epoch46 = 325.697005637805


100%|██████████| 2655/2655 [00:05<00:00, 454.92it/s]


loss for epoch47 = 324.3566900950391


100%|██████████| 2655/2655 [00:06<00:00, 416.37it/s]


loss for epoch48 = 321.7749005156802


100%|██████████| 2655/2655 [00:05<00:00, 449.39it/s]


loss for epoch49 = 320.82330007787095


100%|██████████| 2655/2655 [00:06<00:00, 418.07it/s]


loss for epoch50 = 319.8807694806019


100%|██████████| 2655/2655 [00:05<00:00, 447.04it/s]


loss for epoch51 = 317.5828817267902


100%|██████████| 2655/2655 [00:06<00:00, 403.16it/s]


loss for epoch52 = 318.3342948589125


100%|██████████| 2655/2655 [00:05<00:00, 449.08it/s]


loss for epoch53 = 315.4753844187362


100%|██████████| 2655/2655 [00:06<00:00, 422.26it/s]


loss for epoch54 = 314.1027706951136


100%|██████████| 2655/2655 [00:06<00:00, 441.25it/s]


loss for epoch55 = 313.83242458169116


100%|██████████| 2655/2655 [00:06<00:00, 421.88it/s]


loss for epoch56 = 313.3838455755613


100%|██████████| 2655/2655 [00:05<00:00, 443.89it/s]


loss for epoch57 = 312.90529434371274


100%|██████████| 2655/2655 [00:06<00:00, 421.90it/s]


loss for epoch58 = 311.4358921691892


100%|██████████| 2655/2655 [00:06<00:00, 439.82it/s]


loss for epoch59 = 309.39696727256523


100%|██████████| 2655/2655 [00:06<00:00, 419.92it/s]


loss for epoch60 = 309.01000576728256


100%|██████████| 2655/2655 [00:05<00:00, 445.66it/s]


loss for epoch61 = 308.53501164107


100%|██████████| 2655/2655 [00:06<00:00, 418.82it/s]


loss for epoch62 = 308.5493279672228


100%|██████████| 2655/2655 [00:06<00:00, 398.13it/s]


loss for epoch63 = 306.164573430171


100%|██████████| 2655/2655 [00:05<00:00, 454.67it/s]


loss for epoch64 = 306.0346781363478


100%|██████████| 2655/2655 [00:06<00:00, 411.66it/s]


loss for epoch65 = 306.00792668008944


100%|██████████| 2655/2655 [00:05<00:00, 454.80it/s]


loss for epoch66 = 305.1084704474197


100%|██████████| 2655/2655 [00:06<00:00, 404.22it/s]


loss for epoch67 = 304.76280500827124


100%|██████████| 2655/2655 [00:05<00:00, 455.12it/s]


loss for epoch68 = 302.7710373760201


100%|██████████| 2655/2655 [00:06<00:00, 416.29it/s]


loss for epoch69 = 303.82471421855735


100%|██████████| 2655/2655 [00:05<00:00, 457.23it/s]


loss for epoch70 = 302.3456906584033


100%|██████████| 2655/2655 [00:06<00:00, 415.30it/s]


loss for epoch71 = 300.9110079087259


100%|██████████| 2655/2655 [00:05<00:00, 452.80it/s]


loss for epoch72 = 300.6056458951207


100%|██████████| 2655/2655 [00:06<00:00, 411.92it/s]


loss for epoch73 = 299.0067733776232


100%|██████████| 2655/2655 [00:05<00:00, 455.28it/s]


loss for epoch74 = 300.07615393423475


100%|██████████| 2655/2655 [00:06<00:00, 415.89it/s]


loss for epoch75 = 299.6291577446682


100%|██████████| 2655/2655 [00:05<00:00, 455.53it/s]


loss for epoch76 = 298.3036810992053


100%|██████████| 2655/2655 [00:06<00:00, 408.43it/s]


loss for epoch77 = 298.0896858157648


100%|██████████| 2655/2655 [00:05<00:00, 451.38it/s]


loss for epoch78 = 297.9754398572841


100%|██████████| 2655/2655 [00:06<00:00, 413.53it/s]


loss for epoch79 = 296.50921417502104


100%|██████████| 2655/2655 [00:05<00:00, 455.70it/s]


loss for epoch80 = 296.8655061541649


100%|██████████| 2655/2655 [00:06<00:00, 417.47it/s]


loss for epoch81 = 295.91468519097543


100%|██████████| 2655/2655 [00:05<00:00, 454.59it/s]


loss for epoch82 = 296.5921469157911


100%|██████████| 2655/2655 [00:06<00:00, 410.83it/s]


loss for epoch83 = 295.6371503476694


100%|██████████| 2655/2655 [00:05<00:00, 456.24it/s]


loss for epoch84 = 295.21860988572007


100%|██████████| 2655/2655 [00:06<00:00, 416.89it/s]


loss for epoch85 = 295.3116123738291


100%|██████████| 2655/2655 [00:05<00:00, 457.45it/s]


loss for epoch86 = 293.8260864512704


100%|██████████| 2655/2655 [00:06<00:00, 414.68it/s]


loss for epoch87 = 294.1726795329596


100%|██████████| 2655/2655 [00:05<00:00, 461.77it/s]


loss for epoch88 = 293.4229795690335


100%|██████████| 2655/2655 [00:06<00:00, 379.89it/s]


loss for epoch89 = 292.9248963568243


100%|██████████| 2655/2655 [00:05<00:00, 453.66it/s]


loss for epoch90 = 292.9094701014692


100%|██████████| 2655/2655 [00:06<00:00, 412.27it/s]


loss for epoch91 = 292.1214895732264


100%|██████████| 2655/2655 [00:05<00:00, 453.49it/s]


loss for epoch92 = 292.0034111015266


100%|██████████| 2655/2655 [00:06<00:00, 411.09it/s]


loss for epoch93 = 291.33328238835384


100%|██████████| 2655/2655 [00:05<00:00, 451.83it/s]


loss for epoch94 = 291.9569262377918


100%|██████████| 2655/2655 [00:06<00:00, 415.14it/s]


loss for epoch95 = 291.30208884342574


100%|██████████| 2655/2655 [00:05<00:00, 458.57it/s]


loss for epoch96 = 290.5754368316848


100%|██████████| 2655/2655 [00:06<00:00, 413.41it/s]


loss for epoch97 = 290.1775126620778


100%|██████████| 2655/2655 [00:05<00:00, 457.57it/s]


loss for epoch98 = 289.77018970428617


100%|██████████| 2655/2655 [00:06<00:00, 410.20it/s]


loss for epoch99 = 290.1898272160179
