In [None]:
#dependencies
import numpy as np
from chess.pgn import Game
from chess import Board
from typing import List
import torch
import torch.nn as nn
from torch.optim import SGD
import funcs
from cnn_model import ChessCNN
from torch.utils.data import DataLoader, Dataset
import chess.pgn
import numpy as np
from dataset import ChessDataset
from tqdm import tqdm

In [None]:
#CNN

class ChessCNN(nn.Module):
    #model architecture: input -> conv2d -> relu -> conv2d -> relu -> conv2d -> relu -> flatten -> linear -> relu -> linear
    #input is an 8 x 8 matrix (representing a chess board) with 13 channels (12 for each unique piece and 1 for legal moves)
    #num_classes is the total number of unique moves in the dataset

    def __init__(self, num_classes):
        super().__init__()

        self.conv1 = nn.Conv2d(13,64,3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(64,64,3,stride=1,padding=1)
        self.conv3 = nn.Conv2d(64,128,3,stride=1,padding=1)

        self.fc1 = nn.Linear(8*8*128,512)
        self.fc2 = nn.Linear(512,num_classes)

        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))

        x = self.flatten(x)

        x = self.relu(self.fc1(x))
        x = self.fc2(x)

        return x

In [None]:
#auxiliary functions

#returns a tensor representation of a chess board. Tensor is of shape (13,8,8)
#requires: board is of type Board
def board_to_tensor(board: Board):
    tensor = np.zeros((13,8,8))
    piece_map = board.piece_map()
    for square, piece in piece_map.items():
        row,col = divmod(square,8)
        piece_type = piece.piece_type - 1
        piece_color = 0 if piece.color else 6
        tensor[piece_type + piece_color, row, col] = 1
    
    legal_moves = board.legal_moves
    for move in legal_moves:
        to_square = move.to_square
        row_to, col_to = divmod(to_square,8)
        tensor[12,row_to,col_to] = 1

    return tensor

#returns an np.array of board tensors and an np.array of labels, where the board tensors are (13,8,8) and the labels are uci formatted strings.
#label y_i is the move that was played in position X_i
#requires: games is of type List[Game]
def games_to_input(games: List[Game]):
    X = []
    y = []
    for game in games:
        board = game.board() 
        for move in game.mainline_moves():
            X.append(board_to_tensor(board))
            y.append(move.uci())
            board.push(move)
    return np.array(X, dtype=np.float32), np.array(y)

#returns an np.array of moves encoded as ints, a dict mapping moves to ints, and a dict mapping ints to moves.
#requires: moves is a list of uci formatted strings
def encode_moves(moves):
    unique_moves = list(set(moves))
    move_to_int = {move: int for int, move in enumerate(unique_moves)}
    int_to_move = {int: move for int, move in enumerate(unique_moves)}
    moves = [move_to_int[move] for move in moves]
    return np.array(moves, dtype=np.float32), move_to_int, int_to_move

In [None]:
#custom dataset

#X are the board tensors and y are the labels
class ChessDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
    def __len__(self):
        return len(self.X)

In [None]:
#training loop

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

pgn = open('./data/lichess_elite_2020-08.pgn')

print("processing games...")
games = []
i = 0
while True and i<=1000:
    game = chess.pgn.read_game(pgn)
    if game is None:
        break
    else:
        games.append(game)
    i += 1
print("games processed")

print("converting games to input...")
X, y = funcs.games_to_input(games)
y, moves_to_int = funcs.encode_moves(y)
print("games converted")

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

print(f"number of inputs = {len(X)}")
num_classes = len(moves_to_int)
print(f"num_classes = {num_classes}")

dataset = ChessDataset(X,y)
model = ChessCNN(num_classes=num_classes).to(device)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

loss_fn = nn.CrossEntropyLoss()

optimizer = SGD(model.parameters(), lr=0.1)

for epoch in range(100):

    model.train()
   
    total_loss = 0

    for input, label in tqdm(dataloader):
        input = input.to(device)
        label = label.to(device)
        output = model(input)
        loss = loss_fn(output, label)
        loss.backward()
        total_loss += float(loss)
        optimizer.step()
        optimizer.zero_grad()
    
    print(f"loss for epoch{epoch} = {total_loss}")

import os
os.makedirs("./checkpoints", exist_ok=True)    

torch.save(model.state_dict(), "./checkpoints/test_path.pth")