# Chess Move Legality Classification Using The BERT Model

In this notebook I will attempt to fine-tune the BERT model for classification of legality of the chess moves. The model will be presented a chess-game reresentation in the form of the FEN (Forsyth–Edwards Notation) and a move in the form of a string like e2e3 and will be forced to classify the move as legal or illegal. 

<b> Note </b> This notebook was run in Colab, for reproducibility please follow the below described steps

This notebook was run in Google Colab, if you want to reproduce the results of training the model, please upload this notebook, connect to a GPU, download the data available in the google drive as linked to in the project's README.md, and upload it to a NLP-Chess_Data directory in your drive, then it should work as expected. If you would like to run it locally some changes may be necessary.

## Installs and Imports

In [72]:
!pip install transformers



In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
import pandas as pd
from pandas import DataFrame
import numpy as np
from torch.optim import AdamW
import warnings
from typing import Dict, Callable, Tuple
from torch.optim.optimizer import Optimizer
from torch.nn import Module
from torch.nn.modules.loss import _Loss
from google.colab import drive

In [74]:
warnings.filterwarnings("ignore")

## Data Drive Mounting and Constants Declarations

In [75]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [76]:
EPOCHS = 40
BATCH_SIZE = 16
PADDING = 82
main_directory = "/content/drive/MyDrive/NLP-Chess_Data"
train_path = "NLP_Train.csv.gz"
val_path = "NLP_Val.csv.gz"
test_path = "NLP_Test.csv.gz"
model_path = "BERT_Classifier.pth"

## Functions

In [77]:
def preprocess(df: pd.DataFrame, board_col: str = "prev_board",
               move_col: str = "move") -> pd.DataFrame:
    '''
    Preprocess the DataFrame by combining the previous
    board state and the move into a single
    input for the model.

    Parameters:
    - df (DataFrame): DataFrame containing game data.
    - board_col (str): Name of the column containing
    the previous board state. Default is "prev_board".
    - move_col (str): Name of the column containing the
    move. Default is "move".

    Returns:
    DataFrame: Preprocessed DataFrame with a new column "model_input"
    containing combined board state and move.
    '''
    df["model_input"] = df[board_col] + " [SEP] " + df[move_col]
    new_input = []
    for input_text in df["model_input"]:
        i = 0
        temp_str = ""
        while i < len(input_text):
            if input_text[i] != " " and input_text[i] != "[":
                temp_str += input_text[i] + " "
            elif input_text[i] == "[":
                temp_str += "[SEP] "
                i += 4
            i += 1
        new_input.append(temp_str)
    df["model_input"] = new_input
    return df

In [78]:
def transform_BERT_input(BERT_input: Dict[str, torch.Tensor],
                         device: torch.device) -> Dict[str, torch.Tensor]:
    '''
    Transform BERT input tensors to the specified device.

    Parameters:
    - BERT_input (dict): Dictionary containing input tensors for BERT
    model (input_ids, token_type_ids, attention_mask).
    - device (torch.device): Device to which tensors should be
    moved.

    Returns:
    dict: Transformed BERT input tensors on the specified
    device.
    '''
    batch_size = BERT_input["input_ids"].shape[0]
    padding_size = BERT_input["input_ids"].shape[-1]
    BERT_input["input_ids"] = BERT_input["input_ids"]\
        .view(batch_size, padding_size).to(device)
    BERT_input["token_type_ids"] = BERT_input["token_type_ids"].\
        view(batch_size, padding_size).to(device)
    BERT_input["attention_mask"] = BERT_input["attention_mask"].\
        view(batch_size, padding_size).to(device)
    return BERT_input

