## Training a specialised model (CNN+RNN) for Othello/Reversi

This notebook presents a way to estimate the next move to play in a game of Othello using Supervised Learning. The datasets come from the [Fédération Française d'Othello](https://www.ffothello.org/informatique/la-base-wthor/). 

### Data Handling

In [2]:
import struct   # for reading the .wtb files
import os       # for file/path/directories,...  handling

#### Extracting data from the WThor database

Some functions were taken or modified from the [dnnothello repo](https://github.com/wjaskowski/dnnothello/blob/master/games/othello_data.py)

The header of a .wthor file is 16 bytes long and contains the following fields:
- 1 byte: century of the file's creation
- 1 byte: year of the file's creation
- 1 byte: month of the file's creation
- 1 byte: day of the file's creation
- 4 bytes (int): number of games in the file ($\leq$ 2 147 483 648)
- 2 bytes (short): 0 here (but for other type of files : number of players, tournaments, or number of empty squares in the board ($\leq$ 65 535))
- 1 byte: year of the games
- 1 byte: size of the board {0: 8x8, 8: 8x8, 10: 10x10}
- 1 byte: 0 here the games type (1 if "solitaire", 0 otherwise)
- 1 byte: the games depth
- 1 byte: reserved

The games are stored in the file in the following format:
- 2 bytes (short): label of the tournament
- 2 bytes (short): id number of the black player
- 2 bytes (short): id number of the white player
- 1 byte: true score of the black player
- 1 byte: theoretic score of the black player

And then each move is stored as a 60 byte long record (list of moves).

In [3]:
BOARD_SIZE = 8

HEADER_LENGTH = 16
HEADER_FORMAT = "<BBBBIHHBBBB"  # Byte, Byte, Byte, Byte, Int, Short, Short, Byte, Byte, Byte, (Reserved) Byte

GAME_INFO_LENGTH = 8    
GAME_INFO_FORMAT = "<HHHBB"     # Short, Short, Short, Byte, Byte

MOVES_LENGTH = 60
MOVES_FORMAT = "<" + "B"*MOVES_LENGTH

POSSIBLE_SIZE = [0, 8]

def read_all_wtb_files(directory):
    """Generator to read all .wtb files in a directory."""
    for file_name in os.listdir(directory):
        if file_name.endswith(".wtb"):
            yield from read_wtb(os.path.join(directory, file_name))

def read_wtb(file_path):
    """Generator to read a .wtb file and yield game information and played moves."""
    with open(file_path, 'rb') as f:
        header = struct.unpack(HEADER_FORMAT, f.read(HEADER_LENGTH))
        assert header[7] in POSSIBLE_SIZE   # Check the board size
        
        for _ in range(header[4]):  # Number of games
            game_info = struct.unpack(GAME_INFO_FORMAT, f.read(GAME_INFO_LENGTH))
            played_moves = struct.unpack(MOVES_FORMAT, f.read(MOVES_LENGTH))
            yield game_info[3], played_moves    # Black player true score, moves

In [4]:
reader = read_wtb('data/raw/WTH_2001.wtb')
print(next(reader))

full_reader = read_all_wtb_files('data/raw')
print(next(full_reader))

(11, (56, 64, 53, 46, 35, 63, 34, 66, 65, 74, 37, 43, 57, 33, 76, 24, 75, 26, 83, 36, 73, 38, 25, 16, 14, 15, 17, 47, 13, 68, 48, 58, 52, 28, 67, 23, 12, 61, 32, 42, 31, 86, 51, 41, 27, 84, 85, 82, 71, 18, 72, 11, 21, 22, 62, 81, 77, 78, 88, 87))
(34, (56, 64, 33, 36, 46, 34, 43, 67, 66, 65, 53, 63, 74, 84, 75, 57, 35, 24, 47, 38, 76, 52, 58, 37, 42, 62, 83, 82, 73, 85, 86, 87, 48, 68, 25, 14, 13, 31, 61, 51, 15, 26, 77, 23, 41, 88, 21, 72, 16, 32, 12, 22, 78, 71, 81, 11, 17, 27, 28, 18))


In [5]:
from utils.bitwise_func import set_state, cell_count
from node import Node, replay
from game import init_bit_board

In [6]:
def decode_game(moves):
    """Decode moves played in a game from the board representation to the bitboard representation."""
    enemy, own = init_bit_board(BOARD_SIZE)
    node = Node(None, own, enemy, -1, BOARD_SIZE, -1)
    for move in moves:
        if move == 0:
            break
        node.expand() # Generate the possible moves
        x, y = decode_move(move)
        move = set_state(0, x, y, BOARD_SIZE)
        
        if move not in node.moves: # then it means it is a pass and the other player plays, or it is the end of the game
            node.invert()
            node.expand()
            if move in node.moves:
                node = node.set_child(move)
            else:
                break
        else:
            node = node.set_child(move)
    return node

            
def decode_move(move):
    """Decode a move from the board representation to the (x, y) representation."""
    return move // 10 - 1, move % 10 - 1

In [7]:
true_score, game_moves = next(full_reader)
print(f"Expected score: {true_score}")
first_game = decode_game(game_moves)
# replay(first_game, BOARD_SIZE)
print(f"Score : {cell_count(first_game.own_pieces), cell_count(first_game.enemy_pieces)}") if first_game.turn == -1 else print(f"Score : {cell_count(first_game.enemy_pieces), cell_count(first_game.own_pieces)}")
while true_score in [cell_count(first_game.own_pieces), cell_count(first_game.enemy_pieces)]:
    true_score, game_moves = next(full_reader)
    first_game = decode_game(game_moves)
print(f"Expected score: {true_score}")
print(f"Score : {cell_count(first_game.own_pieces), cell_count(first_game.enemy_pieces)}") if first_game.turn == -1 else print(f"Score : {cell_count(first_game.enemy_pieces), cell_count(first_game.own_pieces)}")
# replay(first_game, BOARD_SIZE)

Expected score: 52
Score : (52, 12)
Expected score: 64
Score : (63, 0)


True score is the number of pieces of the black player + the empty ones if he won, or only the number of pieces if he lost. The score is always given from the black player's perspective.

#### Data Preprocessing

In [8]:
import numpy as np
import pickle

def node_to_board(node: Node) -> np.ndarray:
    """Convert a Node object to a numpy array."""
    board = np.zeros((BOARD_SIZE, BOARD_SIZE))
    for i in range(BOARD_SIZE):
        for j in range(BOARD_SIZE):
            if node.own_pieces & (1 << (i * BOARD_SIZE + j)):
                board[i, j] = node.turn
            elif node.enemy_pieces & (1 << (i * BOARD_SIZE + j)):
                board[i, j] = -node.turn
    return board

def bitboardMove_to_x_y(move: int) -> (int, int):
    """Convert a move from the bitboard representation to the (x, y) representation."""
    for i in range(BOARD_SIZE):
        for j in range(BOARD_SIZE):
            if (move & (1 << (i * BOARD_SIZE + j))) != 0:
                return i, j
    return -1, -1

def find_move(current_node: Node, next_node: Node) -> (int, int):
    """Find the move that was played between two nodes."""
    for move in current_node.moves:
        if current_node.set_child(move) == next_node:
            # convert the binary move to (x, y) representation
            return bitboardMove_to_x_y(move)
    return -1, -1

def dump_data(directory, output_file_black, output_file_white, batch_size=1000):
    data_black = []
    data_white = []
    batch_count_black = 0
    batch_count_white = 0
    i = 0
    data_reader = read_all_wtb_files(directory)
    

    for (score, moves) in data_reader:
        game = decode_game(moves)
        move_list = replay(game, BOARD_SIZE, False)
        for j in range(len(move_list) - 1):
            current_node = move_list[j]
            next_node = move_list[j + 1]
            current_board = node_to_board(current_node)
            next_move = find_move(current_node, next_node)
            
            if next_move == (-1, -1):
                continue
            
            # Append Black moves to the data of the black player if he won, and to the data of the white player otherwise
            if score >= 32:
                data_black.append((current_board, next_move))
            else:
                data_white.append((current_board, next_move))

            if len(data_black) == batch_size:
                with open(f'{output_file_black}_batch_{batch_count_black}.pkl', 'wb') as f:
                    pickle.dump(data_black, f)
                print(f"Dumped batch {batch_count_black} to file after processing {i + 1} games, total {len(data_black)} samples.")
                data_black = []  # Reset data for the next batch
                batch_count_black += 1
            if len(data_white) == batch_size:
                with open(f'{output_file_white}_batch_{batch_count_white}.pkl', 'wb') as f:
                    pickle.dump(data_white, f)
                print(f"Dumped batch {batch_count_white} to file after processing {i + 1} games, total {len(data_white)} samples.")
                data_white = []  # Reset data for the next batch
                batch_count_white += 1
        i += 1

    # Dump any remaining data not fitting the batch size
    if data_black:
        with open(f'{output_file_black}_batch_{batch_count_black}.pkl', 'wb') as f:
            pickle.dump(data_black, f)
        print(f"Dumped final batch {batch_count_black} to file, total {len(data_black)} samples.")
    if data_white:
        with open(f'{output_file_white}_batch_{batch_count_white}.pkl', 'wb') as f:
            pickle.dump(data_white, f)
        print(f"Dumped final batch {batch_count_white} to file, total {len(data_white)} samples.")

In [9]:
# dump_data('data/raw', 'data/black/data', 'data/white/data', 1000)

In [10]:
black_data_size = 3847 * 1000
white_data_size = 3682 * 1000
print(f"Black data size: {black_data_size}")
print(f"White data size: {white_data_size}")

Black data size: 3847000
White data size: 3682000


In [11]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import glob
import pickle


# Dataset class for Othello (credits to https://github.com/zatomos for coming up with this absolute masterpiece of a name)
class Othelload(Dataset):
    def __init__(self, file_list, nb_samples_by_file):
        self.file_list = file_list
        self.nb_samples_by_file = nb_samples_by_file
        
    def __len__(self):
        return len(self.file_list) * self.nb_samples_by_file
    
    def __getitem__(self, index):
        try:
            # Get the corresponding file
            file_index = index // self.nb_samples_by_file
            file = self.file_list[file_index]
            # Get the corresponding sample
            sample_index = index % self.nb_samples_by_file
            with open(file, 'rb') as f:
                data = pickle.load(f)
            board, move = data[sample_index]
            # convert the board to a 32 float tensor. Convert the move (x,y) to one hot encoding
            board = torch.tensor(board, dtype=torch.float32).unsqueeze(0)
            move = torch.tensor(move[0] * 8 + move[1], dtype=torch.long)
            return board, move
        except Exception as e:
            print(f"Error while generating pair (sample, label) at index {index}\n{e}")
            raise e
    
# Create the dataset
file_list_black = glob.glob('data/black/data*.pkl')
file_list_white = glob.glob('data/white/data*.pkl')
print(len(file_list_black), len(file_list_white))
othelload_black = Othelload(file_list_black, nb_samples_by_file=1000)
othelload_white = Othelload(file_list_white, nb_samples_by_file=1000)

# Create the dataloaders
dataloader_black = DataLoader(othelload_black, batch_size=8, shuffle=True)
dataloader_white = DataLoader(othelload_white, batch_size=8, shuffle=True)

3847 3682


#### Model Definition

In [12]:
# Imports
import torch.nn as nn
import torch.nn.functional as F

# Model definition : Let's start simple, just a CNN with 8 conv layers (+BN and Relu) and 2 FC layers
# conv64 → conv64 → conv128 → conv128 → conv256 → conv256 → conv256 → conv256 → f c128 → f c60
class OthelloNet(nn.Module):
    def __init__(self):
        super(OthelloNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(256)
        self.conv8 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn8 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 64)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))
        x = x.view(-1, 256 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [13]:
from torchsummary import summary
summary(OthelloNet().cuda(), (1, 8, 8), batch_size=8)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1              [8, 64, 8, 8]             640
       BatchNorm2d-2              [8, 64, 8, 8]             128
            Conv2d-3              [8, 64, 8, 8]          36,928
       BatchNorm2d-4              [8, 64, 8, 8]             128
            Conv2d-5             [8, 128, 8, 8]          73,856
       BatchNorm2d-6             [8, 128, 8, 8]             256
            Conv2d-7             [8, 128, 8, 8]         147,584
       BatchNorm2d-8             [8, 128, 8, 8]             256
            Conv2d-9             [8, 256, 8, 8]         295,168
      BatchNorm2d-10             [8, 256, 8, 8]             512
           Conv2d-11             [8, 256, 8, 8]         590,080
      BatchNorm2d-12             [8, 256, 8, 8]             512
           Conv2d-13             [8, 256, 8, 8]         590,080
      BatchNorm2d-14             [8, 25

In [14]:
# Define criterion, optimizer and scheduler
import torch.optim as optim

# Define the model
model = OthelloNet()

# Define the criterion
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define the scheduler : decrease the learning rate by a factor of 10 every 10 epochs
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [15]:
# Create Training, Validation, and Test sets
from torch.utils.data.sampler import SubsetRandomSampler

# Define the percentage of samples for each set
validation_split = 0.1
test_split = 0.15

# Define the indices
dataset_size = len(othelload_black)
indices = list(range(dataset_size))
split1 = int(np.floor(validation_split * dataset_size))
split2 = int(np.floor(test_split * dataset_size))
np.random.shuffle(indices)
train_indices, val_indices, test_indices = indices[split2:], indices[:split1], indices[split1:split2]

# Define the samplers
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)

# Define the dataloaders
train_loader = DataLoader(othelload_black, batch_size=8, sampler=train_sampler)
valid_loader = DataLoader(othelload_black, batch_size=8, sampler=valid_sampler)
test_loader = DataLoader(othelload_black, batch_size=8, sampler=test_sampler)

# Display the number of samples in each set
print(f"Number of samples in the training set: {len(train_indices)}")
print(f"Number of samples in the validation set: {len(val_indices)}")
print(f"Number of samples in the test set: {len(test_indices)}")

Number of samples in the training set: 3269950
Number of samples in the validation set: 384700
Number of samples in the test set: 192350


In [16]:
# Create a class to save best model and track the loss
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [17]:
import torch
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Move the model to the chosen device
model = model.to(device)
model.train()

n_epochs = 100
early_stopping = EarlyStopping(patience=10, verbose=True, path='data/model.pth')

for epoch in range(n_epochs):
    running_loss = 0.0
    # Adding tqdm for progress tracking in training loop
    train_loader_pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{n_epochs} Training")
    for i, (boards, moves) in enumerate(train_loader_pbar):
        boards = boards.to(device)  # Move boards to GPU
        moves = moves.to(device)    # Move moves to GPU

        if any(moves < 0) or any(moves >= 64):
            print(f"Invalid move: {moves}")
            continue

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(boards)
        loss = criterion(outputs, moves)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        # Update tqdm progress bar
        train_loader_pbar.set_postfix(loss=running_loss / (i + 1))

    # Validation loss
    model.eval()
    val_loss = 0.0
    # Adding tqdm for progress tracking in validation loop
    valid_loader_pbar = tqdm(valid_loader, desc=f"Epoch {epoch + 1}/{n_epochs} Validation")
    with torch.no_grad():
        for i, (boards, moves) in enumerate(valid_loader_pbar):
            boards = boards.to(device)  # Move boards to GPU
            moves = moves.to(device)    # Move moves to GPU

            if any(moves < 0) or any(moves >= 64):
                print(f"Invalid move: {moves}")
                continue

            outputs = model(boards)
            loss = criterion(outputs, moves)
            val_loss += loss.item()
            # Update tqdm progress bar
            valid_loader_pbar.set_postfix(val_loss=val_loss / (i + 1))

    val_loss /= len(valid_loader)
    print(f"Epoch {epoch + 1}, Training loss: {running_loss / len(train_loader)}, Validation loss: {val_loss}")

    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

    model.train()  # Ensure model is in training mode after evaluation

Using device: cuda


Epoch 1/100 Training:   6%|▌         | 23887/408744 [16:02<4:18:34, 24.81it/s, loss=4.08] 


KeyboardInterrupt: 

In [None]:
# Test the model
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
test_loss = 0.0
with torch.no_grad():
    for i, (boards, moves) in enumerate(test_loader):
        if any(moves < 0) or any(moves >= 64):
            print(f"Invalid move: {moves}")
            continue
        outputs = model(boards)
        loss = criterion(outputs, moves)
        test_loss += loss.item()
    test_loss /= len(test_loader)
    print(f"Test loss: {test_loss}")