# Training example
Example of training a neural network model on chess dataset (with FEN formatting).

In [None]:
import numpy as np
import pytorch_lightning as pl
import torch
import os
import chess

pl.seed_everything(100)
os.chdir("../")

from pretrain.utils.data import ChessDataModule
from pretrain.utils.board import board_to_tensor, INT_TO_UCI_MAP, UCI_TO_INT_MAP
from pretrain.utils.preprocess import PreprocessTensorDataset
from pretrain.utils.dataset import DirectTensorDataset
from luna.luna import Luna_Network
from luna.game import ChessGame
from pytorch_lightning import Trainer, LightningModule, Callback
from pytorch_lightning.loggers import TensorBoardLogger
from stockfish import Stockfish
from typing import Dict
from torch import nn, optim
from tqdm.notebook import tqdm

In [None]:
data_module = ChessDataModule(
    data_dir='pretrain/data/torch_10p',
    batch_size=1024,
    num_workers=20,  # Don't use workers as it copies dataset and has os.chdir implications
    preprocessing=[  # For e.g. FEN dataset preprocessing has to be done during batch creation
        PreprocessTensorDataset(),  # Converts to standard and flexible dataset representation
    ],
    transforms=[
        # Here space for augmentation etc. Operates on results of preprocessing.
    ],
    dataset_class=DirectTensorDataset,
)

In [None]:
net = Luna_Network(
    ChessGame()
)
net.nnet.init_weights()
print(net.nnet)

## Training info

The model is trained with 4 losses. The policy, value, L2 and entropy. 

The policy loss is CrossEntropy
between output policy and target label. The label is used instead of probability distribution as dataset
move feedback is binary and in such way calculations are slightly faster.

The value loss is Mean Squared Error Loss. Common approach.

The L2 loss is to force the weights into following the normal distribution, in consequence making the model not 
exploit specific false patterns too much. Added to prevent overfitting.

The entropy loss is taken from PPO (Proximal Policy Optimisation). It penalizes too low entropy of the model.
It is here to ensure model, will not learn to always have minimal entropy in its predictions. In chess there
is always more options than one correct move.



In [None]:
class StockfishTestCallback(Callback):
    """
    A PyTorch Lightning callback that tests your model against Stockfish
    with a specified ELO rating at regular intervals during training.
    """

    def __init__(
        self,
        stockfish_path: str,
        test_frequency: int = 1000,
        elo_rating: int = 1500,
        num_test_games: int = 50
    ):
        super().__init__()
        self.stockfish_path = stockfish_path
        self.test_frequency = test_frequency
        self.elo_rating = elo_rating
        self.num_test_games = num_test_games

    def on_train_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, output, batch, batch_idx):
        """Run our custom test periodically during training"""
        if batch_idx % self.test_frequency == 0:
            # Run test against Stockfish and log results
            results = self._test_against_stockfish(pl_module)

            # Log metrics
            trainer.logger.log_metrics({
                "stockfish_test/win_rate": results["win_rate"],
                "stockfish_test/draw_rate": results["draw_rate"],
                "stockfish_test/loss_rate": results["loss_rate"],
                "stockfish_test/evaluation": results["evaluation"],
            }, step=trainer.global_step)

    def _test_against_stockfish(self, pl_module):
        """Test the model against Stockfish with a specific ELO rating"""
        # Configure Stockfish with the specified ELO rating
        params = {
            "UCI_LimitStrength": True,
            "UCI_Elo": self.elo_rating,
            "Skill Level": 0,  # Will be adjusted based on ELO
            "Threads": 1,      # Single thread for consistent performance
            "Hash": 16         # Small hash size (MB)
        }

        stockfish = Stockfish(path=self.stockfish_path, parameters=params)

        # Track game results
        wins = 0
        draws = 0
        losses = 0
        evaluation = 0
        total_evals = 0

        # Play test games
        for game_idx in tqdm(range(self.num_test_games), desc="Evaluating against Stockfish", unit="game", leave=False, position=2):
            # Reset board to starting position
            stockfish.set_position([])
            board = chess.Board()

            # Determine who plays white (alternate)
            model_plays_white = game_idx % 2 == 0

            # Play until game is over
            bar = tqdm(desc="Playing game", unit="moves", leave=False, position=3)
            while not board.is_game_over():
                is_model_turn = (board.turn == chess.WHITE) == model_plays_white

                if is_model_turn:
                    # Get model's move
                    before = stockfish.get_evaluation()
                    move = pl_module.predict_move(board)

                    # Update both boards
                    board.push(move)
                    stockfish.make_moves_from_current_position([move.uci()])
                    after = stockfish.get_evaluation()

                    if before["type"] == "cp" and after["type"] == "cp":
                        evaluation += -after["value"] - before["value"]
                        total_evals += 1
                else:
                    # Get Stockfish's move
                    # best_move = stockfish.get_best_move()
                    # move = chess.Move.from_uci(best_move)
                    best_move = np.random.choice(list(board.legal_moves))

                    # Update both boards
                    board.push(best_move)
                    stockfish.make_moves_from_current_position([best_move.uci()])
                bar.update(1)
            bar.close()

            # Record game result
            result = board.outcome()
            if result is None:
                draws += 1
            elif (result.winner == chess.WHITE) == model_plays_white:
                wins += 1
            else:
                losses += 1

        # Calculate metrics
        total_games = wins + draws + losses
        return {
            "win_rate": wins / total_games if total_games > 0 else 0,
            "draw_rate": draws / total_games if total_games > 0 else 0,
            "loss_rate": losses / total_games if total_games > 0 else 0,
            'evaluation': evaluation / total_evals,
        }

