In [None]:
!pip install python-chess cairosvg

In [None]:
!wget https://raw.githubusercontent.com/EmilGou/A-Chess-Transformer/main/uci_moves.py

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
path = "/content/drive/MyDrive/1.moves"

moves = open(path, "r").read()
moves = moves.split('\n\n')[:-1]
GAMES = [m.split('\n')[:-1] for m in moves]

In [None]:
from uci_moves import UCI_MOVES

In [None]:
import chess
import chess.svg
from IPython.display import SVG, display

def uci_moves_to_fen(moves: list[str], show_board: bool = False) -> str:
    """
    Given a list of UCI moves, returns the resulting FEN and optionally shows the board.
    Args:
        moves: List of moves in UCI format, e.g., ['e2e4', 'e7e5']
        show_board: If True, display the board with IPython SVG (for notebooks)
    Returns:
        FEN string of the final board position
    """
    board = chess.Board()
    try:
        for move in moves:
            board.push_uci(move)
    except ValueError as e:
        raise ValueError(f"Invalid move '{move}': {e}")

    if show_board:
        display(SVG(chess.svg.board(board=board)))

    return board.fen()


In [None]:
import chess

UCI_IDS = {v: k for k, v in UCI_MOVES.items()}

# 2) BUILD FEN VOCABULARY (covers full FEN)
FEN_CHARS = [
    '/', ' ', '-',                             # separators & dash
    'P','N','B','R','Q','K',
    'p','n','b','r','q','k',                   # pieces
    '0','1','2','3','4','5','6','7','8','9',    # digits for counters
    'a','b','c','d','e','f','g','h',            # files (en passant targets)
    'w'
]
FEN_CHAR_TO_ID = {c: i + len(UCI_MOVES) for i, c in enumerate(FEN_CHARS)}
ID_TO_FEN_CHAR = {v: k for k, v in FEN_CHAR_TO_ID.items()}

# Compute next available index
max_idx = max(FEN_CHAR_TO_ID.values())

# ——————————————————————————————————————————————————————————————————————
# 3) SPECIAL TOKENS
SPECIAL_TOKENS = {
    "<board>":   max_idx + 1,
    "</board>":  max_idx + 2,
    "<moves>":   max_idx + 3,
    "</moves>":  max_idx + 4,
    "<pad>":     max_idx + 5,
}
ID_TO_SPECIAL = {v: k for k, v in SPECIAL_TOKENS.items()}

# ——————————————————————————————————————————————————————————————————————
# 4) TOKENIZERS / UNTOKENIZER
def tokenize_fen(fen: str) -> list[int]:
    """
    Turn the full FEN string (all 6 fields) into token IDs,
    one per character, dropping only chars not in our vocab.
    """
    return [FEN_CHAR_TO_ID[c] for c in fen if c in FEN_CHAR_TO_ID]

def tokenize_uci(moves: list[str]) -> list[int]:
    return [UCI_MOVES[m] for m in moves if m in UCI_MOVES]

def untokenize(tokens: list[int]) -> list[str]:
    out = []
    for t in tokens:
        if t in ID_TO_SPECIAL:
            out.append(ID_TO_SPECIAL[t])
        elif t in ID_TO_FEN_CHAR:
            out.append(ID_TO_FEN_CHAR[t])
        elif t in UCI_IDS:
            out.append(UCI_IDS[t])
        else:
            out.append(f"<unk:{t}>")
    return out


In [None]:
class ChessGameDataset(torch.utils.data.Dataset):
    def __init__(self, games: list[list[str]], max_seq_len: int):
        self.games = games
        self.max_seq_len = max_seq_len
        self.pad_token = SPECIAL_TOKENS["<pad>"]

    def __len__(self):
        return len(self.games)

    def __getitem__(self, idx):
        # 1) split and pick random cut
        moves = [m.lower() for m in self.games[idx]]
        cutoff = random.randint(1, len(moves) - 1)
        past, future = moves[:cutoff], moves[cutoff:]

        # 2) build FEN from past
        board = chess.Board()
        for m in past:
            board.push_uci(m)
        fen = board.fen()  # full 6-field FEN

        # 3) tokenize
        fen_tokens    = tokenize_fen(fen)
        future_tokens = tokenize_uci(future)

        # 4) assemble input + labels
        input_seq = (
            [SPECIAL_TOKENS["<board>"]] +
            fen_tokens +
            [SPECIAL_TOKENS["</board>"],
             SPECIAL_TOKENS["<moves>"]] +
            future_tokens +
            [SPECIAL_TOKENS["</moves>"]]
        )
        moves_start = len(fen_tokens) + 2  # <board>, </board>
        labels = (
            [self.pad_token] * (moves_start + 1) +  # pad up through <moves>
            future_tokens +
            [self.pad_token]                     # do not predict </moves>
        )

        # 5) truncate & pad
        input_seq = (input_seq + [self.pad_token] * self.max_seq_len)[:self.max_seq_len]
        labels    = (labels    + [self.pad_token] * self.max_seq_len)[:self.max_seq_len]

        return torch.tensor(input_seq, dtype=torch.long), \
               torch.tensor(labels,    dtype=torch.long)


