In [1]:
import os
import transformers
from transformers import GPT2LMHeadModel
import lightning
import torch
from src.gpt_module import GPT2Lightning
from src.chess_tokenizers import UciTileTokenizer
import src.utils as utils

  from .autonotebook import tqdm as notebook_tqdm


First, load the checkpoint from pytorch lightning training and save it as a Huggingface model.

In [6]:
ckpt_path = "outputs/checkpoint/epoch=1-step=38093.ckpt"
output_path = os.path.join(os.path.dirname(ckpt_path),"..","hf_model")
model: GPT2LMHeadModel = GPT2Lightning.load_from_checkpoint(ckpt_path).model
model.save_pretrained(output_path)
tokenizer = UciTileTokenizer(upper_promotions=True)
tokenizer.save_pretrained(output_path)
del model

LLM dtype set to float32.


In [8]:
model = GPT2LMHeadModel.from_pretrained(output_path).train(False).requires_grad_(False)
tokenizer = UciTileTokenizer(upper_promotions=True)
forward = lambda s: model(tokenizer(s, return_tensors='pt')['input_ids'])
generate = lambda s, n: tokenizer.batch_decode(model.generate(tokenizer([s], return_tensors='pt')['input_ids'], max_new_tokens=n))[0]
loss = lambda s: model(tokenizer(s, return_tensors='pt')['input_ids'], labels=tokenizer(s, return_tensors='pt')['input_ids'])
print(f'{forward("e2e4")=}')

forward("e2e4")=CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[-2.6358e+00, -4.5087e+00, -8.9719e+00, -9.8046e+00,  3.1368e-01,
           5.6080e+00, -1.0170e-01, -1.1656e-01,  2.9919e-01,  1.6640e-01,
           7.5141e+00, -8.8405e-01,  4.0575e+00,  6.9123e+00,  7.4935e+00,
           9.2505e+00,  1.0022e+01,  6.2085e+00,  6.9270e+00,  3.9415e+00,
          -9.4112e-01, -1.0826e+00, -5.6983e-02, -5.5098e-01, -7.2798e-01,
          -1.0171e-01, -8.2171e-01, -1.2121e+00, -3.7130e-01, -5.6842e-01,
          -9.1448e-02,  3.3276e-01, -3.2669e-01, -1.0012e+00, -1.3123e+00,
          -9.8074e-01, -1.5437e+00, -1.2318e+00, -1.2625e+00, -5.8870e-01,
          -1.6770e+00, -1.0776e+00, -1.7345e+00, -7.5281e-01, -1.9129e+00,
          -2.0421e+00, -1.5546e+00, -2.4365e+00, -1.9854e+00, -9.9128e-01,
          -1.5286e+00, -1.4279e+00, -1.2885e+00, -1.1187e+00, -2.7727e-01,
          -5.5760e-01, -5.7518e-01, -1.3797e+00, -1.4624e+00, -7.4012e-01,
          -2.4939e+00, -1.5268e+

In [14]:
synth_game = generate("e2e4", 184)
pgn_str = utils.uci_to_pgn(synth_game)
print(pgn_str)


[Event "?"]
[Site "?"]
[Date "????.??.??"]
[Round "?"]
[White "?"]
[Black "?"]
[Result "*"]

1. e4 e5 2. Nf3 Nc6 3. Bc4 Bc5 4. c3 Nf6 5. d4 exd4 6. cxd4 Bb4+ 7. Nc3 Nxe4 8. O-O Bxc3 9. bxc3 d5 10. Bd3 O-O 11. Qc2 Bf5 12. Ne5 Nxe5 13. dxe5 Qe7 14. f3 Qc5+ 15. Kh1 Nxc3 16. Bxf5 d4 17. Bxh7+ Kh8 18. Bd3 Qxe5 19. Bb2 Qe3 20. Rae1 Qh6 21. Bxc3 dxc3 22. Qxc3 c6 23. Re4 Rad8 24. Rh4 Rxd3 25. Rxh6+ Kg8 26. Qxd3 gxh6 27. Qd7 Kg7 28. Qxb7 Re8 29. Qxc6 Re6 30. Qc3+ Rf6 31. Re1 Kg6 32. Re4 Rf5 33. Rg4+ Rg5 34. Rxg5+ hxg5 35. Qc7 f6 36. Qxa7 Kf5 37. Qb6 Kg6 38. a4 Kf5 39. a5 Kg6 40. a6 Kf5 41. a7 Kg6 42. a8=Q Kf5 43. Qe4# *


In [15]:
model.push_to_hub("chessGPT_d12_1500",commit_message="Trained on 2025-02-27/16-29-15 https://wandb.ai/austinleedavis/train/runs/cb3q0p93")

model.safetensors: 100%|██████████| 344M/344M [02:05<00:00, 2.73MB/s]    


CommitInfo(commit_url='https://huggingface.co/austindavis/chessGPT_d12_1500/commit/0cfaa55fd6c3d51c054698558bf3ccd7dd789883', commit_message='Trained on 2025-02-27/16-29-15 https://wandb.ai/austinleedavis/train/runs/cb3q0p93', commit_description='', oid='0cfaa55fd6c3d51c054698558bf3ccd7dd789883', pr_url=None, repo_url=RepoUrl('https://huggingface.co/austindavis/chessGPT_d12_1500', endpoint='https://huggingface.co', repo_type='model', repo_id='austindavis/chessGPT_d12_1500'), pr_revision=None, pr_num=None)