In [None]:
class ExampleNetLightning(LightningModule):

    def __init__(self, model: Luna_Network, l2_lambda: float, entropy_lambda: float):
        super().__init__()
        self.model = model.nnet
        self.luna = model
        self.l2_lambda = l2_lambda
        self.entropy_lambda = entropy_lambda

    def training_step(self, batch: Dict):
        target_value, label = batch["value"], batch["label"]
        boardAndValid =  batch["state"], batch["mask"]

        policy, value = self.model(boardAndValid)
        
        # Standard loss
        loss_policy = nn.functional.cross_entropy(policy.clone(), label.long(),
                                                  ignore_index=UCI_TO_INT_MAP["Terminal"]).mean()
        loss_value = nn.functional.mse_loss(value.flatten(), target_value.flatten()).mean()
        loss_l2 = self.l2_lambda * torch.mean(sum(torch.norm(param, 2) ** 2 for param in self.model.parameters()))

        
        loss = loss_policy + loss_value + loss_l2

        self.log('train_loss_policy', loss_policy.clone(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_loss_value', loss_value.clone(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_loss', loss.clone(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_loss_l2', loss_l2.clone(), on_step=True, on_epoch=True)
        
        return loss

    @torch.no_grad()
    def predict_move(self, board: chess.Board):
        model.eval()
        state = board_to_tensor(board)
        repetitions = 0
        while not board.is_repetition(count=repetitions):
            repetitions += 1
        repetitions = torch.tensor(repetitions, dtype=torch.int8)
        clock = torch.tensor(len(board.move_stack), dtype=torch.int16)
        castling_rights = torch.tensor([
            int(board.has_kingside_castling_rights(chess.WHITE)),
            int(board.has_queenside_castling_rights(chess.WHITE)),
            int(board.has_kingside_castling_rights(chess.BLACK)),
            int(board.has_queenside_castling_rights(chess.BLACK)),
        ], dtype=torch.int8)
        state = PreprocessTensorDataset.one_hot_encoding({
            'state': state, 'clock': clock, 'castling_rights': castling_rights, 'repetitions': repetitions
        }).to('cuda')
        policy, value = self.model((state.unsqueeze(0), torch.ones((len(INT_TO_UCI_MAP),), dtype=torch.int8).unsqueeze(0).to('cuda')))
        best_moves = torch.argsort(policy.squeeze(), descending=True)
        model.train()
        for move in best_moves:
            decoded = INT_TO_UCI_MAP[move.cpu().item()]
            if decoded == "Terminal":
                continue
            uci_move = chess.Move.from_uci(decoded)
            if uci_move in board.legal_moves:
                return uci_move
        raise RuntimeError("There is no legal move.")

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(), lr=1e-3)

In [None]:
model = ExampleNetLightning(net, l2_lambda=1e-4, entropy_lambda=1e-4)

In [None]:
trainer = Trainer(
    max_epochs=3,
    logger=TensorBoardLogger(
        save_dir="tensorboard",
        name="luna_training",
    ),
    log_every_n_steps=50,
    accelerator="gpu",
    precision="16-mixed",
    callbacks=[StockfishTestCallback("pretrain/data/stockfish/stockfish-windows-x86-64-avx2.exe", 5_000, 600, 10)],
    deterministic=True,
    accumulate_grad_batches=1,
)

In [None]:
trainer.fit(model, data_module)

In [None]:
torch.save(model, "model_rtx5090.pt")

In [None]:
model = torch.load("model_rtx5090.pt", weights_only=False)

In [None]:
import time
from IPython.display import display, clear_output

model.cuda()
def test_play(color):
    board = chess.Board()
    if color == chess.BLACK:
        board.push_uci("e2e4")
    while not board.is_game_over():
        # clear_output()
        move = model.predict_move(board)
        board.push(move)
        # display(board)
        # time.sleep(.1)
        if board.is_game_over():
            break
        correct = False
        board.push(np.random.choice(list(board.legal_moves)))
        # while not correct:
        #     try:
        #         s = input("Move in uci:")
        #         if s == "quit":
        #             return
        #         board.push_uci(s)
        #         correct = True
        #     except:
        #         continue
    return board.outcome()

outcomes = []

for i in tqdm(range(50)):
    res = test_play(i % 2)
    outcomes.append((0 if res.winner is None else (1 if res.winner == i % 2 else -1)))
print(np.mean(outcomes))

In [None]:
np.mean(np.array(outcomes) == 1)

In [None]:
np.mean(np.array(outcomes) == 0)

In [None]:
np.mean(np.array(outcomes) == -1)

In [None]:
def _test_against_stockfish(n):
        """Test the model against Stockfish with a specific ELO rating"""
        # Configure Stockfish with the specified ELO rating
        params = {
            "UCI_LimitStrength": True,
            "UCI_Elo": 200,
            "Skill Level": 0,  # Will be adjusted based on ELO
            "Threads": 1,      # Single thread for consistent performance
            "Hash": 16         # Small hash size (MB)
        }

        stockfish = Stockfish(path="pretrain/data/stockfish/stockfish-windows-x86-64-avx2.exe", parameters=params)

        # Track game results
        wins = 0
        draws = 0
        losses = 0
        evaluation = 0
        total_evals = 0

        # Play test games
        for game_idx in tqdm(range(n), desc="Evaluating against Stockfish", unit="game", leave=False, position=2):
            # Reset board to starting position
            stockfish.set_position([])
            board = chess.Board()

            # Determine who plays white (alternate)
            model_plays_white = True

            # Play until game is over
            bar = tqdm(desc="Playing game", unit="moves", leave=False, position=3)
            while not board.is_game_over():
                is_model_turn = (board.turn == chess.WHITE) == model_plays_white

                if is_model_turn:
                    # Get model's move
                    move = model.predict_move(board)
                    print(f"Model makes move {move.uci()}")

                    # Update both boards
                    board.push(move)
                    stockfish.make_moves_from_current_position([move.uci()])
                else:
                    # Get Stockfish's move
                    best_move = stockfish.get_best_move()
                    move = chess.Move.from_uci(best_move)

                    # Update both boards
                    board.push(move)
                    stockfish.make_moves_from_current_position([best_move])
                bar.update(1)

                time.sleep(0.50)
                clear_output()
                display(board)

            bar.close()

            # Record game result
            result = board.outcome()
            if result is None:
                draws += 1
            elif (result.winner == chess.WHITE) == model_plays_white:
                wins += 1
            else:
                losses += 1

        # Calculate metrics
        total_games = wins + draws + losses
        return {
            "win_rate": wins / total_games if total_games > 0 else 0,
            "draw_rate": draws / total_games if total_games > 0 else 0,
            "loss_rate": losses / total_games if total_games > 0 else 0,
        }

In [None]:
results = _test_against_stockfish(1)

In [None]:
results