In [2]:
import torch
import torch.nn as nn

from py_pytorch_chess_model import ChessModel

MAX_SEQUENCE_LENGTH = 512
VOCAB_SIZE = 370

In [3]:
from torch.utils.data import Dataset, DataLoader

class ChessMovesDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx]
    
    
with open("games.txt", "r") as file:
    sentences = file.read().splitlines()

# Create the dataset


In [None]:
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from py_get_bert_word_embeddings import EmbeddingFromSentence
from torch.utils.data import DataLoader
import os

# Create a directory to save checkpoints
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ChessModel(MAX_SEQUENCE_LENGTH, vocab_size=VOCAB_SIZE).to(device)



optimizer = optim.Adam(model.parameters(), lr = 0.001)
loss_fn = nn.CrossEntropyLoss()
embedder = EmbeddingFromSentence(MAX_SEQUENCE_LENGTH, chess_vocab_size=VOCAB_SIZE)

text_dataset = ChessMovesDataset(sentences)


batch_size = 24
data_loader = DataLoader(text_dataset, batch_size=batch_size, shuffle=True)
losses_list = []
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for idx, text in enumerate(data_loader):
        
        optimizer.zero_grad()
        embeds, attn_mask, ids = embedder.one_hot_from_sentence(text)
        
        embeds = F.pad(embeds, (0,0,0,1),  value=0)
        attn_mask = F.pad(attn_mask, (0,1),  value=0).bool()
        ids = F.pad(ids, (0,1),  value=0)
        inpt = embeds[:,:-1,:].to(device)
        inpt_mask = attn_mask[:,:-1].bool().to(device)
        labels = ids[:,1:].to(device)
        

        logits = model(inpt, inpt_mask)

        logits = logits.view(-1, VOCAB_SIZE)  # Shape: (batch_size * sequence_length, 30522)
        labels = labels.view(-1)          # Shape: (batch_size * sequence_length)

        attn_mask_reshape = inpt_mask.view(-1)
        valid_logits = logits[attn_mask_reshape]
        valid_labels = labels[attn_mask_reshape]
        
        loss = loss_fn(valid_logits, valid_labels)
        
        
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        losses_list.append(loss.item())
        if (idx % 10 == 0):
            print()
            print(f"Iteration: {idx}, Average Loss: {total_loss/(idx+1)}", end=" | ")
            
            
        # Checkpoint the model every 1000 iterations
        if (idx % 300 == 0):  # Avoid checkpointing at the very start
            checkpoint_path = os.path.join(checkpoint_dir, f"chess_model_epoch{epoch+1}_iter{idx}.pt")
            torch.save({
                'epoch': epoch,
                'iteration': idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item(),
            }, checkpoint_path)
            print()
            print(f"Checkpoint saved at {checkpoint_path}")

        print(" " + str(loss.item()), end="")
        
        
    avg_loss = total_loss / len(data_loader)
    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {avg_loss:.4f}")