In [2]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="yp-edu/gpt2-stockfish-debug")

config.json:   0%|          | 0.00/907 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/498M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cuda:0


In [3]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("yp-edu/gpt2-stockfish-debug")
model = AutoModelForCausalLM.from_pretrained("yp-edu/gpt2-stockfish-debug")

In [None]:
import chess
from transformers import AutoModelForCausalLM, AutoTokenizer


def next_move(model, tokenizer, fen):
    input_ids = tokenizer(f"FEN: {fen}\nMOVE:", return_tensors="pt")
    input_ids = {k: v.to(model.device) for k, v in input_ids.items()}
    out = model.generate(
        **input_ids,
        max_new_tokens=10,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.1,
    )
    out_str = tokenizer.batch_decode(out)[0]
    return out_str.split("MOVE:")[-1].replace("<|endoftext|>", "").strip()


board = chess.Board()
model = AutoModelForCausalLM.from_pretrained("yp-edu/gpt2-stockfish-debug")
tokenizer = AutoTokenizer.from_pretrained("yp-edu/gpt2-stockfish-debug")  # or "gpt2"
tokenizer.pad_token = tokenizer.eos_token
for i in range(100):
    fen = board.fen()
    move_uci = next_move(model, tokenizer, fen)
    try:
        print(move_uci)
        move = chess.Move.from_uci(move_uci)
        if move not in board.legal_moves:
            raise chess.IllegalMoveError
        board.push(move)
        outcome = board.outcome()
        if outcome is not None:
            print(board)
            print(outcome.result())
            break
    except chess.IllegalMoveError:
        print(board)
        print("Illegal move", i)
        break
else:
    print(board)


e2e4
e7e6
d2d4
d7d5
b1c3
f8b4
e4d5
e6d5
f1d3
g8f6
d1e2
c8e6
g1f3
e8g8
e1g1
b8c6
c1f4
f8e8
a2a3
b4c3
b2c3
f6e4
f1e1
e4c3
e2e3
c3e4
f4g3
e4g3
h2g3
a7a5
a1b1
d8d7
c2c3
a5a4
e3d2
b7b6
d2c2
a8b8
c2d2
e6g4
e1e8
b8e8
b1e1
e8e1
d2e1
g4f3
e1e8
d7e8
g2f3
e8e1
g1g2
e1d1
d3b5
d1c2
g2h3
c2f2
b5c6
f2f3
c6d5
f3h5
h3g2
h5d5
g2f3
. . . . . . k .
. . p . . p p p
. p . . . . . .
. . . q . . . .
p . . P . . . .
P . P . . . P .
. . . . . . K .
. . . . . . . .
Illegal move 62


In [26]:
from project.dataset.lit_module import SeqAnnotationDM
label = "Result"
dm = SeqAnnotationDM(
    "data/games_0001/train_100K.parquet",
    "data/games_0001/val_100K.parquet",
    "data/games_0001/test_100K.parquet",
    32,
    ["Moves"],
    [label],
    8,
)
dm.setup()


Loading parquet file @  data/games_0001/train_100K.parquet  with columns  ['Moves', 'Result']
Loaded 70000 rows and 2 columns


Building lookup tables: 100%|██████████| 2/2 [00:00<00:00,  3.62it/s]
Converting columns to indices: 100%|██████████| 2/2 [00:00<00:00,  2.65it/s]


Loading parquet file @  data/games_0001/val_100K.parquet  with columns  ['Moves', 'Result']
Loaded 19999 rows and 2 columns


Building lookup tables: 100%|██████████| 2/2 [00:00<00:00, 11.91it/s]
Converting columns to indices: 100%|██████████| 2/2 [00:00<00:00,  8.31it/s]


Loading parquet file @  data/games_0001/test_100K.parquet  with columns  ['Moves', 'Result']
Loaded 10001 rows and 2 columns


Building lookup tables: 100%|██████████| 2/2 [00:00<00:00, 25.36it/s]
Converting columns to indices: 100%|██████████| 2/2 [00:00<00:00, 16.09it/s]


In [29]:
dm.fit_set.df

Unnamed: 0,Moves,Result
75721,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...",1
80184,"[1, 154, 155, 6, 156, 157, 158, 8, 159, 160, 1...",2
19864,"[1, 268, 3, 6, 5, 269, 270, 157, 13, 266, 188,...",3
76699,"[5, 6, 3, 157, 50, 2, 123, 4, 13, 154, 295, 27...",2
92991,"[1, 2, 5, 494, 210, 268, 495, 4, 496, 304, 3, ...",2
...,...,...
7075,"[5, 494, 3, 157, 123, 278, 295, 268, 35, 531, ...",3
51072,"[5, 157, 3, 278, 50, 494, 13, 268, 738, 576, 1...",1
71582,"[5, 157, 3, 278, 9, 268, 123, 531, 164, 277, 1...",3
21635,"[3, 157, 50, 2, 5, 4, 13, 494, 738, 576, 1, 63...",3


In [27]:
import numpy as np

value_counts = dm.fit_set.df[label].explode().value_counts()
weights = 1 / value_counts
weights = weights / weights.sum()
weights = weights.sort_index().values
weights = np.concatenate([np.zeros(1), weights])
weights, value_counts.sort_index()


(array([0.        , 0.30691389, 0.47402682, 0.2190593 ]),
 Result
 1    22962
 2    14867
 3    32171
 Name: count, dtype: int64)

In [24]:
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
timestamp

'20250520091935'