In [None]:
#%pip install -r requirements.txt

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import tiktoken
from dataset import ShakespeareDataset
from model import TransformerModel
import config

In [None]:
# kagglehub.dataset_download("nenadblagovcanin/shakespeare")


In [None]:
with open('shakespeare.txt', 'r') as f:
    text = f.read()

encoder = tiktoken.get_encoding("gpt2")
tokens = encoder.encode(text)

In [None]:
dataset = ShakespeareDataset(tokens, config.sequence_length)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

In [None]:
vocab_size = encoder.n_vocab
model = TransformerModel(
    vocab_size=vocab_size,
    embedding_size=config.embedding_size,
    num_heads=config.num_heads,
    num_layers=config.num_layers,
    hidden_size=config.hidden_size,
    max_seq_length=config.max_seq_length,
).to(config.device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

In [None]:
model.train()
for epoch in range(config.num_epochs):
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(config.device), targets.to(config.device)

        # Prepare inputs for the Transformer model
        decoder_input = torch.cat([torch.zeros(inputs.size(0), 1).long().to(config.device), targets[:, :-1]], dim=1)

        # Forward pass
        outputs = model(inputs, decoder_input)

        # Reshape outputs and targets for the loss function
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{config.num_epochs}], Step [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}")

# Save the model
torch.save(model.state_dict(), "shakespeare_transformer.pth")