In [1]:
%%capture

!uv pip install tokenizers polars

In [2]:
import polars as pl

In [3]:
START_GAME = "<|g_start|>"
END_GAME = "<|g_end|>"

In [4]:
SPECIAL_TOKENS = [START_GAME, END_GAME]

In [5]:
df = pl.read_csv("../.data/chess_games_2025-01-15.csv", null_values=["None"])

In [6]:
sample = df.select("PGN").sample(n=500)

In [7]:
training_text = []

for game in sample.iter_rows():
    if game[0]:
        training_text.append(START_GAME + game[0].strip() + END_GAME)

In [8]:
import re

# ignore `1.`, ` 2.`, ` `, etc. and get the actual moves as separate entries
chunk_pattern = re.compile(r""" ?\d+\.|\. ?| ?[-\w]+|[#+]|\s+""")

In [31]:
from tokenizers import Tokenizer, Regex
from tokenizers.models import BPE
from tokenizers.normalizers import NFD
from tokenizers.pre_tokenizers import (
    Split,
    ByteLevel,
    Sequence,
    WhitespaceSplit,
    PreTokenizer,
)
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.processors import ByteLevel as ByteLevelProcessor
from tokenizers.trainers import BpeTrainer

In [46]:
tokenizer = Tokenizer(
    BPE(unk_token="[UNK]", fuse_unk=True, continuing_subword_prefix="")
)

tokenizer.normalizer = NFD()

tokenizer.pre_tokenizer = Sequence(
    [
        # WhitespaceSplit(),
        Split(pattern=Regex(r""" ?\d+\.|\. ?| ?[-\w]+|[#+]"""), behavior="isolated"),
        # TODO: figure why this adds random Ġ characters
        # everywhere when we just want to avoid spaces
        # ByteLevel(add_prefix_space=False),
    ]
)

tokenizer.post_processor = ByteLevelProcessor(trim_offsets=True)
tokenizer.decoder = ByteLevelDecoder()

trainer = BpeTrainer(vocab_size=3072, show_progress=True)

In [47]:
tokenizer.train_from_iterator([training_text], trainer=trainer)






In [48]:
sample = df.sample(1).select("PGN").item()

print(sample)

1.e4 c6 2.d4 d5 3.e5 Bf5 4.Nf3 e6 5.Be3 h6 6.c3 Nd7 7.Qb3 Qb6 8.Nbd2 Qxb3 9.axb3 c5 10.Bb5 cxd4 11.Nxd4 Bh7 12.O-O Ne7 13.f4 Nf5 14.Nxf5 Bxf5 15.Bxa7 Kd8 16.Bd4 Rc8 17.b4 Be7 18.Ra7 Rb8 19.Nb3 f6 20.Na5 fxe5 21.fxe5 Kc8 22.Bxd7+ Kxd7 23.Rxb7+ Rxb7 24.Nxb7 Kc6 25.Na5+ Kb5 26.h3 Bd8 27.Nb7 Kc6 28.Nd6 Bg6 29.b5+ Kc7 30.Ra1 Kb8 31.Ra7 Rf8 32.Rb7+ Ka8 33.Ra7+ Kb8 34.b6 Bh7 35.Rd7 Bd3 36.Rb7+ Ka8 37.Ra7+ Kb8 38.b7 Bh7 39.Ra8+ 


In [49]:
output = tokenizer.encode(sample)

print(output.ids)
print(output.tokens)

[45, 73, 211, 106, 69, 174, 107, 121, 336, 104, 95, 181, 105, 196, 207, 108, 71, 227, 109, 326, 296, 110, 357, 1064, 111, 721, 158, 112, 273, 219, 113, 231, 1485, 114, 67, 306, 115, 76, 454, 117, 821, 618, 120, 1195, 614, 122, 408, 241, 123, 133, 187, 124, 519, 259, 128, 362, 264, 129, 940, 620, 132, 423, 1066, 135, 767, 2, 1067, 136, 774, 2, 1120, 139, 1042, 500, 140, 940, 2, 1497, 143, 146, 580, 147, 1097, 500, 153, 505, 552, 156, 164, 2, 435, 162, 483, 706, 166, 519, 358, 171, 531, 2, 1499, 173, 519, 2, 706, 180, 167, 1485, 186, 541, 700, 189, 531, 2, 1499, 192, 519, 2, 706, 197, 177, 1485, 200, 639, 2, 0]
['1.', 'e4', ' c6', ' 2.', 'd4', ' d5', ' 3.', 'e5', ' Bf5', ' 4.', 'Nf3', ' e6', ' 5.', 'Be3', ' h6', ' 6.', 'c3', ' Nd7', ' 7.', 'Qb3', ' Qb6', ' 8.', 'Nbd2', ' Qxb3', ' 9.', 'axb3', ' c5', ' 10.', 'Bb5', ' cxd4', ' 11.', 'Nxd4', ' Bh7', ' 12.', 'O-O', ' Ne7', ' 13.', 'f4', ' Nf5', ' 14.', 'Nxf5', ' Bxf5', ' 15.', 'Bxa7', ' Kd8', ' 16.', 'Bd4', ' Rc8', ' 17.', 'b4', ' Be7', ' 18

In [50]:
test = tokenizer.decode(output.ids)
print(output)
print(test)
print(test == sample)

Encoding(num_tokens=126, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
1.e4 c6 2.d4 d5 3.e5 Bf5 4.Nf3 e6 5.Be3 h6 6.c3 Nd7 7.Qb3 Qb6 8.Nbd2 Qxb3 9.axb3 c5 10.Bb5 cxd4 11.Nxd4 Bh7 12.O-O Ne7 13.f4 Nf5 14.Nxf5 Bxf5 15.Bxa7 Kd8 16.Bd4 Rc8 17.b4 Be7 18.Ra7 Rb8 19.Nb3 f6 20.Na5 fxe5 21.fxe5 Kc8 22.Bxd7+ Kxd7 23.Rxb7+ Rxb7 24.Nxb7 Kc6 25.Na5+ Kb5 26.h3 Bd8 27.Nb7 Kc6 28.Nd6 Bg6 29.b5+ Kc7 30.Ra1 Kb8 31.Ra7 Rf8 32.Rb7+ Ka8 33.Ra7+ Kb8 34.b6 Bh7 35.Rd7 Bd3 36.Rb7+ Ka8 37.Ra7+ Kb8 38.b7 Bh7 39.Ra8+ 
True


In [51]:
tokenizer.decode(tokenizer.encode("1.d4 d5 2.Nf3 Bf5").ids)

'1.d4 d5 2.Nf3 Bf5'