In [1]:
from datasets import load_dataset
from huggingface_hub import login
from dotenv import load_dotenv
import os
from src.tokenizer.tokenizer import ChessTokenizer

load_dotenv()

login(token=os.getenv("HF_TOKEN"))


  from .autonotebook import tqdm as notebook_tqdm


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to C:\Users\James\.cache\huggingface\token
Login successful


In [2]:
dataset = load_dataset("jimbowyer123/chessformers")
dataset = dataset.filter(lambda x: x["moves"] is not None)
tokenizer = ChessTokenizer()


In [3]:
from transformers import GPT2LMHeadModel, GPT2Config

# Initialize the GPT-2 configuration with the tokenizer's vocabulary size
config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_positions=1024,
    n_ctx=256,
    n_embd=768,
    n_layer=12,
    n_head=12
)

# Create a new GPT-2 model with the custom configuration
model = GPT2LMHeadModel(config)

print(f"Initialized GPT-2 model with vocabulary size: {tokenizer.vocab_size}")


Initialized GPT-2 model with vocabulary size: 8514


In [4]:
import torch
from torch.nn.utils.rnn import pad_sequence

class ChessDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.max_length = tokenizer.max_length

    def __call__(self, examples):
        moves = [example["moves"] for example in examples]
        inputs = self.tokenizer(moves, padding=True, return_tensors=True)
        
        # Prepare labels (shift input_ids right by one position)
        labels = inputs["input_ids"].clone()
        labels = torch.roll(labels, shifts=-1, dims=1)
        labels[:, -1] = -100

        inputs["labels"] = labels
        
        return inputs

# Create an instance of the data collator
data_collator = ChessDataCollator(tokenizer)

print("Custom PyTorch data collator created for chess move prediction.")

data_collator = ChessDataCollator(tokenizer)



Custom PyTorch data collator created for chess move prediction.


In [5]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size=2, collate_fn=data_collator)



In [6]:
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm.auto import tqdm

import wandb

wandb.login()

wandb.init(project="chess-training")
# Set up the optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

batch_logging_interval = 10  # Log every 100 batches

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
    
    for batch_idx, batch in enumerate(progress_bar):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        progress_bar.set_postfix({"loss": loss.item()})
        
        if batch_idx % batch_logging_interval == 0:
            wandb.log({"epoch": epoch + 1, "batch": batch_idx, "loss": loss.item()})
    
    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    
    wandb.log({"epoch": epoch + 1, "avg_loss": avg_loss})

print("Training completed!")


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbower-james1996[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/3:   0%|          | 692/1761463 [08:30<414:49:33,  1.18it/s, loss=4.08]