In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from torch.utils.data import DataLoader, TensorDataset
from data_processing import generate_dataset_from_pgn, label_to_move_table, fen_to_board
import chess
import random

In [6]:


os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

dataset = generate_dataset_from_pgn("masaurus101-white.pgn", 100)
train_to_test_ratio = 0.8

train_size = int(len(dataset) * train_to_test_ratio)
test_size = len(dataset) - train_size

# split the dataset
train_data = dataset[:train_size]
test_data = dataset[train_size:]

# convert to tensors (simpler now since labels are already integers!)
X_train = torch.stack([board for board, move, winner in train_data])  # (N, 8, 8, 12)
t_train = torch.tensor([(move, winner) for board, move, winner in train_data])  # (N, 2)

X_test = torch.stack([board for board, move, winner in test_data])
t_test = torch.tensor([(move, winner) for board, move, winner in test_data])

# create DataLoaders
batch_size = 32
train_dataset = TensorDataset(X_train, t_train)
test_dataset = TensorDataset(X_test, t_test)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


Total games: 1
Total games: 2
Total games: 3
Total games: 4
Total games: 5
Total games: 6
Total games: 7
Total games: 8
Total games: 9
Total games: 10
Total games: 11
Total games: 12
Total games: 13
Total games: 14
Total games: 15
Total games: 16
Total games: 17
Total games: 18
Total games: 19
Total games: 20
Total games: 21
Total games: 22
Total games: 23
Total games: 24
Total games: 25
Total games: 26
Total games: 27
Total games: 28
Total games: 29
Total games: 30
Total games: 31
Total games: 32
Total games: 33
Total games: 34
Total games: 35
Total games: 36
Total games: 37
Total games: 38
Total games: 39
Total games: 40
Total games: 41
Total games: 42
Total games: 43
Total games: 44
Total games: 45
Total games: 46
Total games: 47
Total games: 48
Total games: 49
Total games: 50
Total games: 51
Total games: 52
Total games: 53
Total games: 54
Total games: 55
Total games: 56
Total games: 57
Total games: 58
Total games: 59
Total games: 60
Total games: 61
Total games: 62
Total games: 63
T

In [7]:
class ResNetBlock(nn.Module):
    def __init__(self, channels, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            channels, channels, kernel_size=3, stride=stride, padding=1
        )
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(
            channels, channels, kernel_size=3, stride=1, padding=1
        )
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.relu(out)
        return out

class SLPolicyValueNetwork(nn.Module):
    def __init__(self, blocks=10, channels=256, num_possible_moves=20480):
        super(SLPolicyValueNetwork, self).__init__()

        # shared trunk
        self.conv1 = nn.Conv2d(
            in_channels=12, out_channels=channels, kernel_size=3, padding=1 # padding=1 for same size output
        )

        self.norm = nn.BatchNorm2d(channels)

        self.blocks = nn.ModuleList(
            [ResNetBlock(channels) for _ in range(blocks)]
        )

        self.fc_shared = nn.Linear(channels * 8 * 8, 512)

        # policy head
        self.fc_policy = nn.Linear(512, num_possible_moves)

        # value head
        self.value_conv = nn.Conv2d(in_channels=channels, out_channels=1, kernel_size=1)
        self.value_bn = nn.BatchNorm2d(1)
        self.value_fc1 = nn.Linear(8 * 8, 256)
        self.value_fc2 = nn.Linear(256, 1)


    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = F.relu(self.conv1(x))
        conv_out = self.norm(x)

        for block in self.blocks:
            conv_out = block(conv_out)

        # policy head
        policy_flat = torch.flatten(conv_out, start_dim=1)  # exclude batch dimension
        policy_features = F.relu(self.fc_shared(policy_flat))
        policy_logits = self.fc_policy(policy_features)

        # value head
        v = self.value_conv(conv_out)
        v = self.value_bn(v)
        v = F.relu(v)
        v = v.view(v.size(0), -1)
        v = self.value_fc1(v)
        v = F.relu(v)
        v = self.value_fc2(v) # convert to single scalar
        value = torch.tanh(v)


        return policy_logits, value


