In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import chess

from torch.utils.data import DataLoader, TensorDataset
from data_processing import generate_dataset_from_pgn, label_to_move_table, fen_to_board

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

Using device: NVIDIA GeForce RTX 4070 Laptop GPU


In [3]:
dataset = generate_dataset_from_pgn("tal.pgn") # dataset is a list of all moves in a game (8,8,12)
train_to_test_ratio = 0.8

train_size = int(len(dataset) * train_to_test_ratio)
test_size = len(dataset) - train_size

# Split the dataset
train_data = dataset[:train_size]
test_data = dataset[train_size:]

# Convert to tensors (simpler now since labels are already integers!)
X_train = torch.stack([board for board, label in train_data])  # (N, 8, 8, 12)
t_train = torch.tensor([label for board, label in train_data])  # (N,)

X_test = torch.stack([board for board, label in test_data])
t_test = torch.tensor([label for board, label in test_data])

# Create DataLoaders
batch_size = 64
train_dataset = TensorDataset(X_train, t_train) # pairs x and y
test_dataset = TensorDataset(X_test, t_test)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

label_to_uci = label_to_move_table()

In [4]:
class SLPolicyNetwork(nn.Module):
    def __init__(self, num_possible_moves=20480):
        super(SLPolicyNetwork, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels=12, out_channels=32, kernel_size=3, padding=1
        )
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding=1
        )
        self.conv3 = nn.Conv2d(
            in_channels=64, out_channels=128, kernel_size=3, padding=1
        )

        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_possible_moves)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = torch.flatten(x, start_dim=1)  # exclude batch dimension
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

In [5]:
model = SLPolicyNetwork().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1e-3)

epochs = 100

for epoch in range(epochs):
    total_loss = 0
    model.train()

    for x_batch, y_batch in train_dataloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

torch.save(model.state_dict(), "sl_policy_network.pth")

Epoch 1/100, Loss: 6.3119
Epoch 2/100, Loss: 5.7362
Epoch 3/100, Loss: 5.4182
Epoch 4/100, Loss: 5.1370
Epoch 5/100, Loss: 4.8850
Epoch 6/100, Loss: 4.6486
Epoch 7/100, Loss: 4.4293
Epoch 8/100, Loss: 4.2182
Epoch 9/100, Loss: 4.0222
Epoch 10/100, Loss: 3.8417
Epoch 11/100, Loss: 3.6753
Epoch 12/100, Loss: 3.5228
Epoch 13/100, Loss: 3.3836
Epoch 14/100, Loss: 3.2545
Epoch 15/100, Loss: 3.1338
Epoch 16/100, Loss: 3.0203
Epoch 17/100, Loss: 2.9127
Epoch 18/100, Loss: 2.8090
Epoch 19/100, Loss: 2.7109
Epoch 20/100, Loss: 2.6154
Epoch 21/100, Loss: 2.5242
Epoch 22/100, Loss: 2.4348
Epoch 23/100, Loss: 2.3473
Epoch 24/100, Loss: 2.2617
Epoch 25/100, Loss: 2.1773
Epoch 26/100, Loss: 2.0984
Epoch 27/100, Loss: 2.0176
Epoch 28/100, Loss: 1.9386
Epoch 29/100, Loss: 1.8633
Epoch 30/100, Loss: 1.7885
Epoch 31/100, Loss: 1.7136
Epoch 32/100, Loss: 1.6423
Epoch 33/100, Loss: 1.5719
Epoch 34/100, Loss: 1.5037
Epoch 35/100, Loss: 1.4377
Epoch 36/100, Loss: 1.3691
Epoch 37/100, Loss: 1.3056
Epoch 38/1

