In [4]:
import torch
import torch.nn as nn

from py_pytorch_chess_model import ChessModel

MAX_SEQUENCE_LENGTH = 512
VOCAB_SIZE = 370

In [2]:
from torch.utils.data import Dataset, DataLoader

class ChessMovesDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx]
    
    
with open("games.txt", "r") as file:
    sentences = file.read().splitlines()

# Create the dataset


In [24]:
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from py_get_bert_word_embeddings import EmbeddingFromSentence
from torch.utils.data import DataLoader
import os

# Create a directory to save checkpoints
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ChessModel(MAX_SEQUENCE_LENGTH, vocab_size=VOCAB_SIZE).to(device)



optimizer = optim.Adam(model.parameters(), lr = 0.001)
loss_fn = nn.CrossEntropyLoss()
embedder = EmbeddingFromSentence(MAX_SEQUENCE_LENGTH, chess_vocab_size=VOCAB_SIZE)

text_dataset = ChessMovesDataset(sentences)


batch_size = 24
data_loader = DataLoader(text_dataset, batch_size=batch_size, shuffle=True)
losses_list = []
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for idx, text in enumerate(data_loader):
        
        optimizer.zero_grad()
        embeds, attn_mask, ids = embedder.one_hot_from_sentence(text)
        
        embeds = F.pad(embeds, (0,0,0,1),  value=0)
        attn_mask = F.pad(attn_mask, (0,1),  value=0).bool()
        ids = F.pad(ids, (0,1),  value=0)
        inpt = embeds[:,:-1,:].to(device)
        inpt_mask = attn_mask[:,:-1].bool().to(device)
        labels = ids[:,1:].to(device)
        

        logits = model(inpt, inpt_mask)

        logits = logits.view(-1, VOCAB_SIZE)  # Shape: (batch_size * sequence_length, 30522)
        labels = labels.view(-1)          # Shape: (batch_size * sequence_length)

        attn_mask_reshape = inpt_mask.view(-1)
        valid_logits = logits[attn_mask_reshape]
        valid_labels = labels[attn_mask_reshape]
        
        loss = loss_fn(valid_logits, valid_labels)
        
        
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        losses_list.append(loss.item())
        if (idx % 10 == 0):
            print()
            print(f"Iteration: {idx}, Average Loss: {total_loss/(idx+1)}", end=" | ")
            
            
        # Checkpoint the model every 1000 iterations
        if (idx % 300 == 0):  # Avoid checkpointing at the very start
            checkpoint_path = os.path.join(checkpoint_dir, f"chess_model_epoch{epoch+1}_iter{idx}.pt")
            torch.save({
                'epoch': epoch,
                'iteration': idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item(),
            }, checkpoint_path)
            print()
            print(f"Checkpoint saved at {checkpoint_path}")

        print(" " + str(loss.item()), end="")
        
        
    avg_loss = total_loss / len(data_loader)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}")


Iteration: 0, Average Loss: 5.913638114929199 | 
Checkpoint saved at checkpoints\chess_model_epoch1_iter0.pt
 5.913638114929199 5.908865928649902 5.900223255157471 5.90372896194458 5.8990092277526855 5.902172088623047 5.914328575134277 5.912727355957031 5.901181221008301 5.891996383666992
Iteration: 10, Average Loss: 5.90559907393022 |  5.9137187004089355 5.912589073181152 5.898606777191162 5.90002965927124 5.889636039733887 5.897335529327393 5.89759635925293 5.897603988647461 5.893815994262695 5.900088787078857
Iteration: 20, Average Loss: 5.902080876486642 |  5.894806385040283 5.894341468811035 5.898996829986572 5.896486282348633 5.894387722015381 5.90232515335083 5.903687477111816 5.893354415893555 5.895406246185303 5.8980302810668945
Iteration: 30, Average Loss: 5.900552949597759 |  5.896427154541016 5.8966264724731445 5.902865409851074 5.896948337554932 5.899543285369873 5.895835876464844 5.885085582733154 5.891657829284668 5.892608642578125 5.900711536407471
Iteration: 40, Avera