In [None]:
import random
from torch.utils.data import DataLoader

# 1) Fix your random seed for reproducibility
random.seed(42)

# 2) Shuffle indices and split
n = len(GAMES)
indices = list(range(n))
random.shuffle(indices)
split = int(n * 0.8)
train_idx, test_idx = indices[:split], indices[split:]

# 3) Slice out train / test game‐lists
train_games = [GAMES[i] for i in train_idx]
test_games  = [GAMES[i] for i in test_idx]

# 4) Create datasets
max_len = 196
train_ds = ChessGameDataset(train_games, max_seq_len=max_len)
test_ds  = ChessGameDataset(test_games,  max_seq_len=max_len)

# 5) Create loaders
bsz = 32
train_loader = DataLoader(train_ds, batch_size=bsz, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=bsz, shuffle=False)

# Visualize
for idx, (batch, labels) in enumerate(train_loader):
    print(batch[0], labels[0])
    print("Decoded batch:")
    print(untokenize(batch[0].tolist()))
    print("Decoded labels:")
    print(untokenize(labels[0].tolist()))
    if idx == 1:
      break

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

def make_custom_attention_mask(input_ids: torch.Tensor) -> torch.Tensor:
    """
    Constructs a (B, S, S) attention mask:
    - Bidirectional within <board>...</board>
    - Causal within <moves>...</moves>
    - No attention for <pad>
    """
    B, S = input_ids.shape
    mask = torch.full((B, S, S), float('-inf'), device=input_ids.device)

    for b in range(B):
        row = input_ids[b]

        board_start = (row == SPECIAL_TOKENS["<board>"]).nonzero(as_tuple=False)
        board_end = (row == SPECIAL_TOKENS["</board>"]).nonzero(as_tuple=False)
        moves_start = (row == SPECIAL_TOKENS["<moves>"]).nonzero(as_tuple=False)
        moves_end = (row == SPECIAL_TOKENS["</moves>"]).nonzero(as_tuple=False)

        # Bidirectional attention inside <board>...</board>
        if len(board_start) and len(board_end):
            s, e = board_start.item(), board_end.item()
            mask[b, s:e+1, s:e+1] = 0

        # Causal attention inside <moves>...</moves>
        if len(moves_start):
            s = moves_start.item()
            e = moves_end.item() if len(moves_end) else S - 1
            e = min(e, S - 1)  # clamp
            for i in range(s, e + 1):
                mask[b, i, s:i+1] = 0

        # Block <pad> tokens
        pad_mask = row == SPECIAL_TOKENS["<pad>"]
        mask[b, pad_mask, :] = float('-inf')
        mask[b, :, pad_mask] = float('-inf')

    return mask


