In [None]:
!pip install python-chess cairosvg
!git clone https://github.com/EmilGou/RL-Chess.git
%cd RL-Chess/

In [None]:
!pip install -U -q gdown
FILE_ID="1BSBuF2dKOnVWuR5CNjp-o7QBYb-10JTO"
!gdown --id $FILE_ID -O moves

In [11]:
from GRPO.data import ChessGameDataset
from GRPO.tokenize import SPECIAL_TOKENS, untokenize
from GRPO.model import AutoregressiveTransformer, ChessConfig
from GRPO.pretrain.utils import sample_game_masked, sample_game_to_video

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW

import random
import os

In [12]:
path = "/content/moves"

moves = open(path, "r").read()
moves = moves.split('\n\n')[:-1]
GAMES = [m.split('\n')[:-1] for m in moves]

In [13]:
random.seed(42)

n = len(GAMES)
indices = list(range(n))
random.shuffle(indices)
split = int(n * 0.8)
train_idx, test_idx = indices[:split], indices[split:]

train_games = [GAMES[i] for i in train_idx]
test_games  = [GAMES[i] for i in test_idx]

max_len = 196
train_ds = ChessGameDataset(train_games, max_seq_len=max_len)
test_ds  = ChessGameDataset(test_games,  max_seq_len=max_len)

bsz = 32
train_loader = DataLoader(train_ds, batch_size=bsz, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=bsz, shuffle=False)
debug_display = False

if debug_display:
  for idx, (batch, labels) in enumerate(train_loader):
      print(batch[0], labels[0])
      print("Decoded batch:")
      print(untokenize(batch[0].tolist()))
      print("Decoded labels:")
      print(untokenize(labels[0].tolist()))
      if idx == 1:
        break

In [None]:
vocab_size = max(SPECIAL_TOKENS.values()) + 1

config = ChessConfig()
config.vocab_size = vocab_size
config.pad_id = SPECIAL_TOKENS['<pad>']
config.d_model = 1_024
config.d_ff = 4_096
config.num_layers = 8
config.max_len = 256 + 1


model = AutoregressiveTransformer(config).cuda()
optimizer = AdamW(model.parameters(), lr=1e-4)
start_epoch = 0
c = 0
for pp in model.parameters():
    c += pp.numel()
print("Total parameters:", c)

CHECKPOINT_DIR = '/content/checkpoints/'
name = 'v1'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

model.train()

In [16]:
def load_ckpt(model, optimizer, path, device='cuda'):
  ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
  model.load_state_dict(ckpt['model_state'], strict=True)
  optimizer.load_state_dict(ckpt['opt_state'])
  start_epoch = ckpt['epoch'] + 1
  last_loss   = ckpt['loss']

  model.to(device)
  return model, optimizer, start_epoch, last_loss

#CHECKPOINT_PATH = f"{CHECKPOINT_DIR}/chess_v1_vocab_size=2008_pad_id=2006_d_model=1_024_d_ff=4_096_num_layers=8_epoch=35.pt"
#model, optimizer, start_epoch, last_loss = load_ckpt(model, optimizer, CHECKPOINT_PATH)

In [None]:
for epoch in range(start_epoch,100):
    for step, (x, y) in enumerate(train_loader):
        x = x.cuda(); y = y.cuda()
        input_seq = x[:, :-1]
        target_seq = y[:, 1:]

        logits = model(input_seq)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.reshape(-1), ignore_index=SPECIAL_TOKENS['<pad>'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print(f"Epoch {epoch} Step {step} | Loss: {loss.item():.4f}")

        if step % 500 == 0:
          print('No masking:')
          _ = sample_game_to_video(model, max_moves=200, frame_duration=0.5, top_k=5)
          print("Masking:")
          _ = sample_game_masked(model,
                       max_moves=100,
                       temperature=1.0,
                       frame_duration=0.5,
                       video_path="chess_masked.mp4")


    CHECKPOINT_PATH = f'{CHECKPOINT_DIR}/chess_{name}_vocab_size={vocab_size}_pad_id={SPECIAL_TOKENS["<pad>"]}_d_model=1_024_d_ff=4_096_num_layers=8_latest.pt'
    torch.save({
        'epoch':      epoch,
        'model_state': model.state_dict(),
        'opt_state':  optimizer.state_dict(),
        'loss':       loss,
    }, CHECKPOINT_PATH)
    print(f"Checkpoint saved to {CHECKPOINT_PATH}")
    if epoch % 5 == 0:
      CHECKPOINT_PATH = f'{CHECKPOINT_DIR}/chess_{name}_vocab_size={vocab_size}_pad_id={SPECIAL_TOKENS["<pad>"]}_d_model=1_024_d_ff=4_096_num_layers=8_epoch={epoch}.pt'
      torch.save({
          'epoch':      epoch,
          'model_state': model.state_dict(),
          'opt_state':  optimizer.state_dict(),
          'loss':       loss,
      }, CHECKPOINT_PATH)
      print(f"Checkpoint saved to {CHECKPOINT_PATH}")

