In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from torch.utils.data import DataLoader, TensorDataset
from data_processing import generate_dataset_from_pgn, label_to_move_table, fen_to_board
import chess

g4f7n


In [12]:


os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

dataset = generate_dataset_from_pgn("masaurus101-white.pgn")
train_to_test_ratio = 0.8

train_size = int(len(dataset) * train_to_test_ratio)
test_size = len(dataset) - train_size

# Split the dataset
train_data = dataset[:train_size]
test_data = dataset[train_size:]

# Convert to tensors (simpler now since labels are already integers!)
X_train = torch.stack([board for board, label in train_data])  # (N, 8, 8, 12)
t_train = torch.tensor([label for board, label in train_data])  # (N,)

X_test = torch.stack([board for board, label in test_data])
t_test = torch.tensor([label for board, label in test_data])

# Create DataLoaders
batch_size = 32
train_dataset = TensorDataset(X_train, t_train)
test_dataset = TensorDataset(X_test, t_test)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)




AssertionError: Torch not compiled with CUDA enabled

In [3]:

class SLPolicyNetwork(nn.Module):
    def __init__(self, num_possible_moves=20480):
        super(SLPolicyNetwork, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels=12, out_channels=32, kernel_size=3, padding=1
        )
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding=1
        )
        self.conv3 = nn.Conv2d(
            in_channels=64, out_channels=128, kernel_size=3, padding=1
        )

        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_possible_moves)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = torch.flatten(x, start_dim=1)  # exclude batch dimension
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x


model = SLPolicyNetwork()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1e-4)

In [4]:

def predict_move(model, board_tensor):
    """
    Takes a board tensor (8, 8, 12) and returns the predicted UCI move.
    """
    label_to_uci = label_to_move_table()
    model.eval()  # Set to evaluation mode
    

    with torch.no_grad():  # No gradients needed for inference
        # Add batch dimension: (8, 8, 12) -> (1, 8, 8, 12)
        board_batch = board_tensor.unsqueeze(0)

        # Get model output
        logits = model(board_batch)  # Shape: (1, 20480)
        probabilities = F.softmax(logits, dim=1)

        # Get the highest scoring move
        predicted_label = torch.argmax(probabilities, dim=1).item()

        # Convert to UCI
        predicted_uci = label_to_uci[predicted_label]

    return predicted_uci, probabilities[0][predicted_label]


def list_predicted_moves(model, board_tensor, num_moves):
    label_to_uci = label_to_move_table()

    model.eval()
    with torch.no_grad():
        board_batch = board_tensor.unsqueeze(0)
        logits = model(board_batch)  
        probabilities = F.softmax(logits, dim=1)
        score, moves = torch.topk(probabilities, num_moves)
        moves = [label_to_uci[int(move)] for move in moves[0]]
        

    return moves, score



In [5]:
epochs = 1

for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_dataloader):
        output = model(data)  # calculate predictions for this batch
        loss = criterion(output, target)  # calculate loss
        optimizer.zero_grad()  # reset gradient
        loss.backward()  # calculate gradient
        optimizer.step()  # update parameters

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

    # doesn't really make sense to calculate validation accuracy as opening move has many possible moves

    # model.eval()
    # test_loss = 0
    # correct = 0

    # with torch.no_grad():
    #     for data, target in test_dataloader:
    #         # data, target = data.to(device), target.to(device)
    #         output = model(data)
    #         test_loss += criterion(output, target).item()
    #         correct += (output.argmax(1) == target).type(torch.float).sum().item()

    # print('epoch: {}, test loss: {:.6f}, test accuracy: {:.6f}'.format(
    #     epoch + 1,
    #     test_loss / len(test_dataloader),
    #     correct / len(test_dataloader.dataset)
    #     ))

Epoch 1: Loss = 9.9320
Epoch 1: Loss = 9.9107
Epoch 1: Loss = 9.6090
Epoch 1: Loss = 7.7745
Epoch 1: Loss = 6.9639
Epoch 1: Loss = 7.1183
Epoch 1: Loss = 6.9550
Epoch 1: Loss = 6.5209
Epoch 1: Loss = 7.5304
Epoch 1: Loss = 6.6756


In [6]:
import chess
board = chess.Board()
board.push_uci('d2d4')
board_tensor = fen_to_board(board.fen())
print(predict_move(model, board_tensor))
print(list_predicted_moves(model, board_tensor, 5))