class AutoregressiveTransformer(nn.Module):
    """A minimal decoder‑only (causal) Transformer with **learned** positional embeddings.

    ‑ ``batch_first`` everywhere
    ‑ uses ``nn.Embedding`` for positions instead of sinusoidal features
    ‑ retains a simple causal mask for autoregressive modelling
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        num_layers: int = 6,
        d_ff: int = 2048,
        max_len: int = 512,
        dropout: float = 0.1,
        pad_id: int = 0,
    ) -> None:
        super().__init__()
        self.pad_id = pad_id
        self.max_len = max_len

        # Embeddings
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.n_heads = n_heads

        # Transformer blocks (encoder‑layer reused for decoder‑only stack)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

        # Language‑model head shares weight matrix with token_emb if desired (not tied here)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Initialize position embedding with small variance
        nn.init.normal_(self.pos_emb.weight, mean=0.0, std=0.02)

    # ‑‑‑ utility masks ‑‑‑ -------------------------------------------------
    def _causal_mask(self, size: int, device: torch.device) -> torch.Tensor:
        """Upper‑triangular causal mask (True means *masked* for PyTorch)."""
        return torch.triu(torch.ones(size, size, dtype=torch.bool, device=device), diagonal=1)

    # ‑‑‑ forward pass ‑‑‑ ---------------------------------------------------
    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
      bsz, seq_len = tokens.shape
      if seq_len > self.max_len:
          raise ValueError(f"Sequence length {seq_len} exceeds max_len {self.max_len}.")

      pos = torch.arange(seq_len, device=tokens.device).unsqueeze(0).expand(bsz, seq_len)
      x = self.token_emb(tokens) + self.pos_emb(pos)

      attn_mask = self._causal_mask(size=seq_len, device=tokens.device)# make_custom_attention_mask(tokens)  # (B, S, S)
      # attn_mask = attn_mask.repeat_interleave(self.n_heads, dim=0)

      pad_mask = tokens.eq(self.pad_id)               # (B, S)

      x = self.transformer(x, mask=attn_mask, src_key_padding_mask=pad_mask)
      return self.lm_head(x)

In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW

vocab_size = max(SPECIAL_TOKENS.values()) + 1
model = AutoregressiveTransformer(vocab_size=vocab_size, pad_id=SPECIAL_TOKENS['<pad>'], d_model=1_024, d_ff=4_096, num_layers=8, max_len=256+1).cuda()
optimizer = AdamW(model.parameters(), lr=1e-4)

c = 0
for pp in model.parameters():
    c += pp.numel()
print("Total parameters:", c)

model.train()

In [None]:
import torch
import chess
import chess.svg
import chess.pgn
import torch.nn.functional as F
import cairosvg
import imageio.v2 as imageio
import os
import random
import base64
from typing import Optional
from tempfile import TemporaryDirectory
from IPython.display import SVG, display, Video, HTML

# ——————————————————————————————————————————————————————————————————————
# 1) FULL-FEN TOKENIZER
#    consumes every character in the 6-field FEN string
def tokenize_fen(fen: str) -> list[int]:
    """
    Takes the full FEN, e.g.:
      "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
    and returns a list of character-IDs for every character that
    appears in our FEN_CHAR_TO_ID vocabulary.
    """
    return [FEN_CHAR_TO_ID[c] for c in fen if c in FEN_CHAR_TO_ID]

# ——————————————————————————————————————————————————————————————————————
# 2) SAMPLE + LOG-PROB DIAGNOSTIC
@torch.no_grad()
def sample_model_moves(model, dataset, max_moves=10, temperature=1.0, top_k=10):
    model.eval()
    device = next(model.parameters()).device

    # pick random
    idx = random.randrange(len(dataset))
    input_ids, labels = dataset[idx]
    input_ids = input_ids.unsqueeze(0).to(device)
    labels   = labels.unsqueeze(0).to(device)

    # extract full-fen from input_ids
    bs = input_ids[0]
    b0 = (bs == SPECIAL_TOKENS["<board>"]).nonzero(as_tuple=True)[0].item() + 1
    b1 = (bs == SPECIAL_TOKENS["</board>"]).nonzero(as_tuple=True)[0].item()
    fen_ids = bs[b0:b1].tolist()
    # rebuild full-fen string
    fen_chars = [ID_TO_FEN_CHAR[i] for i in fen_ids]
    fen_full = "".join(fen_chars)
    # now fen_full contains e.g. "rnbqkbnr/... w KQkq - 0 1"
    board = chess.Board(fen=fen_full)

    display(SVG(chess.svg.board(board=board, size=350)))
    print("📍 Full FEN:", fen_full, "\n")

    # show tokens
    print("Input tokens: ", " ".join(untokenize(input_ids[0].tolist())))
    print("Label tokens:", " ".join(untokenize(labels[0].tolist())), "\n")

    # teacher-forcing loss
    inp  = input_ids[:, :-1]
    targ = labels[:,  1:]
    logits = model(inp)                   # (1, S-1, V)
    logp   = F.log_softmax(logits, dim=-1)

    ll = []
    for i in range(targ.size(1)):
        tid = targ[0, i].item()
        if tid == SPECIAL_TOKENS["<pad>"]:
            continue
        p = logp[0, i, tid].item()
        print(f" Step {i:2d} P({untokenize([tid])[0]}) = {math.exp(p):.4f}  logp={p:.4f}")
        ll.append(p)
    if ll:
        avg = sum(ll)/len(ll)
        print(f"\n→ mean NLL: {-avg:.4f}, ppl: {math.exp(-avg):.2f}\n")
    else:
        print("⚠️ no valid labels for loss\n")

    # sample future moves
    ms = SPECIAL_TOKENS["<moves>"]
    me = SPECIAL_TOKENS["</moves>"]
    start = (bs==ms).nonzero(as_tuple=True)[0].item()+1
    gen = bs[:start].tolist()

    for _ in range(max_moves):
        x = torch.tensor(gen, device=device).unsqueeze(0)
        lg = model(x)[0,-1,:]/temperature
        if top_k:
            v,i = torch.topk(lg, top_k)
            pr = F.softmax(v, dim=0)
            nxt = i[torch.multinomial(pr,1)].item()
        else:
            pr = F.softmax(lg, dim=-1)
            nxt = torch.multinomial(pr,1).item()
        if nxt in (me, SPECIAL_TOKENS["<pad>"]):
            break
        gen.append(nxt)

    sampled = [UCI_IDS[t] for t in gen[start:] if t in UCI_IDS]
    actual  = [UCI_IDS[t] for t in labels[0,start:].tolist() if t in UCI_IDS]
    print("🔮 Sampled:", " ".join(sampled))
    print("✅ Ground truth:", " ".join(actual))
    return sampled

# ——————————————————————————————————————————————————————————————————————
# 3) SAMPLE FULL GAME → MP4 AT CORRECT FEN
@torch.no_grad()
def sample_game_to_video(
    model,
    max_moves: int = 50,
    temperature: float = 1.0,
    top_k: int = 10,
    video_path: str = "sample_game.mp4",
    frame_duration: float = 1.2
) -> Optional[chess.pgn.Game]:
    model.eval()
    device = next(model.parameters()).device

    # start from initial
    board = chess.Board()
    # you can also push a first move, e.g. board.push_uci("e2e4")

    # build initial input
    fen_ids = tokenize_fen(board.fen())
    seq = ([SPECIAL_TOKENS["<board>"]] +
           fen_ids +
           [SPECIAL_TOKENS["</board>"], SPECIAL_TOKENS["<moves>"]])
    gen = torch.tensor(seq, device=device).unsqueeze(0)[0].tolist()

    game = chess.pgn.Game()
    last = None

    fps = 1.0/frame_duration
    with TemporaryDirectory() as tmp:
        frames = []
        # frame 0
        svg0 = chess.svg.board(board=board, size=350, lastmove=last)
        p0   = os.path.join(tmp, "f000.svg")
        png0 = os.path.join(tmp, "f000.png")
        open(p0,"w").write(svg0)
        cairosvg.svg2png(url=p0, write_to=png0)
        frames.append(png0)

        for i in range(1, max_moves+1):
            x = torch.tensor(gen, device=device).unsqueeze(0)
            lg = model(x)[0,-1,:]/temperature
            if top_k:
                v,iid = torch.topk(lg, top_k)
                pr     = F.softmax(v, dim=0)
                tok    = iid[torch.multinomial(pr,1)].item()
            else:
                pr  = F.softmax(lg,dim=-1)
                tok = torch.multinomial(pr,1).item()
            if tok in (SPECIAL_TOKENS["</moves>"], SPECIAL_TOKENS["<pad>"]):
                break
            gen.append(tok)
            if tok not in UCI_IDS:
                print("⚠️ unk", tok); break
            m = chess.Move.from_uci(UCI_IDS[tok])
            if not board.is_legal(m):
                print("⛔ illegal", UCI_IDS[tok]); break
            board.push(m); last=m; game.add_variation(m)

            svgi = chess.svg.board(board=board, size=350, lastmove=last)
            pi   = os.path.join(tmp, f"f{i:03}.svg")
            pngi = os.path.join(tmp, f"f{i:03}.png")
            open(pi,"w").write(svgi)
            cairosvg.svg2png(url=pi, write_to=pngi)
            frames.append(pngi)

        # write mp4 with correct fps
        writer = imageio.get_writer(video_path, format="ffmpeg", fps=fps)
        for p in frames:
            writer.append_data(imageio.imread(p))
        writer.close()

    display(Video(video_path, embed=True, html_attributes="controls autoplay loop"))
    print(f"✅ Video saved @ {fps:.2f} fps → {video_path}")
    return game


In [None]:
import torch
import chess
import chess.pgn
import chess.svg
import torch.nn.functional as F
import cairosvg
import imageio.v2 as imageio
import os
import random
import base64
from typing import Optional
from tempfile import TemporaryDirectory
from IPython.display import Video, display

@torch.no_grad()
def sample_game_masked(
    model,
    max_moves: int = 50,
    temperature: float = 1.0,
    video_path: str = "sample_game_masked.mp4",
    frame_duration: float = 1.2  # seconds per frame
) -> Optional[chess.pgn.Game]:
    """
    Samples a full game, but at each step:
      • Enumerates board.legal_moves
      • Converts them to your UCI token IDs
      • Masks out all other logits
      • Samples from the remaining legal‐move distribution
    Saves as MP4 at fps = 1/frame_duration.
    """
    model.eval()
    device = next(model.parameters()).device

    board = chess.Board()
    # Optionally seed first move:
    # board.push_uci("e2e4")

    # Build initial context
    fen_ids = tokenize_fen(board.fen())
    seq = [SPECIAL_TOKENS["<board>"]] + fen_ids + [
        SPECIAL_TOKENS["</board>"], SPECIAL_TOKENS["<moves>"]
    ]
    generated = seq.copy()
    game = chess.pgn.Game()
    last_move = None

    fps = 1.0 / frame_duration

    with TemporaryDirectory() as tmpdir:
        frames = []
        # frame 0: starting position
        svg0 = chess.svg.board(board=board, size=350, lastmove=last_move)
        p0   = os.path.join(tmpdir, "frame_000.svg")
        png0 = os.path.join(tmpdir, "frame_000.png")
        open(p0, "w").write(svg0)
        cairosvg.svg2png(url=p0, write_to=png0)
        frames.append(png0)

        for i in range(1, max_moves+1):
            x = torch.tensor(generated, device=device).unsqueeze(0)
            logits = model(x)[0, -1, :] / temperature  # (V,)

            # build mask over vocab, -inf for illegal moves
            legal_ids = []
            for mv in board.legal_moves:
                uci = mv.uci()
                if uci in UCI_MOVES:
                    legal_ids.append(UCI_MOVES[uci])
            if not legal_ids:
                print("⛔ No legal moves tokenized – stopping.")
                break

            mask = torch.full_like(logits, float('-inf'))
            mask[legal_ids] = 0.0
            filtered_logits = logits + mask

            probs = F.softmax(filtered_logits, dim=-1)
            token = torch.multinomial(probs, num_samples=1).item()

            # stop if end-of-moves or pad
            if token in (SPECIAL_TOKENS["</moves>"], SPECIAL_TOKENS["<pad>"]):
                break
            generated.append(token)

            # apply move
            uci = UCI_IDS.get(token, None)
            if uci is None:
                print(f"⚠️ Generated unknown token {token}.")
                break
            move = chess.Move.from_uci(uci)
            if not board.is_legal(move):
                print(f"⛔ Illegal generated move: {uci}.")
                break

            board.push(move)
            last_move = move
            game.add_variation(move)

            # render frame
            svg_i = chess.svg.board(board=board, size=350, lastmove=last_move)
            pi   = os.path.join(tmpdir, f"frame_{i:03}.svg")
            pngi = os.path.join(tmpdir, f"frame_{i:03}.png")
            open(pi, "w").write(svg_i)
            cairosvg.svg2png(url=pi, write_to=pngi)
            frames.append(pngi)

        # write MP4
        with imageio.get_writer(video_path, format='ffmpeg', fps=fps) as writer:
            for p in frames:
                writer.append_data(imageio.imread(p))

    display(Video(video_path, embed=True, html_attributes="controls autoplay loop"))
    print(f"✅ Masked‐sampling video saved to {video_path} @ {fps:.2f} fps")
    return game


In [None]:
import os
CHECKPOINT_DIR = '/content/drive/MyDrive/chess_checkpoints'
name = 'v1'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
os.listdir(CHECKPOINT_DIR)

In [None]:
# CHECKPOINT_PATH = f"{CHECKPOINT_DIR}/chess_v1_vocab_size={vocab_size-1}_pad_id={SPECIAL_TOKENS['<pad>']}_d_model=1_024_d_ff=4_096_num_layers=8_latest.pt"
CHECKPOINT_PATH = f"{CHECKPOINT_DIR}/chess_v1_vocab_size=2008_pad_id=2006_d_model=1_024_d_ff=4_096_num_layers=8_latest.pt"
ckpt = torch.load(CHECKPOINT_PATH, map_location='cuda' if torch.cuda.is_available() else 'cpu')

model.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['opt_state'])
start_epoch = ckpt['epoch'] + 1
last_loss   = ckpt['loss']

device = torch.device('cuda')
model.to(device)         # e.g. device = torch.device('cuda')
print(f"✅ Loaded checkpoint from epoch {ckpt['epoch']}, loss={last_loss:.4f}")

In [None]:
_ = sample_game_to_video(model, max_moves=200, frame_duration=0.5, top_k=5)

In [None]:
_ = sample_game_masked(model,
                       max_moves=200,
                       temperature=1.0,
                       frame_duration=0.5,
                       video_path="chess_masked.mp4")

In [None]:
import torch
import torch.nn.functional as F
import chess
import chess.svg
import chess.pgn
import cairosvg
import imageio.v2 as imageio
import os
import random
from tempfile import TemporaryDirectory
from IPython.display import SVG, Video, display

@torch.no_grad()
def sample_game_alpha(
    model,
    alpha: Optional[int] = None,     # reset every alpha moves; None → never reset
    max_moves: int = 50,
    temperature: float = 1.0,
    top_k: Optional[int] = 10,
    video_path: str = "sample_game_alpha.mp4",
    frame_duration: float = 1.2
) -> chess.pgn.Game:
    """
    Sample up to max_moves, resetting context every `alpha` moves.
    If alpha is None or >= max_moves, context is never reset (full-history).
    """
    model.eval()
    device = next(model.parameters()).device

    # Initialize board + PGN
    board = chess.Board()
    game = chess.pgn.Game()
    last_move = None

    # Helper that builds the current context from the board
    def make_ctx():
        fen_ids = tokenize_fen(board.fen())
        return [
            SPECIAL_TOKENS["<board>"],
            *fen_ids,
            SPECIAL_TOKENS["</board>"],
            SPECIAL_TOKENS["<moves>"]
        ]

    generated = make_ctx()
    moves_this_block = 0

    fps = 1.0 / frame_duration
    with TemporaryDirectory() as tmpdir:
        frames = []

        # --- Frame 0: starting position ---
        svg0 = chess.svg.board(board=board, size=350, lastmove=last_move)
        p0   = os.path.join(tmpdir, "f000.svg")
        png0 = os.path.join(tmpdir, "f000.png")
        open(p0, "w").write(svg0)
        cairosvg.svg2png(url=p0, write_to=png0)
        frames.append(png0)

        for step in range(max_moves):
            # reset context if we've reached the block size
            if alpha is not None and moves_this_block >= alpha:
                generated = make_ctx()
                moves_this_block = 0

            # model forward
            x = torch.tensor(generated, device=device).unsqueeze(0)  # (1, L)
            logits = model(x)[0, -1, :] / temperature              # (V,)

            # gather only legal moves
            legal_ids = [UCI_MOVES[m.uci()] for m in board.legal_moves if m.uci() in UCI_MOVES]
            if not legal_ids:
                print("⛔ no legal moves left – stopping.")
                break

            # sample from legal logits directly
            legal_logits = logits[legal_ids]                       # (L,)
            probs        = F.softmax(legal_logits, dim=-1)         # (L,)

            # optionally cap to top_k among the legal moves
            if top_k is not None and top_k < probs.size(0):
                vals, idxs = torch.topk(probs, top_k)
                probs      = vals
                legal_ids  = [legal_ids[i] for i in idxs.tolist()]

            choice = torch.multinomial(probs, num_samples=1).item()
            token  = legal_ids[choice]

            # stop on special end‐of‐moves tokens
            if token in (SPECIAL_TOKENS["</moves>"], SPECIAL_TOKENS["<pad>"]):
                break

            # append, apply to board, record PGN
            generated.append(token)
            uci = UCI_IDS[token]
            move = chess.Move.from_uci(uci)
            if not board.is_legal(move):
                print(f"⛔ illegal move {uci} – stopping.")
                break
            board.push(move)
            game.add_variation(move)
            last_move = move
            moves_this_block += 1

            # render new frame
            svg_i = chess.svg.board(board=board, size=350, lastmove=last_move)
            pi    = os.path.join(tmpdir, f"f{step+1:03}.svg")
            png_i = os.path.join(tmpdir, f"f{step+1:03}.png")
            open(pi, "w").write(svg_i)
            cairosvg.svg2png(url=pi, write_to=png_i)
            frames.append(png_i)

        # write out MP4
        writer = imageio.get_writer(video_path, format='ffmpeg', fps=fps)
        for p in frames:
            writer.append_data(imageio.imread(p))
        writer.close()

    display(Video(video_path, embed=True, html_attributes="controls autoplay loop"))
    print(f"✅ Video saved @ {fps:.2f} fps → {video_path}")
    return game


In [None]:
# never reset (current behavior)
sample_game_alpha(model, alpha=None, frame_duration=0.5)

# reset every move
sample_game_alpha(model, alpha=1, frame_duration=0.5)

# reset every 4 moves
sample_game_alpha(model, alpha=4, frame_duration=0.5)

In [None]:
for epoch in range(start_epoch,100):
    for step, (x, y) in enumerate(train_loader):
        x = x.cuda(); y = y.cuda()
        input_seq = x[:, :-1]
        target_seq = y[:, 1:]

        logits = model(input_seq)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.reshape(-1), ignore_index=SPECIAL_TOKENS['<pad>'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print(f"Epoch {epoch} Step {step} | Loss: {loss.item():.4f}")

        if step % 500 == 0:
          print('No masking:')
          _ = sample_game_to_video(model, max_moves=200, frame_duration=0.5, top_k=5)
          print("Masking:")
          _ = sample_game_masked(model,
                       max_moves=200,
                       temperature=1.0,
                       frame_duration=0.5,
                       video_path="chess_masked.mp4")


    CHECKPOINT_PATH = f'{CHECKPOINT_DIR}/chess_{name}_vocab_size={vocab_size}_pad_id={SPECIAL_TOKENS["<pad>"]}_d_model=1_024_d_ff=4_096_num_layers=8_latest.pt'
    torch.save({
        'epoch':      epoch,
        'model_state': model.state_dict(),
        'opt_state':  optimizer.state_dict(),
        'loss':       loss,
    }, CHECKPOINT_PATH)
    print(f"✅ Checkpoint saved to {CHECKPOINT_PATH}")
    if epoch % 5 == 0:
      CHECKPOINT_PATH = f'{CHECKPOINT_DIR}/chess_{name}_vocab_size={vocab_size}_pad_id={SPECIAL_TOKENS["<pad>"]}_d_model=1_024_d_ff=4_096_num_layers=8_epoch={epoch}.pt'
      torch.save({
          'epoch':      epoch,
          'model_state': model.state_dict(),
          'opt_state':  optimizer.state_dict(),
          'loss':       loss,
      }, CHECKPOINT_PATH)
      print(f"✅ Checkpoint saved to {CHECKPOINT_PATH}")



In [None]:

# 3) Save a checkpoint
# Assuming you have:
#   model   -> your nn.Module
#   optimizer -> your optimizer (e.g. AdamW)
#   epoch   -> current epoch number (int)
#   loss    -> last loss value (float)

torch.save({
    'epoch':      epoch,
    'model_state': model.state_dict(),
    'opt_state':  optimizer.state_dict(),
    'loss':       loss,
}, CHECKPOINT_PATH)
print(f"✅ Checkpoint saved to {CHECKPOINT_PATH}")

# 4) Later, to load it back:

In [None]:
SPECIAL_TOKENS.update({
    "<think>" : max(SPECIAL_TOKENS.values()) + 1,
    "</think>": max(SPECIAL_TOKENS.values()) + 2,
})
ID_TO_SPECIAL.update({v: k for k, v in SPECIAL_TOKENS.items()})


In [None]:
from transformers import PreTrainedTokenizerBase
import json, os, re

ALL_IDS = {
    **UCI_MOVES,                      # 0 … |UCI|
    **FEN_CHAR_TO_ID,                 # cont’d
    **SPECIAL_TOKENS                  # cont’d
}
ID_TO_TOKEN = {v: k for k, v in ALL_IDS.items()}

class ChessTokenizer(PreTrainedTokenizerBase):
    def __init__(self):
        super().__init__(
            pad_token="<pad>",
            eos_token="</moves>",      # generation finishes here
        )
        self.vocab = ALL_IDS
        self.inv_vocab = ID_TO_TOKEN

    # ----- required HF methods --------------------------------------------
    def _tokenize(self, text):
        # space-separated already – just split
        return text.strip().split()

    def _convert_token_to_id(self, token):
        return self.vocab[token]

    def _convert_id_to_token(self, idx):
        return self.inv_vocab[idx]

    def get_vocab(self):
        return self.vocab

    # ----- helper ----------------------------------------------------------
    def encode(self, s, **kw):
        return [self.vocab[t] for t in self._tokenize(s)]

    def decode(self, ids, **kw):
        return " ".join([self.inv_vocab[i] for i in ids
                         if i != self.vocab["<pad>"]])


In [None]:
list(ID_TO_TOKEN.items())[-5:]

In [None]:
class RLThinkDataset(torch.utils.data.Dataset):
    """
    Returns only the prompt prefix; the trainer will call .generate() to
    produce <think> … </think> <moves> MOVE </moves>.
    """
    def __init__(self, games, max_len):
        self.games, self.max_len = games, max_len
        self.pad = SPECIAL_TOKENS["<pad>"]

    def __len__(self): return len(self.games)

    def __getitem__(self, idx):
        moves = [m.lower() for m in self.games[idx]]
        # choose any *legal* random prefix so the position isn’t terminal
        cutoff = random.randint(0, len(moves) - 1)
        board  = chess.Board()
        for m in moves[:cutoff]:
            board.push_uci(m)
        fen_tokens = tokenize_fen(board.fen())

        prompt_ids = (
            [SPECIAL_TOKENS["<board>"]] +
            fen_tokens +
            [SPECIAL_TOKENS["</board>"],
             SPECIAL_TOKENS["<think>"]]       # leave open tag
        )

        prompt_ids = prompt_ids[:self.max_len]
        attn_mask  = [1]*len(prompt_ids)
        pad_needed = self.max_len - len(prompt_ids)
        prompt_ids += [self.pad]*pad_needed
        attn_mask  += [0]*pad_needed

        return {
            "input_ids": torch.tensor(prompt_ids),
            "attention_mask": torch.tensor(attn_mask),
        }


In [None]:
# ▸ run this in a Colab code cell
!sudo apt-get update -y
!sudo apt-get install -y stockfish   # installs v16-series in <30 s

In [None]:
SPECIAL_TOKENS['<think>'] = max(SPECIAL_TOKENS.values()) + 1
SPECIAL_TOKENS['</think>'] = max(SPECIAL_TOKENS.values()) + 2

In [None]:
import os, shutil, chess, chess.engine, torch
os.environ["PATH"] += ":/usr/games"

assert shutil.which("stockfish")  # sanity check

engine = chess.engine.SimpleEngine.popen_uci("stockfish")
board  = chess.Board()
print(engine.analyse(board, chess.engine.Limit(depth=10)))
engine.close()

In [None]:
next()

In [None]:
import chess.engine

STOCKFISH_PATH = "/usr/local/bin/stockfish"   # adjust as needed
engine = chess.engine.SimpleEngine.popen_uci("stockfish")

# regex to grab 1st UCI move between <moves> … </moves>
MOVE_RE = re.compile(r"<moves>\s*([a-h][1-8][a-h][1-8][qrbn]?)\s*</moves>")
tokenizer = ChessTokenizer()
def reward_fn(batch_prompts, batch_outputs, limit=0.5, depth=12):
    """
    vectorised reward: positive if predicted move improves Score(cp)
    wrt the side to move, clipped to [-limit, +limit] and normalised.
    """
    rewards = []
    for prompt_ids, generated_ids in zip(batch_prompts, batch_outputs):
        prompt = tokenizer.decode(prompt_ids.tolist())
        full   = tokenizer.decode(generated_ids.tolist())

        # 1) recover FEN from prompt
        fen_txt = prompt.split("</board>")[0].split("<board>")[1].strip()
        board   = chess.Board(" ".join(fen_txt))  # join chars back to string

        # 2) extract move
        m = MOVE_RE.search(full)
        if m is None or not board.is_legal(chess.Move.from_uci(m[1])):
            rewards.append(torch.tensor(-limit))   # illegal – punish
            continue

        move = chess.Move.from_uci(m[1])

        # 3) Stockfish evaluations
        with engine.analysis(board, chess.engine.Limit(depth=depth)) as info:
            base_eval = info.info["score"].pov(board.turn).white().score(mate_score=10000)
        board.push(move)
        with engine.analysis(board, chess.engine.Limit(depth=depth)) as info:
            after_eval = info.info["score"].pov(board.turn).white().score(mate_score=10000)

        # positive if position improves for side to move
        delta = (after_eval - base_eval) / 100.0          # centipawns → pawns
        delta = max(min(delta,  limit), -limit)           # clip
        rewards.append(torch.tensor(delta))
    return torch.stack(rewards)
engine.close()

In [None]:
d = RLThinkDataset(GAMES, max_len)

In [None]:
import random, torch, chess, chess.svg, pprint, textwrap
from IPython.display import SVG, display

# ──────────────────────────────────────────────────────────────────────────────
# 0.  Verify FEN tokeniser round-trip ─────────────────────────────────────────
fen_example = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
ids        = tokenize_fen(fen_example)
fen_back   = "".join(ID_TO_FEN_CHAR[i] for i in ids)
print(fen_example)
print(fen_back)
assert fen_back == fen_example, "FEN round-trip failed!"
print("✅ Tokeniser round-trips FEN correctly.\n")

# ──────────────────────────────────────────────────────────────────────────────
# 1.  Helper: rebuild FEN *purely from IDs* ───────────────────────────────────
def fen_from_prompt_ids(prompt_ids: torch.Tensor) -> str:
    """Read token IDs between <board> and </board> and convert to FEN string."""
    ids = prompt_ids.tolist()
    try:
        s = ids.index(SPECIAL_TOKENS["<board>"]) + 1
        e = ids.index(SPECIAL_TOKENS["</board>"])
    except ValueError as err:
        raise ValueError("<board> or </board> missing") from err
    char_ids = ids[s:e]
    fen = "".join(ID_TO_FEN_CHAR[i] for i in char_ids)
    return fen

# ──────────────────────────────────────────────────────────────────────────────
# 2.  Visual sanity check on one dataset sample ───────────────────────────────
def visual_debug(idx: int | None = None):
    if idx is None:
        idx = random.randrange(len(train_ds))

    sample      = d[idx]
    prompt_ids  = sample["input_ids"]
    prompt_txt  = tokenizer.decode(prompt_ids.tolist(), skip_special_tokens=False)
    print(prompt_txt)
    fen_str     = fen_from_prompt_ids(prompt_ids)
    board       = chess.Board(fen_str)

    # fabricate a dummy CoT & legal move so reward_fn has something to chew on
    legal_move  = random.choice(list(board.legal_moves)).uci()
    cot_txt     = " e1e8"
    generated   = f"{prompt_txt}{cot_txt} </think> <moves> {legal_move} </moves>"
    print((generated))
    output_ids  = torch.tensor(tokenizer.encode(generated))

    reward_val  = reward_fn(prompt_ids.unsqueeze(0),
                            output_ids.unsqueeze(0)).item()

    # ── display ──────────────────────────────────────────────────────────────
    print(f"\nDataset index {idx}")
    print("Prompt tokens (decoded):")
    print(textwrap.fill(prompt_txt, width=90), "\n")
    print("Generated dummy output:")
    print(textwrap.fill(generated, width=90), "\n")
    print(f"Reward = {reward_val:+.3f}\n")
    display(SVG(chess.svg.board(board=board)))

visual_debug()                # run; pass idx=42 to inspect a specific sample


In [None]:
# --- 1.  make a tiny HF config ---------------------------------------------
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput

class ChessConfig(PretrainedConfig):
    model_type = "chess_transformer"
    def __init__(self, vocab_size: int, pad_token_id: int = 0, **kwargs):
        super().__init__(pad_token_id=pad_token_id, **kwargs)
        self.vocab_size = vocab_size                    # needed by generate()

# --- 2.  hug-friendly wrapper around your pure-torch model ------------------
class HFChessModel(PreTrainedModel, GenerationMixin):
    config_class = ChessConfig

    def __init__(self, inner: AutoregressiveTransformer):
        cfg = ChessConfig(vocab_size=inner.token_emb.num_embeddings,
                          pad_token_id=inner.pad_id)
        super().__init__(cfg)
        self.inner = inner                             # keep real net inside

    # ---- the two methods GRPOTrainer cares about ---------------------------
    def forward(self, input_ids, attention_mask=None, **kwargs):
        logits = self.inner(input_ids)                 # (B,S,V)
        return CausalLMOutput(logits=logits)

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

# --- 3.  build policy / reference and launch trainer ------------------------
policy      = HFChessModel(model).cuda()
ref_policy  = HFChessModel(model).cuda()
ref_policy.load_state_dict(policy.state_dict())        # frozen copy

from trl import GRPOConfig, GRPOTrainer
args = GRPOConfig(output_dir="ckpts/chess-grpo", logging_steps=10)
'''
trainer = GRPOTrainer(
    model=policy,
    ref_model=ref_policy,
    args=args,
    tokenizer=my_tokenizer,       # can be a dummy PreTrainedTokenizerFast
    train_dataset=prompts_ds,     # any Dataset yielding {"prompt": str}
    reward_funcs=my_reward_fn     # or list of funcs
)
trainer.train()
'''

In [None]:
!pip install trl