In [None]:
import os
import chess
import chess.pgn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from classes import SimpleDataset

In [None]:
file_path = ''
output_file = ''

# Define piece types and colors
PIECE_TYPES = [None, chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
PIECE_COLORS = [chess.WHITE, chess.BLACK]

def encode_board(board):
    encoded = []
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        square_encoded = [0] * 13  # 12 for pieces + 1 for empty square
        
        if piece:
            index = PIECE_TYPES.index(piece.piece_type) + (piece.color * 6) - 1
            square_encoded[index] = 1
        else:
            square_encoded[-1] = 1
            
        encoded.extend(square_encoded)
    return encoded

def encode_move(move):
    encoded = [0] * 4096 #64 * 64 possible moves
    index = move.from_square * 64 + move.to_square
    encoded[index] = 1
    return encoded

def process_pgn_to_tensors(pgn_file_path, output_file_path):
    board_states = []
    moves = []
    i = 0
    with open(pgn_file_path) as pgn:
        while True:
            i += 1
            if i % 100 == 0:
                print(f"Processed {i} games")
            if i >= 1000:
                break #end early TODO: remove
            game = chess.pgn.read_game(pgn)
            if game is None:
                break  # End of PGN file
            board = game.board()
            for move in game.mainline_moves():
                board_states.append(encode_board(board))
                moves.append(encode_move(move))
                board.push(move)

    # Convert lists to tensors
    board_states_tensor = torch.tensor(board_states, dtype=torch.float32)
    moves_tensor = torch.tensor(moves, dtype=torch.float32)
    
    # Save tensors
    torch.save((board_states_tensor, moves_tensor), output_file_path)


In [None]:
process_pgn_to_tensors(file_path, output_file)

In [None]:
class ChessNet(nn.Module):
    def __init__(self):
        super(ChessNet, self).__init__()

        self.input_size = 832  # 64 squares * 13 possible states per square
        self.hidden_sizes = [1024, 1024, 2048, 2048, 4096] 
        self.output_size = 4096  # 64 starting squares * 64 destination squares
        
        self.layers = nn.ModuleList()
        
        self.layers.append(nn.Linear(self.input_size, self.hidden_sizes[0]))
        
        for i in range(len(self.hidden_sizes) - 1):
            self.layers.append(nn.Linear(self.hidden_sizes[i], self.hidden_sizes[i+1]))
        
        self.layers.append(nn.Linear(self.hidden_sizes[-1], self.output_size))

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = F.relu(layer(x))
        x = self.layers[-1](x)  
        x = F.softmax(x, dim=1)
        return x

In [None]:
train_split = 0.8
x , y = torch.load("/Users/mikil/Desktop/SideProjects/MiscML/Chess_Classification/Chess Dataset/data.pt")
x_train = x[:int(train_split*len(x))]
y_train = y[:int(train_split*len(y))]
x_test = x[int(train_split*len(x)):]
y_test = y[int(train_split*len(y)):]

train_dataset = SimpleDataset(x_train, y_train)
test_dataset = SimpleDataset(x_test, y_test)

In [None]:
epochs = 10
batch_size = 32
learning_rate = 0.0001
device = torch.device("mps")

model = ChessNet()
print(f"Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

model.to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

loss_funtion  = nn.CrossEntropyLoss()

data_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
data_test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
def compute_test_accuracy():
    total_loss = 0
    total_samples = 0
    with torch.no_grad():
        for inputs, targets in tqdm(data_test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_funtion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)  # Multiply by batch size to get total loss for this batch
            total_samples += inputs.size(0)
    
    average_loss = total_loss / total_samples
    return average_loss

compute_test_accuracy()

In [None]:
for epoch in range(epochs):
    epoch_loss = 0 
    progress_bar = tqdm(data_train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
    for inputs, targets in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_funtion(outputs, targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        average_loss = epoch_loss / (progress_bar.n + 1)
        progress_bar.set_description(f"Epoch {epoch + 1}/{epochs}, Loss: {average_loss:.4f}")

    # Validate the model
    val_loss, val_accuracy = compute_test_accuracy()
    progress_bar.set_postfix(val_loss=val_loss, val_acc=val_accuracy)

    #update the learning rate scheduler
    scheduler.step(average_loss)