In [4]:
import os
import chess.pgn
import encoder
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
def reformat_games(file_names, new_dir):
    """Reformats PGN games and saves them to a new directory with unique filenames."""
    file_name_idx = 0

    for file_name in file_names:
        with open(file_name, 'r') as pgn_fh:
            while True:
                game = chess.pgn.read_game(pgn_fh)
                if not game:
                    break
                new_file_name = os.path.join(new_dir, f'game_{file_name_idx}.pgn')
                with open(new_file_name, 'w') as new_file:
                    print(game, file=new_file, end='\n\n')
                file_name_idx += 1
                if file_name_idx % 1000 == 0:
                    print(f'Processed {file_name_idx} games so far...')


original_dir = '../../data/2_5_million_chess_games/original/train'
file_names = os.listdir(original_dir)
for i in range(len(file_names)):
    file_names[i] = os.path.join(original_dir, file_names[i])
reformat_dir = '../../data/2_5_million_chess_games/reformated/train'

In [18]:
reformat_games(file_names, reformat_dir)

In [19]:
class ConvBlock(nn.Module):
    def __init__(self, input_channels, num_filters):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=num_filters, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=num_filters)
        self.relu1 = nn.ReLU()

    def __call__(self, x):
        return self.relu1(self.bn1(self.conv1(x)))

In [20]:
class ResidualBlock(nn.Module):
    def __init__(self, num_filters):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=num_filters)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=num_filters)
        self.relu2 = nn.ReLU()

    def __call__(self, x):
        residual = x
        temp = self.relu1(self.bn1(self.conv1(x)))
        output = self.relu2(self.bn2(self.conv2(temp)) + residual)
        return output