In [79]:
class Chess_Dataset(Dataset):
    '''
    Custom dataset class for chess data.

    Parameters:
    - X (DataFrame): Input data.
    - y (array-like): Target labels.
    - tokenizer (callable): Tokenizer function to encode input text.
    - padding (int): Maximum length of the input sequence after padding.
    '''
    def __init__(self, X: DataFrame, y: np.ndarray,
                 tokenizer: Callable, padding: int):
        self.tokenizer = tokenizer
        self.X = X.values.reshape(-1)
        self.y = y.astype(np.int8)
        self.padding = padding

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        Get an item from the dataset.

        Parameters:
        - idx (int): Index of the item to retrieve.

        Returns:
        tuple: Tuple containing the encoded input and the label tensor.
        '''
        input_text = self.X[idx]
        label = self.y[idx]
        encoded_input = tokenizer(input_text, padding='max_length',
                                  max_length=self.padding, return_tensors='pt')
        label_tensor = torch.tensor(label)
        return encoded_input, label_tensor

In [80]:
def train(epoch: int, model: Module,
          optimizer: Optimizer, train_loader: DataLoader,
          criterion: _Loss, num_examples: int, device: str) -> None:
    '''
    Train the model for one epoch.

    Parameters:
    - epoch (int): Current epoch number.
    - model (torch.nn.Module): Model to train.
    - optimizer (torch.optim.Optimizer): Optimizer for
    updating model parameters.
    - train_loader (torch.utils.data.DataLoader): DataLoader
    for training data.
    - criterion (torch.nn.Module): Loss function.
    - num_examples (int): Total number of examples in the
    dataset.
    - device (str): Device to use for training.

    Returns:
    None
    '''
    model.train()
    total_loss = 0
    total_accuracy = 0
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids, labels = batch
        input_ids = transform_BERT_input(input_ids, device)
        labels.to(device)
        input_ids.to(device)
        outputs = model(**input_ids)
        outputs = outputs.logits.to(device)
        loss = criterion(outputs.squeeze(-1), labels.float().to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        binarized_outputs = nn.Sigmoid()(outputs.detach().squeeze(-1))
        binarized_outputs = binarized_outputs.cpu().numpy()
        binarized_outputs[binarized_outputs >= 0.5] = 1
        binarized_outputs[binarized_outputs < 0.5] = 0
        binarized_outputs = binarized_outputs.astype(np.int8)
        labels_coped = labels.squeeze(-1).detach().numpy()
        total_accuracy += \
            sum(binarized_outputs == labels_coped.astype(np.int8))
    print(f'Epoch: {epoch} Training loss: \
          {round(total_loss/len(train_loader), 4)} \
          Training accuracy: {round((total_accuracy/num_examples)*100, 4)}%')

In [81]:
def evaluate(epoch: int, model: Module,
             test_loader: DataLoader, criterion: _Loss,
             num_examples: int, device: str,
             loss_type: str = "Validation") -> None:
    '''
    Evaluate the model on validation or test data.

    Parameters:
    - epoch (int): Current epoch number.
    - model (torch.nn.Module): Model to evaluate.
    - test_loader (torch.utils.data.DataLoader): DataLoader
    for test data.
    - criterion (torch.nn.Module): Loss function.
    - num_examples (int): Total number of
    examples in the dataset.
    - device (str): Device to use for evaluation.
    - loss_type (str): Type of loss being evaluated
    (default is "Validation").

    Returns:
    None
    '''
    model.eval()
    total_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for batch in test_loader:
            optimizer.zero_grad()
            input_ids, labels = batch
            transform_BERT_input(input_ids, device)
            labels.to(device)
            input_ids.to(device)
            outputs = model(**input_ids)
            outputs = outputs.logits.to(device)
            loss = criterion(outputs.squeeze(-1),
                             labels.float().to(device))
            total_loss += loss.item()
            binarized_outputs = nn.Sigmoid()(outputs.detach().squeeze(-1))
            binarized_outputs = binarized_outputs.cpu().numpy()
            binarized_outputs[binarized_outputs >= 0.5] = 1
            binarized_outputs[binarized_outputs < 0.5] = 0
            binarized_outputs = binarized_outputs.astype(np.int8)
            labels_coped = labels.squeeze(-1).detach().numpy()
            total_accuracy += \
                sum(binarized_outputs == labels_coped.astype(np.int8))
    print(f'Epoch {epoch}: {loss_type} \
          loss: {round(total_loss/len(test_loader), 4)} {loss_type} \
          accuracy: {round((total_accuracy/num_examples)*100, 4)}%')

## Implementation

### Data Reading

In [82]:
train_data = pd.read_csv(f"{main_directory}/{train_path}", compression="gzip")
val_data = pd.read_csv(f"{main_directory}/{val_path}", compression="gzip")
test_data = pd.read_csv(f"{main_directory}/{test_path}", compression="gzip")

In [83]:
train_data = train_data.head(10000)
val_data = val_data.head(500)
test_data = test_data.head(500)

In [84]:
X_train = train_data.\
    drop(columns=list(set(train_data.columns) - set(["prev_board", "move"])))
y_train = train_data["legal"].values.astype(bool)
X_val = val_data.\
    drop(columns=list(set(train_data.columns) - set(["prev_board", "move"])))
y_val = val_data["legal"].values.astype(bool)
X_test = test_data.\
    drop(columns=list(set(train_data.columns) - set(["prev_board", "move"])))
y_test = test_data["legal"].\
    values.astype(bool)
del train_data
del val_data
del test_data
X_train.head()

Unnamed: 0,move,prev_board
0,g7d8,4Q2k/p4Qrp/1p6/8/3R3p/8/PPP3PP/5RK1 b - - 0 37
1,f3f2,8/6p1/5p1p/1p1p1P2/1P1K1kP1/P4P1P/8/8 b - - 0 38
2,a6g1,4rb2/5kp1/pp1p1pp1/2pP4/P1P3R1/1PB3PP/5PK1/8 w...
3,c8b8,2r2k2/p4p2/2B2p2/P2Pp3/4PnR1/4KP1p/5P1P/8 b - ...
4,e6f5,6k1/6pp/p3p3/3bKR2/1p6/1P2R3/P7/6r1 b - - 0 33


In [85]:
X_train = preprocess(X_train)
X_val = preprocess(X_val)
X_test = preprocess(X_test)
X_train = X_train.drop(columns=["move", "prev_board"])
X_val = X_val.drop(columns=["move", "prev_board"])
X_test = X_test.drop(columns=["move", "prev_board"])
X_train.head()

Unnamed: 0,model_input
0,4 Q 2 k / p 4 Q r p / 1 p 6 / 8 / 3 R 3 p / 8 ...
1,8 / 6 p 1 / 5 p 1 p / 1 p 1 p 1 P 2 / 1 P 1 K ...
2,4 r b 2 / 5 k p 1 / p p 1 p 1 p p 1 / 2 p P 4 ...
3,2 r 2 k 2 / p 4 p 2 / 2 B 2 p 2 / P 2 P p 3 / ...
4,6 k 1 / 6 p p / p 3 p 3 / 3 b K R 2 / 1 p 6 / ...


Making Sure the data split is mostly even

In [86]:
unique_values, counts = np.unique(y_train, return_counts=True)
for value, count in zip(unique_values, counts):
    print(f"{value}: {count} occurrences")

False: 4830 occurrences
True: 5170 occurrences


In [87]:
unique_values, counts = np.unique(y_val, return_counts=True)
for value, count in zip(unique_values, counts):
    print(f"{value}: {count} occurrences")

False: 225 occurrences
True: 275 occurrences


In [88]:
unique_values, counts = np.unique(y_test, return_counts=True)
for value, count in zip(unique_values, counts):
    print(f"{value}: {count} occurrences")

False: 236 occurrences
True: 264 occurrences


### Model Creation

In [89]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
train_dataset = Chess_Dataset(X_train, y_train, tokenizer, PADDING)
val_dataset = Chess_Dataset(X_val, y_val, tokenizer, PADDING)
test_dataset = Chess_Dataset(X_test, y_test, tokenizer, PADDING)

In [90]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [95]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Available Device: {device}")
model = BertForSequenceClassification.from_pretrained(
    "bert-base-cased", num_labels=1,
    output_attentions=False,
    output_hidden_states=False)
model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=1e-5)

Available Device: cuda


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Training

In [96]:
num_training_examples = len(X_train)
num_val_examples = len(X_val)
num_test_examples = len(X_test)

In [97]:
for epoch in range(EPOCHS):
    train(epoch, model, optimizer, train_loader,
          criterion, num_training_examples, device)
    evaluate(epoch, model, val_loader, criterion,
             num_val_examples, device)

Epoch: 0 Training loss: 0.6264 Training accuracy: 63.36%
Epoch 0: Validation loss: 0.465 Validation accuracy: 78.8%
Epoch: 1 Training loss: 0.4657 Training accuracy: 78.84%
Epoch 1: Validation loss: 0.3792 Validation accuracy: 85.6%
Epoch: 2 Training loss: 0.3874 Training accuracy: 83.94%
Epoch 2: Validation loss: 0.3373 Validation accuracy: 87.6%
Epoch: 3 Training loss: 0.3501 Training accuracy: 85.73%
Epoch 3: Validation loss: 0.3285 Validation accuracy: 86.8%
Epoch: 4 Training loss: 0.3203 Training accuracy: 86.67%
Epoch 4: Validation loss: 0.328 Validation accuracy: 87.2%
Epoch: 5 Training loss: 0.2965 Training accuracy: 87.89%
Epoch 5: Validation loss: 0.3181 Validation accuracy: 88.2%
Epoch: 6 Training loss: 0.2751 Training accuracy: 88.68%
Epoch 6: Validation loss: 0.3028 Validation accuracy: 88.8%
Epoch: 7 Training loss: 0.2564 Training accuracy: 89.04%
Epoch 7: Validation loss: 0.3215 Validation accuracy: 90.0%
Epoch: 8 Training loss: 0.2368 Training accuracy: 90.04%
Epoch 8: 

In [98]:
torch.save(model, f"{main_directory}/{model_path}")

### Evaluation

In [99]:
model = torch.load(f"{main_directory}/{model_path}")

In [100]:
evaluate(0, model, test_loader, criterion,
         num_test_examples, device, loss_type="Test")

Epoch Test: Test loss: 0.7671 Test accuracy: 85.0%