model = SLPolicyValueNetwork()
# model.load_state_dict(torch.load("sl_policy_network_KC.pth", map_location=torch.device("cpu")))
policy_criterion = nn.CrossEntropyLoss() # softmax regression loss function
value_criterion = nn.BCEWithLogitsLoss() # logistic regression loss function
optimizer = optim.Adam(model.parameters(), lr=0.1e-4)

In [8]:

def predict_move(model, board_tensor):
    """
    Takes a board tensor (8, 8, 12) and returns the predicted UCI move.
    """
    label_to_uci = label_to_move_table()
    model.eval()  # Set to evaluation mode
    

    with torch.no_grad():  # no gradients needed for inference
        # add batch dimension to fit model params: (8, 8, 12) -> (1, 8, 8, 12)
        board_batch = board_tensor.unsqueeze(0)

        # Get model output
        logits, val = model(board_batch)  # Shape: (1, 20480)
        probabilities = F.softmax(logits, dim=1)

        # Get the highest scoring move
        predicted_label = torch.argmax(probabilities, dim=1).item()

        # Convert to UCI
        predicted_uci = label_to_uci[predicted_label]

    return predicted_uci, val


def list_predicted_moves(model, board_tensor, num_moves):
    label_to_uci = label_to_move_table()

    model.eval()
    with torch.no_grad():
        board_batch = board_tensor.unsqueeze(0)
        logits, val = model(board_batch)  
        probabilities = F.softmax(logits, dim=1)
        score, moves = torch.topk(probabilities, num_moves)
        # moves = [label_to_uci[int(move)] for move in moves[0]]
        

    return moves, val



In [9]:
epochs = 1

for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_dataloader):
        batch_move_target = target[:, 0]
        batch_val_target = target[:, 1].float().unsqueeze(1)

        pred_policy, pred_val = model(data)  # calculate predictions for this batch
        policy_loss = policy_criterion(pred_policy, batch_move_target)  # calculate loss for policy
        value_loss = value_criterion(pred_val, batch_val_target) # calculate loss for value
        loss = policy_loss + value_loss
        optimizer.zero_grad()  # reset gradient
        loss.backward()  # calculate gradient
        optimizer.step()  # update parameters

        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

    # doesn't really make sense to calculate validation accuracy as opening move has many possible moves

    # model.eval()
    # test_loss = 0
    # correct = 0

    # with torch.no_grad():
    #     for data, target in test_dataloader:
    #         # data, target = data.to(device), target.to(device)
    #         output = model(data)
    #         test_loss += criterion(output, target).item()
    #         correct += (output.argmax(1) == target).type(torch.float).sum().item()

    # print('epoch: {}, test loss: {:.6f}, test accuracy: {:.6f}'.format(
    #     epoch + 1,
    #     test_loss / len(test_dataloader),
    #     correct / len(test_dataloader.dataset)
    #     ))

Epoch 1: Loss = 10.6795
Epoch 1: Loss = 10.1292
Epoch 1: Loss = 10.2155
Epoch 1: Loss = 10.4192
Epoch 1: Loss = 9.6584
Epoch 1: Loss = 9.4563
Epoch 1: Loss = 9.3433


KeyboardInterrupt: 

In [10]:
import chess
board = chess.Board()
# board.push_uci('d2d4')
board_tensor = fen_to_board(board.fen())
print(predict_move(model, board_tensor))
print(list_predicted_moves(model, board_tensor, 5))

label_to_move_table()

# MODEL PREDICTS ILLEGAL MOVES

('e2e4', tensor([[-0.7302]]))
(tensor([[10735,  8135, 15570, 12340,  6817]]), tensor([[-0.7302]]))