In [6]:
def predict_move(model, board_tensor):
    """
    Takes a board tensor (8, 8, 12) and returns the predicted UCI move.
    """
    model.eval()  # Set to evaluation mode

    with torch.no_grad():  # No gradients needed for inference
        # Add batch dimension: (8, 8, 12) -> (1, 8, 8, 12)
        board_batch = board_tensor.unsqueeze(0).to(device)

        # Get model output
        outputs = model(board_batch)  # Shape: (1, 20480)

        # Get the highest scoring move
        predicted_label = torch.argmax(outputs, dim=1).item()

        # Convert to UCI
        predicted_uci = label_to_uci[predicted_label]

    return predicted_uci, predicted_label

def list_predicted_moves(model, board_tensor, num_moves):
    label_to_uci = label_to_move_table()

    model.eval()
    with torch.no_grad():
        board_batch = board_tensor.unsqueeze(0).to(device)
        logits = model(board_batch)  
        probabilities = F.softmax(logits, dim=1)
        score, moves = torch.topk(probabilities, num_moves)
        moves = [label_to_uci[int(move)] for move in moves[0]]
        

    return moves, score

In [7]:
board = chess.Board()
board.push_uci('d2d4')
board_tensor = fen_to_board(board.fen())
model = SLPolicyNetwork()
model.load_state_dict(torch.load("sl_policy_network.pth", map_location=device))
model.to(device)    
predict_move(model, board_tensor)
list_predicted_moves(model, board_tensor, 5)

# MODEL PREDICTS ILLEGAL MOVES

board = chess.Board()
moves_played = []
move_num = 0

while not board.is_game_over():
    board_tensor = fen_to_board(board.fen())
    
    # Get top predicted moves from the model
    moves, probs = list_predicted_moves(model, board_tensor, 10)
    
    # Filter predictions to only legal moves
    legal_moves = [move.uci() for move in board.legal_moves]
    move_played = None
    
    for move in moves:
        if move in legal_moves:
            board.push_uci(move)
            move_played = move
            break

    # If none of the top moves are legal, pick a random legal move
    if move_played is None:
        move_played = next(iter(board.legal_moves)).uci()
        board.push_uci(move_played)

    moves_played.append(move_played)
    move_num += 1
    
    # Print board and move number
    print(board)
    print(move_num)

r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . P . . .
. . . . . . . .
P P P P . P P P
R N B Q K B N R
1
r n b q k b n r
p p . p p p p p
. . . . . . . .
. . p . . . . .
. . . . P . . .
. . . . . . . .
P P P P . P P P
R N B Q K B N R
2
r n b q k b n r
p p . p p p p p
. . . . . . . .
. . p . . . . .
. . . . P . . .
. . . . . N . .
P P P P . P P P
R N B Q K B . R
3
r n b q k b n r
p p . . p p p p
. . . p . . . .
. . p . . . . .
. . . . P . . .
. . . . . N . .
P P P P . P P P
R N B Q K B . R
4
r n b q k b n r
p p . . p p p p
. . . p . . . .
. . p . . . . .
. . . P P . . .
. . . . . N . .
P P P . . P P P
R N B Q K B . R
5
r n b q k b n r
p p . . p p p p
. . . p . . . .
. . . . . . . .
. . . p P . . .
. . . . . N . .
P P P . . P P P
R N B Q K B . R
6
r n b q k b n r
p p . . p p p p
. . . p . . . .
. . . . . . . .
. . . N P . . .
. . . . . . . .
P P P . . P P P
R N B Q K B . R
7
r n b q k b . r
p p . . p p p p
. . . p . n . .
. . . . . . . .
. . . N P . . .
. . . . . 

In [9]:
# create a game object
game = chess.pgn.Game()
game.headers["Event"] = "AI Self Play"
game.headers["White"] = "Your Model"
game.headers["Black"] = "Your Model"
game.headers["Result"] = board.result()

# add moves to the game node
node = game
for move in board.move_stack:
    node = node.add_variation(move)

# save to PGN file
with open("output_game1.pgn", "w", encoding="utf-8") as pgn_file:
    print(game, file=pgn_file)