In [1]:
import os
import transformers
from transformers import GPT2LMHeadModel
import lightning
import torch
from omegaconf import DictConfig, OmegaConf
from src.gpt_module import GPT2Lightning
from src.chess_tokenizers import UciTileTokenizer, PgnCharTokenizer
import src.utils as utils

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# !rsync -avz adavis@newton.ist.ucf.edu:/home/adavis/git/train-transformer/outputs/2025-03-02/14-39-12/checkpoint/ .

First, load the checkpoint from pytorch lightning training and save it as a Huggingface model.

In [6]:
ckpt_path = "outputs/checkpoint/epoch=18-step=285786.ckpt"
ckpt = torch.load(ckpt_path)
ckpt.keys()

Next, we'll load the checkpoint into our GPT2Lightning module, save the hugginface model to disk, and clear memory.

In [None]:
output_path = os.path.join(os.path.dirname(ckpt_path),"..","hf_model",os.path.basename(ckpt_path))
ckpt = torch.load(ckpt_path)
model:GPT2Lightning = GPT2Lightning(OmegaConf.create(ckpt['hyper_parameters']))
model.load_state_dict(ckpt['state_dict'])
model.model.save_pretrained(output_path)
tokenizer = UciTileTokenizer(upper_promotions=True)
tokenizer.save_pretrained(output_path)
del model

LLM dtype set to torch.float32.


In [4]:
model = GPT2LMHeadModel.from_pretrained(output_path).train(False).requires_grad_(False)
forward = lambda s: model(tokenizer(s, return_tensors='pt')['input_ids'])
generate = lambda s, n: tokenizer.batch_decode(model.generate(tokenizer([s], return_tensors='pt')['input_ids'], attention_mask=tokenizer([s], return_tensors='pt', return_attention_mask=True)['attention_mask'],max_new_tokens=n))[0]
loss = lambda s: model(tokenizer(s, return_tensors='pt')['input_ids'], labels=tokenizer(s, return_tensors='pt')['input_ids'])
print(f'{forward("")=}')

