# Training example
Example of training a neural network model on chess dataset (with FEN formatting).

In [1]:
import pytorch_lightning as pl
import torch
import os

pl.seed_everything(100)
os.chdir("../")

from pretrain.utils.data import ChessDataModule
from pretrain.utils.preprocess import FenDataset, LunaPreprocessing
from luna.luna import Luna_Network
from luna.game import ChessGame
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
from typing import Dict
from torch import nn, optim

Seed set to 100


In [2]:
data_module = ChessDataModule(
    data_dir='pretrain/data/out',
    batch_size=1024,
    num_workers=0,  # Don't use workers as it copies dataset and has os.chdir implications
    schema=FenDataset.Schema,
    preprocessing=[  # For e.g. FEN dataset preprocessing has to be done during batch creation
        FenDataset(),  # Converts to standard and flexible dataset representation
        LunaPreprocessing(use_mask=False),  # Converts to Luna sample
    ],
    transforms=[
        # Here space for augmentation etc. Operates on results of preprocessing.
    ],
)

In [3]:
net = Luna_Network(
    ChessGame()
)
net.nnet.init_weights()
print(net.nnet)

LunaNN(
  (conv1): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv2): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv4): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv5): Conv3d(128, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn5): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=2048, out_features=1024, bias=True)
  (fc_bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2):

## Training info

The model is trained with 4 losses. The policy, value, L2 and entropy. 

The policy loss is CrossEntropy
between output policy and target label. The label is used instead of probability distribution as dataset
move feedback is binary and in such way calculations are slightly faster.

The value loss is Mean Squared Error Loss. Common approach.

The L2 loss is to force the weights into following the normal distribution, in consequence making the model not 
exploit specific false patterns too much. Added to prevent overfitting.

The entropy loss is taken from PPO (Proximal Policy Optimisation). It penalizes too low entropy of the model.
It is here to ensure model, will not learn to always have minimal entropy in its predictions. In chess there
is always more options than one correct move.



In [4]:
class ExampleNetLightning(LightningModule):

    def __init__(self, model: Luna_Network, l2_lambda: float, entropy_lambda: float):
        super().__init__()
        self.model = model.nnet
        self.luna = model
        self.l2_lambda = l2_lambda
        self.entropy_lambda = entropy_lambda

    def training_step(self, batch: Dict):
        target_value, label = batch["value"], batch["label"]
        boardAndValid =  batch["state"], batch["mask"]

        policy, value = self.model(boardAndValid)
        
        # Standard loss
        loss_policy = nn.functional.cross_entropy(policy.clone(), label,
                                                  ignore_index=LunaPreprocessing.BAD_INDEX).mean()
        loss_value = nn.functional.mse_loss(value.flatten(), target_value.flatten()).mean()
        loss_l2 = self.l2_lambda * torch.mean(sum(torch.norm(param, 2) ** 2 for param in self.model.parameters()))

        # Compliment on target distribution being binary (like in PPO)
        loss_entropy = self.entropy_lambda * -torch.sum(policy.clone().detach().softmax(-1)
                                                        * torch.log(policy.softmax(-1) + 1e-8), dim=-1)
        loss_entropy[label == LunaPreprocessing.BAD_INDEX] = 0.0
        loss_entropy = torch.mean(loss_entropy)
        
        loss = loss_policy + loss_value + loss_l2 + loss_entropy

        self.log('train_loss_policy', loss_policy.clone(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_loss_value', loss_value.clone(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_loss', loss.clone(), on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_loss_entropy', loss_entropy.clone(), on_step=True, on_epoch=True)
        self.log('train_loss_l2', loss_l2.clone(), on_step=True, on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(), lr=0.001)

In [5]:
model = ExampleNetLightning(net, l2_lambda=1e-4, entropy_lambda=1e-3)

In [6]:
trainer = Trainer(
    max_epochs=10,
    logger=TensorBoardLogger(
        save_dir="tensorboard",
        name="luna_training",
    ),
    accelerator="gpu",
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, data_module)


  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | LunaNN | 6.1 M  | train
-----------------------------------------
6.1 M     Trainable params
0         Non-trainable params
6.1 M     Total params
24.435    Total estimated model params size (MB)
19        Modules in train mode
0         Modules in eval mode
/Users/fsociety/Programing/PyCharm/ChessRL/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
torch.save(model, "lightning_model.pt")