In [6]:
import torch
import chess
import random

In [7]:
from src.model import ResNet
from src.dataclass import ChessDataset
from src.encode import get_canonical_board, index_to_uci

In [8]:
# Step 1: Load the validation set, in this case we take January 2016
val_dataset = ChessDataset("data/csv/le2016-01.csv")
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

In [9]:
# Step 2: Initialize the model and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(filters=128, res_blocks=6)

In [10]:
# Step 3: Load the trained model
checkpoint = torch.load(f"models/model5/model.28.pth")
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [12]:
# Step 4: Cast model to device and enter evaluation mode
model.to(device)
model.eval()

ResNet(
  (start_block): Sequential(
    (0): Conv2d(18, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (res_tower): ModuleList(
    (0-5): 6 x ResBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batch2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (policy_head): Sequential(
    (0): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=128, out_features=4544, bias=True)
  )
  (value_head): Se

In [13]:
# Step 5: Predict how many times the model picks the best move or at least one of the three best moves
correct_top_1 = 0
correct_top_3 = 0
total = 0

with torch.no_grad():
    for inputs, labels, values in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        values = values.to(device).float()

        policy, value = model(inputs)
        probabilities = torch.softmax(policy, dim=1)

        pred_top_1 = probabilities.argmax(dim=1)
        correct_top_1 += (pred_top_1 == labels).sum().item()

        pred_top_3 = torch.topk(probabilities, k=3, dim=1).indices
        correct_top_3 += sum([labels[i] in pred_top_3[i] for i in range(len(labels))])

        total += labels.size(0)

top_1_accuracy = 100 * correct_top_1 / total
top_3_accuracy = 100 * correct_top_3 / total

In [14]:
print(f"\nValidation Results:\nTop-1 accuracy: {top_1_accuracy:.2f}%\nTop-3 accuracy: {top_3_accuracy:.2f}%")


Validation Results:
Top-1 accuracy: 53.00%
Top-3 accuracy: 80.95%


In [15]:
# With only 5 epochs of training the model manages to find the best move 28% of the time and half the time it finds
# one of the three best moves. This looks promising but we have to visualize the performance more. 

# With helper planes and 50 epochs, the loss is 1.7 and the top1 accuracy is 48.72% while top3 acc is 74.73%

# ResNet seems to perform worse than our second version model. One reason might be that the dataset is not that large. 

# Training the ConvNet with both a policy and value head this time, again gives me top1 47.99 and top3 74.1

# Training resnet with policy and value head (alpha zero style) with 128 filters and 6 resblocks on a dataset of 25million positions.
# After 28 epochs it achieves top1 acc 53% and top3 acc 80.95% on the validation set

In [16]:
# Step 6: Create a function to have the model predict a legal move
def predict_move(board):
    legal_moves = [element.uci() for element in board.legal_moves]
    model_input = torch.from_numpy(get_canonical_board(board.fen())).unsqueeze(0).to(device)
    with torch.no_grad():
        logits, _ = model(model_input)
        probs = torch.softmax(logits, dim=1).squeeze()
        top_moves = torch.topk(probs, k=3)
        shuffled_moves = top_moves.indices[torch.randperm(top_moves.indices.nelement())]
        for index in shuffled_moves:
            uci_move = index_to_uci[index.item()]
            if uci_move in legal_moves:
                return chess.Move.from_uci(uci_move)
    return random.choice(list(board.legal_moves))

In [21]:
# Step 7: Get the pgn of a game played by the network for evaluation
import chess.pgn

game = chess.pgn.Game()
game.headers["Event"] = "Evaluation Game"
game.headers["Site"] = "Local"
game.headers["White"] = "Human"
game.headers["Black"] = "Model"

node = game

board = chess.Board()
while not board.is_game_over():
    move = predict_move(board)
    board.push(move)
    node = node.add_variation(move)
print("\nGame over:", board.result())
with open("data/eval/v5.5.pgn", "w") as pgn_file:
    print(game, file=pgn_file)

print("Game saved to game_output.pgn — upload it to Lichess to review!")


Game over: 1-0
Game saved to game_output.pgn — upload it to Lichess to review!