# label_to_move_table()

# MODEL PREDICTS ILLEGAL MOVES

('b1c3', tensor(0.0180))
(['b1c3', 'g8f6', 'd2d4', 'g1f3', 'e2e4'], tensor([[0.0180, 0.0175, 0.0161, 0.0156, 0.0151]]))


{0: 'a1a1',
 1: 'a1a1q',
 2: 'a1a1r',
 3: 'a1a1b',
 4: 'a1a1n',
 5: 'a1a2',
 6: 'a1a2q',
 7: 'a1a2r',
 8: 'a1a2b',
 9: 'a1a2n',
 10: 'a1a3',
 11: 'a1a3q',
 12: 'a1a3r',
 13: 'a1a3b',
 14: 'a1a3n',
 15: 'a1a4',
 16: 'a1a4q',
 17: 'a1a4r',
 18: 'a1a4b',
 19: 'a1a4n',
 20: 'a1a5',
 21: 'a1a5q',
 22: 'a1a5r',
 23: 'a1a5b',
 24: 'a1a5n',
 25: 'a1a6',
 26: 'a1a6q',
 27: 'a1a6r',
 28: 'a1a6b',
 29: 'a1a6n',
 30: 'a1a7',
 31: 'a1a7q',
 32: 'a1a7r',
 33: 'a1a7b',
 34: 'a1a7n',
 35: 'a1a8',
 36: 'a1a8q',
 37: 'a1a8r',
 38: 'a1a8b',
 39: 'a1a8n',
 40: 'a1b1',
 41: 'a1b1q',
 42: 'a1b1r',
 43: 'a1b1b',
 44: 'a1b1n',
 45: 'a1b2',
 46: 'a1b2q',
 47: 'a1b2r',
 48: 'a1b2b',
 49: 'a1b2n',
 50: 'a1b3',
 51: 'a1b3q',
 52: 'a1b3r',
 53: 'a1b3b',
 54: 'a1b3n',
 55: 'a1b4',
 56: 'a1b4q',
 57: 'a1b4r',
 58: 'a1b4b',
 59: 'a1b4n',
 60: 'a1b5',
 61: 'a1b5q',
 62: 'a1b5r',
 63: 'a1b5b',
 64: 'a1b5n',
 65: 'a1b6',
 66: 'a1b6q',
 67: 'a1b6r',
 68: 'a1b6b',
 69: 'a1b6n',
 70: 'a1b7',
 71: 'a1b7q',
 72: 'a1b7r',
 73

In [7]:
board = chess.Board()

moves_played = []
while not board.is_game_over():
    board_tensor = fen_to_board(board.fen())
    moves, probs = list_predicted_moves(model, board_tensor, 20480)
    
    for move in moves:
        try:
            board.push_uci(move)
        except chess.IllegalMoveError:
            continue
        break

    moves_played.append(move)

    print(board)
    print(moves_played)
    

r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . N . . . . .
P P P P P P P P
R . B Q K B N R
['b1c3']
r n b q k b . r
p p p p p p p p
. . . . . n . .
. . . . . . . .
. . . . . . . .
. . N . . . . .
P P P P P P P P
R . B Q K B N R
['b1c3', 'g8f6']
r n b q k b . r
p p p p p p p p
. . . . . n . .
. . . . . . . .
. . . P . . . .
. . N . . . . .
P P P . P P P P
R . B Q K B N R
['b1c3', 'g8f6', 'd2d4']
r . b q k b . r
p p p p p p p p
. . n . . n . .
. . . . . . . .
. . . P . . . .
. . N . . . . .
P P P . P P P P
R . B Q K B N R
['b1c3', 'g8f6', 'd2d4', 'b8c6']
r . b q k b . r
p p p p p p p p
. . n . . n . .
. . . . . . . .
. . . P . . . .
. . N . . N . .
P P P . P P P P
R . B Q K B . R
['b1c3', 'g8f6', 'd2d4', 'b8c6', 'g1f3']
r . b q k b . r
p p p . p p p p
. . n . . n . .
. . . p . . . .
. . . P . . . .
. . N . . N . .
P P P . P P P P
R . B Q K B . R
['b1c3', 'g8f6', 'd2d4', 'b8c6', 'g1f3', 'd7d5']
r . b q k b . r
p p p . p p p p
. . n . . n . .
. . . p . 