forward("")=CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[ -0.8370,  -5.0638,  -5.9900, -13.0756,  -0.3998,   5.3587,  -2.8244,
           -2.7212,  -1.6958,  -1.6414,   8.1915,  -2.7699,   4.1968,   6.5402,
            7.7799,   9.5483,   9.8756,   5.6187,   6.1569,   3.9617,  -1.3667,
           -3.8264,  -1.0056,  -1.7576,  -2.8185,  -1.7683,  -3.0384,  -3.9871,
           -0.3712,  -0.5421,  -0.4377,   0.1427,  -1.6314,  -2.5468,  -1.2421,
           -0.8518,  -0.9784,  -1.7661,  -0.2989,  -0.5568,  -4.2425,  -0.5624,
           -5.5635,  -0.4996,  -1.4173,  -5.3151,  -2.2342,  -4.9424,  -3.2522,
           -3.0730,  -0.9018,  -1.5437,  -0.8215,  -1.1626,  -0.2576,  -0.4123,
           -0.3545,  -1.6461,  -3.3882,  -1.1151,  -3.4613,  -2.5481,  -3.8879,
           -4.4299,  -5.7905,  -4.6589,  -3.4310,  -2.4511,  -2.3585,  -3.7082,
           -4.1983,  -3.7257]]]), past_key_values=((tensor([[[[ 1.6229e-01, -1.1270e-01,  8.5852e-01,  3.3557e-01, -1.6215e-01,
        

In [33]:
import chess
from typing import Union, Iterable



In [36]:
truncate_illegal_moves("e2e4 c7c5 g1f3 d7d6 d2d4 c5d4 f3d4 g8f6 b1c3 a7a6 c1g5 e7e6 f2f4 f8e7 d1f3 d8c7 e1c1 b8d7 g2g4 b7b5 g5f6 d7f6 g4g5 f6d7 f4f5 d7c5 f5f6 g7f6 g5f6 e7f8 h1g1 b5b4 c3d5")

('e2e4 c7c5 g1f3 d7d6 d2d4 c5d4 f3d4 g8f6 b1c3 a7a6 c1g5 e7e6 f2f4 f8e7 d1f3 d8c7 e1c1 b8d7 g2g4 b7b5 g5f6 d7f6 g4g5 f6d7 f4f5 d7c5 f5f6 g7f6 g5f6 e7f8 h1g1 b5b4 c3d5',
 33,
 '')

In [42]:
utils.pgn_to_uci("""[Event "Bled op 22nd"]
[Site "Bled"]
[Date "2001.03.26"]
[Round "3"]
[White "Jukic, Branimir"]
[Black "Kos, Toni"]
[Result "0-1"]
[GameId "KWkVQlM8"]
[WhiteElo "2386"]
[BlackElo "2402"]
[Variant "Standard"]
[TimeControl "-"]
[ECO "B99"]
[Opening "Sicilian Defense: Najdorf Variation, Main Line"]
[Termination "Normal"]
[Annotator "lichess.org"]

1. e4 c5 2. Nf3 d6 3. d4 cxd4 4. Nxd4 Nf6 5. Nc3 a6 6. Bg5 e6 7. f4 Be7 8. Qf3 Qc7 9. O-O-O Nbd7 { B99 Sicilian Defense: Najdorf Variation, Main Line } 10. g4 b5 11. Bxf6 Nxf6 12. g5 Nd7 13. f5 Nc5 14. f6 gxf6 15. gxf6 Bf8 16. Rg1 b4 17. Nd5 exd5 18. exd5 Bd7 19. Rg7 O-O-O 20. Rxf7 Bh6+ 21. Kb1 Rdf8 22. Re7 Bg5 23. Rxd7 Nxd7 24. Ne6 Qb6 25. Nxg5 Rxf6 26. Qg4 h5 27. Qh3 Kc7 28. Qh4 Kb8 29. Bh3 Ne5 30. Ne4 Rg6 31. Qe7 Rg1 32. Bf1 Nc4 33. Ng3 h4 34. Qf6 Rc8 35. Ne2 Rh1 36. Nd4 Ne3 37. Nc6+ Rxc6 38. dxc6 Nxd1 39. Qf8+ Ka7 40. Qf7+ Ka8 41. Qf8+ Qb8 42. Qf3 Nc3+ 43. Kc1 Rxf1+ 44. Qxf1 Nb5 45. Qf4 Qd8 46. a4 bxa3 47. bxa3 Qh8 48. Qf3 Qa1+ 49. Kd2 Qd4+ 50. Kc1 Kb8 51. Qf7 Qe3+ 52. Kb1 Qb6 53. Qe8+ Ka7 54. Qd7+ Nc7+ 55. Ka2 d5 56. h3 Qa5 57. Qd8 Kb6 58. Qxh4 Qc5 59. Kb3 d4 60. c4 a5 61. Qe1 Qxc6 62. Qe2 a4+ 63. Kc2 Nd5 64. Qd3 Ne3+ 65. Kb1""")

'e2e4 c7c5 g1f3 d7d6 d2d4 c5d4 f3d4 g8f6 b1c3 a7a6 c1g5 e7e6 f2f4 f8e7 d1f3 d8c7 e1c1 b8d7 g2g4 b7b5 g5f6 d7f6 g4g5 f6d7 f4f5 d7c5 f5f6 g7f6 g5f6 e7f8 h1g1 b5b4 c3d5 e6d5 e4d5 c8d7 g1g7 e8c8 g7f7 f8h6 c1b1 d8f8 f7e7 h6g5 e7d7 c5d7 d4e6 c7b6 e6g5 f8f6 f3g4 h7h5 g4h3 c8c7 h3h4 c7b8 f1h3 d7e5 g5e4 f6g6 h4e7 g6g1 h3f1 e5c4 e4g3 h5h4 e7f6 h8c8 g3e2 g1h1 e2d4 c4e3 d4c6 c8c6 d5c6 e3d1 f6f8 b8a7 f8f7 a7a8 f7f8 b6b8 f8f3 d1c3 b1c1 h1f1 f3f1 c3b5 f1f4 b8d8 a2a4 b4a3 b2a3 d8h8 f4f3 h8a1 c1d2 a1d4 d2c1 a8b8 f3f7 d4e3 c1b1 e3b6 f7e8 b8a7 e8d7 b5c7 b1a2 d6d5 h2h3 b6a5 d7d8 a7b6 d8h4 a5c5 a2b3 d5d4 c2c4 a6a5 h4e1 c5c6 e1e2 a5a4 b3c2 c7d5 e2d3 d5e3 c2b1'

In [46]:
# test_pgn = """1. e4 e5 2. Nf3 Nc6 3. Bc4 Bc5 4. c3 Nf6 5. d4 exd4 6. cxd4 Bb4+ 7. Nc3 Nxe4 8. O-O Bxc3 9. bxc3 d5 10. Bd3 O-O 11. Qc2 Bf5 12. Ne5 Nxe5 13. dxe5 Qe7 14. f3 Qc5+ 15. Kh1 Nxc3 16. Bxf5 d4 17. Bxh7+ Kh8 18. Bd3 Qxe5 19. Bb2 Qe3 20. Rae1 Qh6 21. Bxc3 dxc3 22. Qxc3 c6 23. Re4 Rad8 24. Rh4 Rxd3 25. Rxh6+ Kg8 26. Qxd3 gxh6 27. Qd7 Kg7 28. Qxb7 Re8 29. Qxc6 Re6 30. Qc3+ Rf6 31. Re1 Kg6 32. Re4 Rf5 33. Rg4+ Rg5 34. Rxg5+ hxg5 35. Qc7 f6 36. Qxa7 Kf5 37. Qb6 Kg6 38. a4 Kf5 39. a5 Kg6 40. a6 Kf5 41. a7 Kg6 42. a8=Q Kf5 43. Qe4"""
prefix = 'e2e4 c7c5 g1f3 d7d6 d2d4 c5d4 f3d4 g8f6 b1c3 a7a6 c1g5 e7e6 f2f4 f8e7 d1f3 d8c7 e1c1 b8d7 g2g4 b7b5 g5f6 d7f6 g4g5 f6d7 f4f5 d7c5 f5f6 g7f6 g5f6 e7f8 h1g1 b5b4 c3d5 e6d5 e4d5 c8d7 g1g7 e8c8 g7f7 f8h6 c1b1 d8f8 f7e7 h6g5 e7d7 c5d7 d4e6 c7b6 e6g5 f8f6 f3g4 h7h5 g4h3 c8c7 h3h4 c7b8 f1h3 d7e5 g5e4 f6g6 h4e7 g6g1 h3f1 e5c4 e4g3 h5h4 e7f6 h8c8 g3e2 g1h1 e2d4 c4e3 d4c6 c8c6 d5c6 e3d1 f6f8 b8a7 f8f7 a7a8 f7f8 b6b8 f8f3 d1c3 b1c1 h1f1 f3f1 c3b5 f1f4 b8d8 a2a4 b4a3 b2a3 d8h8 f4f3 h8a1 c1d2 a1d4 d2c1 a8b8 f3f7 d4e3 c1b1 e3b6 f7e8 b8a7 e8d7 b5c7 b1a2 d6d5 h2h3 b6a5 d7d8 a7b6 d8h4 a5c5 a2b3 d5d4 c2c4 a6a5 h4e1 c5c6 e1e2'
synth_game = generate(prefix, 10)
legal_moves, n, bad_move = truncate_illegal_moves(synth_game)
print(f"{legal_moves=}")
print(f"{bad_move=} {n=} of {len(synth_game.split())} of {len(prefix.split())}")
print(" ".join(synth_game.split()[n:]))

pgn_str = utils.uci_to_pgn(legal_moves)
print(pgn_str)


legal_moves='e2e4 c7c5 g1f3 d7d6 d2d4 c5d4 f3d4 g8f6 b1c3 a7a6 c1g5 e7e6 f2f4 f8e7 d1f3 d8c7 e1c1 b8d7 g2g4 b7b5 g5f6 d7f6 g4g5 f6d7 f4f5 d7c5 f5f6 g7f6 g5f6 e7f8 h1g1 b5b4 c3d5 e6d5 e4d5 c8d7 g1g7 e8c8 g7f7 f8h6 c1b1 d8f8 f7e7 h6g5 e7d7 c5d7 d4e6 c7b6 e6g5 f8f6 f3g4 h7h5 g4h3 c8c7 h3h4 c7b8 f1h3 d7e5 g5e4 f6g6 h4e7 g6g1 h3f1 e5c4 e4g3 h5h4 e7f6 h8c8 g3e2 g1h1 e2d4 c4e3 d4c6 c8c6 d5c6 e3d1 f6f8 b8a7 f8f7 a7a8 f7f8 b6b8 f8f3 d1c3 b1c1 h1f1 f3f1 c3b5 f1f4 b8d8 a2a4 b4a3 b2a3 d8h8 f4f3 h8a1 c1d2 a1d4 d2c1 a8b8 f3f7 d4e3 c1b1 e3b6 f7e8 b8a7 e8d7 b5c7 b1a2 d6d5 h2h3 b6a5 d7d8 a7b6 d8h4 a5c5 a2b3 d5d4 c2c4 a6a5 h4e1 c5c6 e1e2'
bad_move='e8Q#' n=123 of 124 of 123
e8Q#
[Event "?"]
[Site "?"]
[Date "????.??.??"]
[Round "?"]
[White "?"]
[Black "?"]
[Result "*"]

1. e4 c5 2. Nf3 d6 3. d4 cxd4 4. Nxd4 Nf6 5. Nc3 a6 6. Bg5 e6 7. f4 Be7 8. Qf3 Qc7 9. O-O-O Nbd7 10. g4 b5 11. Bxf6 Nxf6 12. g5 Nd7 13. f5 Nc5 14. f6 gxf6 15. gxf6 Bf8 16. Rg1 b4 17. Nd5 exd5 18. exd5 Bd7 19. Rg7 O-O-O 20. Rxf7 Bh6+ 21. 

In [26]:
model.push_to_hub("chessGPT_d12_pgn_elite",commit_message="Trained on 2025-02-28/18-12-46 https://wandb.ai/austinleedavis/train/runs/0bqysocc")

model.safetensors: 100%|██████████| 343M/343M [02:00<00:00, 2.84MB/s] 


CommitInfo(commit_url='https://huggingface.co/austindavis/chessGPT_d12_pgn_elite/commit/0cf18dce281de3d498a02f7e4d28d6ed87158517', commit_message='Trained on 2025-02-28/18-12-46 https://wandb.ai/austinleedavis/train/runs/0bqysocc', commit_description='', oid='0cf18dce281de3d498a02f7e4d28d6ed87158517', pr_url=None, repo_url=RepoUrl('https://huggingface.co/austindavis/chessGPT_d12_pgn_elite', endpoint='https://huggingface.co', repo_type='model', repo_id='austindavis/chessGPT_d12_pgn_elite'), pr_revision=None, pr_num=None)