RuntimeError: [enforce fail at inline_container.cc:603] . unexpected pos 124237952 vs 124237840

Model output device: cuda:0


In [25]:
MAX_SEQUENCE_LENGTH = 512
from py_get_bert_word_embeddings import EmbeddingFromSentence
from py_pytorch_chess_model import ChessModel
model = ChessModel(MAX_SEQUENCE_LENGTH)

embedder = EmbeddingFromSentence(MAX_SEQUENCE_LENGTH)
embeds, mask, ids = embedder.one_hot_from_sentence(["e4"])

output = model(embeds, mask)

In [26]:
output[0,0,:].argmax()

tensor(315)

In [27]:
embeds, attn_mask, ids = embedder.get_embeddings_from_sentence("Nf3 d5 g3 Nc6 Bg2 e5 d3 Nf6 Bf5 Re1 Qd7 Nh4 Bg6 e3 Nc3 Bb4 Bd2 Bf5 Nf5 Qf5 a3 Bc5 Qe2 h5 Qf1 g5 e4 e4 Ne4 h4 Nc5 g3 g3")
        
embeds = F.pad(embeds, (0,0,0,1),  value=0)
attn_mask = F.pad(attn_mask, (0,1),  value=0).bool()
ids = F.pad(ids, (0,1),  value=0)

inpt = embeds[:,:-1,:].to(device)
inpt_mask = attn_mask[:,:-1].bool().to(device)
labels = ids[:,1:].to(device)
print(attn_mask.shape)

torch.Size([1, 513])


In [13]:
output[0][0][2549]

tensor(1.0000, device='cuda:0', grad_fn=<SelectBackward0>)

In [1]:
from py_get_bert_word_embeddings import EmbeddingFromSentence

x  = EmbeddingFromSentence(max_sequence_length=512)
o, m, i = x.one_hot_from_sentence(["e4 e5 h4 Ne4 Qf2 Qf1 Ke3"]) 

In [10]:
o[0][0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 

In [16]:
from transformers import BertModel, BertTokenizer
tokenizer: BertTokenizer = BertTokenizer.from_pretrained('bert-base-uncased', clean_up_tokenization_spaces=True)

tokenizer.decode(tokenizer("Nf3 d5 g3 Nc6 Bg2 e5 d3 Nf6 Bf5 Re1 Qd7 Nh4 Bg6 e3 Nc3 Bb4 Bd2 Bf5 Nf5 Qf5 a3 Bc5 Qe2 h5 Qf1 g5 e4 e4 Ne4 h4 Nc5 g3 g3")["input_ids"])

'[CLS] nf3 d5 g3 nc6 bg2 e5 d3 nf6 bf5 re1 qd7 nh4 bg6 e3 nc3 bb4 bd2 bf5 nf5 qf5 a3 bc5 qe2 h5 qf1 g5 e4 e4 ne4 h4 nc5 g3 g3 [SEP]'

In [17]:
for idx, text in enumerate(data_loader):
    print(text)
    break

24


In [18]:
labels

tensor([[ 1050,  2546,  2509,  1040,  2629,  1043,  2509, 13316,  2575,  1038,
          2290,  2475,  1041,  2629,  1040,  2509,  1050,  2546,  2575, 28939,
          2629,  2128,  2487,  1053,  2094,  2581, 18699,  2549,  1038,  2290,
          2575,  1041,  2509, 13316,  2509, 22861,  2549,  1038,  2094,  2475,
         28939,  2629,  1050,  2546,  2629,  1053,  2546,  2629,  1037,  2509,
          4647,  2629,  1053,  2063,  2475,  1044,  2629,  1053,  2546,  2487,
          1043,  2629,  1041,  2549,  1041,  2549, 11265,  2549,  1044,  2549,
         13316,  2629,  1043,  2509,  1043,  2509,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  