In [21]:
class ValueHead(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=1, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(num_features=1)
        self.relu1 = nn.ReLU()
        self.fc1 = nn.Linear(64, 256)
        self.relu2 = nn.ReLU()
        self.fc2 = nn.Linear(256, 1)
        self.tanh1 = nn.Tanh()

    def __call__(self, x):
        temp1 = self.relu1(self.bn1(self.conv1(x)))
        view = temp1.view(temp1.shape[0], 64)
        temp2 = self.tanh1(self.fc2(self.relu2(self.fc1(view))))
        return temp2


In [22]:
class PolicyHead(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=2, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(num_features=2)
        self.relu1 = nn.ReLU()
        self.fc1 = nn.Linear(128, 4608)

    def __call__(self, x):
        temp = self.relu1(self.bn1(self.conv1(x)))
        view = temp.view(temp.shape[0], 128)
        temp = self.fc1(view)
        return temp

In [23]:
class AlphaZero(nn.Module):
    def __init__(self, num_blocks, num_filters):
        super().__init__()
        self.convBlock1 = ConvBlock(16, num_filters)
        residual_blocks = []
        for i in range(num_blocks):
            residual_blocks.append(ResidualBlock(num_filters))
        self.residualBlocks = nn.ModuleList(residual_blocks)
        self.valueHead = ValueHead(num_filters)
        self.policyHead = PolicyHead(num_filters)
        self.softmax1 = nn.Softmax(dim=1)
        self.mseLoss = nn.MSELoss()
        self.crossEntropyLoss = nn.CrossEntropyLoss()

    def __call__(self, x, value_target=None, policy_target=None, policy_mask=None):
        x = self.convBlock1(x)
        for block in self.residualBlocks:
            x = block(x)
        value, policy = self.valueHead(x), self.policyHead(x)

        if self.training:
            value_loss = self.mseLoss(value, value_target)
            policy_target = policy_target.view(policy_target.shape[0])
            policy_loss = self.cross_entropy_loss(policy, policy_target)
            return value_loss, policy_loss
        else:
            policy_mask = policy_mask.view(policy_mask.shape[0], -1)
            policy_exp = torch.exp(policy)
            policy_exp *= policy_mask.type(torch.float32)
            policy_exp_sum = torch.sum(policy_exp, dim=1, keepdim=True)
            policy_softmax = policy_exp / policy_exp_sum
            return value, policy_softmax


In [24]:
class AlphaZeroDataset(Dataset):
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.pgn_files = os.listdir(dataset_path)

    def __len__(self):
        return len(self.pgn_files)

    def __getitem__(self, i):
        pgn_file_path = os.path.join(self.dataset_path, self.pgn_files[i])
        with open(pgn_file_path) as pgn_file:
            game = chess.pgn.read_game(pgn_file)

        move_sequence = list(game.mainline_moves())
        random_move_index = np.random.randint(0, len(move_sequence) - 1)

        board = game.board()
        for i, move in enumerate(move_sequence):
            board.push(move)
            if random_move_index == i:
                next_move = move_sequence[i + 1]
                break

        winner = encoder.game_result(game.headers['Result'])
        position, policy_target, value_target, mask = encoder.encode_training_point(board, next_move, winner)

        return {
            'position': torch.from_numpy(position),
            'policy': torch.Tensor([policy_target]).type(dtype=torch.long),
            'value': torch.Tensor([value_target], dtype=torch.float32),
            'mask': torch.from_numpy(mask)
        }


In [41]:
EPOCHS = 40
RESIDUAL_BLOCKS = 20
NUM_FILTERS = 128
DATASET_PATH = reformat_dir
BATCH_SIZE = 256
NUM_WORKERS = 48

train_loader = DataLoader(
    AlphaZeroDataset(DATASET_PATH),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS
)

az = AlphaZero(RESIDUAL_BLOCKS, NUM_FILTERS)
if device == 'cuda':
    az_model = az.cuda()

optimizer = optim.Adam(az.parameters())

for epoch in range(EPOCHS):
    az.train()
    total_value_loss, total_policy_loss = 0, 0
    
    for iter_num, data in enumerate(train_loader):
        optimizer.zero_grad()

        position = data['position']
        value_target = data['value']
        policy_target = data['policy']

        if device == 'cuda':
            position = position.cuda()
            value_target = value_target.cuda()
            policy_target = policy_target.cuda()

        value_loss, policy_loss = az(position, value_target=value_target, policy_target=policy_target)
        total_value_loss += value_loss.item()
        total_policy_loss += policy_loss.item()

        loss = value_loss + policy_loss
        loss.backward()
        optimizer.step()

    avg_value_loss = total_value_loss / len(train_loader)
    avg_policy_loss = total_policy_loss / len(train_loader)
    print(f"Epoch {epoch:03} | Value Loss: {avg_value_loss:.5f} | Policy Loss: {avg_policy_loss:.5f}")

    model_filename = f'alpha_zero_net_epoch_{epoch:03}.pt'
    torch.save(az.state_dict(), model_filename)
    print(f"Model saved at epoch {epoch}\n")

Epoch 000 | Value Loss: 0.01679 | Policy Loss: 0.00806
Model saved at epoch 0
Epoch 001 | Value Loss: 0.01289 | Policy Loss: 0.00907
Model saved at epoch 1
Epoch 002 | Value Loss: 0.00103 | Policy Loss: 0.01157
Model saved at epoch 2
Epoch 003 | Value Loss: 0.01130 | Policy Loss: 0.00865
Model saved at epoch 3
Epoch 004 | Value Loss: 0.00843 | Policy Loss: 0.01550
Model saved at epoch 4
Epoch 005 | Value Loss: 0.00848 | Policy Loss: 0.00005
Model saved at epoch 5
Epoch 006 | Value Loss: 0.01554 | Policy Loss: 0.00943
Model saved at epoch 6
Epoch 007 | Value Loss: 0.00005 | Policy Loss: 0.01392
Model saved at epoch 7
Epoch 008 | Value Loss: 0.01128 | Policy Loss: 0.00005
Model saved at epoch 8
Epoch 009 | Value Loss: 0.00627 | Policy Loss: 0.01506
Model saved at epoch 9
Epoch 010 | Value Loss: 0.00574 | Policy Loss: 0.00005
Model saved at epoch 10
Epoch 011 | Value Loss: 0.00474 | Policy Loss: 0.00389
Model saved at epoch 11
Epoch 012 | Value Loss: 0.00265 | Policy Loss: 0.00028
Model s