{0: 'a1a1',
 1: 'a1a1q',
 2: 'a1a1r',
 3: 'a1a1b',
 4: 'a1a1n',
 5: 'a1a2',
 6: 'a1a2q',
 7: 'a1a2r',
 8: 'a1a2b',
 9: 'a1a2n',
 10: 'a1a3',
 11: 'a1a3q',
 12: 'a1a3r',
 13: 'a1a3b',
 14: 'a1a3n',
 15: 'a1a4',
 16: 'a1a4q',
 17: 'a1a4r',
 18: 'a1a4b',
 19: 'a1a4n',
 20: 'a1a5',
 21: 'a1a5q',
 22: 'a1a5r',
 23: 'a1a5b',
 24: 'a1a5n',
 25: 'a1a6',
 26: 'a1a6q',
 27: 'a1a6r',
 28: 'a1a6b',
 29: 'a1a6n',
 30: 'a1a7',
 31: 'a1a7q',
 32: 'a1a7r',
 33: 'a1a7b',
 34: 'a1a7n',
 35: 'a1a8',
 36: 'a1a8q',
 37: 'a1a8r',
 38: 'a1a8b',
 39: 'a1a8n',
 40: 'a1b1',
 41: 'a1b1q',
 42: 'a1b1r',
 43: 'a1b1b',
 44: 'a1b1n',
 45: 'a1b2',
 46: 'a1b2q',
 47: 'a1b2r',
 48: 'a1b2b',
 49: 'a1b2n',
 50: 'a1b3',
 51: 'a1b3q',
 52: 'a1b3r',
 53: 'a1b3b',
 54: 'a1b3n',
 55: 'a1b4',
 56: 'a1b4q',
 57: 'a1b4r',
 58: 'a1b4b',
 59: 'a1b4n',
 60: 'a1b5',
 61: 'a1b5q',
 62: 'a1b5r',
 63: 'a1b5b',
 64: 'a1b5n',
 65: 'a1b6',
 66: 'a1b6q',
 67: 'a1b6r',
 68: 'a1b6b',
 69: 'a1b6n',
 70: 'a1b7',
 71: 'a1b7q',
 72: 'a1b7r',
 73

In [None]:
board = chess.Board()
moves_played = []
model_turn = 1

while not board.is_game_over():
    move = None
    if model_turn:
        print("model")

        board_tensor = fen_to_board(board.fen())
        moves, probs = list_predicted_moves(model, board_tensor, 20480)
        
        for move in moves:
            try:
                board.push_uci(move)
            except chess.IllegalMoveError:
                continue
            break
        model_turn = 0
    else:
        print("random")
        moves = board.legal_moves
        move_index = random.randint(0, moves.count()-1)
        moves = [move for move in moves]
        move = moves[move_index]
        board.push(move)
        model_turn = 1

    moves_played.append(move)
    print(board)
    print(moves_played)
    

model
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . P . . .
. . . . . . . .
P P P P . P P P
R N B Q K B N R
['e2e4']
random
r n b q k b n r
p . p p p p p p
. . . . . . . .
. p . . . . . .
. . . . P . . .
. . . . . . . .
P P P P . P P P
R N B Q K B N R
['e2e4', Move.from_uci('b7b5')]
model
r n b q k b n r
p . p p p p p p
. . . . . . . .
. p . . . . . .
. . . P P . . .
. . . . . . . .
P P P . . P P P
R N B Q K B N R
['e2e4', Move.from_uci('b7b5'), 'd2d4']
random
r n b q k b n r
p . p . p p p p
. . . p . . . .
. p . . . . . .
. . . P P . . .
. . . . . . . .
P P P . . P P P
R N B Q K B N R
['e2e4', Move.from_uci('b7b5'), 'd2d4', Move.from_uci('d7d6')]
model
r n b q k b n r
p . p . p p p p
. . . p . . . .
. p . . . . . .
. . . P P . . .
. . N . . . . .
P P P . . P P P
R . B Q K B N R
['e2e4', Move.from_uci('b7b5'), 'd2d4', Move.from_uci('d7d6'), 'b1c3']
random
r n b . k b n r
p . p q p p p p
. . . p . . . .
. p . . . . . .
. . . P P . . .
. . N . . . . .
P P P . . 

In [None]:

def export_game_from_board(board: chess.Board, file_name: str):
    game = chess.pgn.Game()
    game.headers["Event"] = "AI Self Play"
    game.headers["White"] = "Your Model"
    game.headers["Black"] = "Your Model"
    game.headers["Result"] = board.result()

    # add moves to the game node
    node = game
    for move in board.move_stack:
        node = node.add_variation(move)

    # save to PGN file
    with open(f"{file_name}.pgn", "w", encoding="utf-8") as pgn_file:
        print(game, file=pgn_file)
    pgn_file.close()


20


In [None]:
export_game_from_board(board, "model_vs_random")