In [None]:
from model import Transformer
from dataset import Dataset, play
import torch
from tqdm import tqdm

In [None]:
device = torch.device("mps")
dataset = Dataset(device, min_partition_length=32)

print(len(dataset))

In [None]:
print(vocab_size)

In [None]:
vocab_size = dataset.vocab_size

print(vocab_size)

context_size = 32
n_embd = 256
n_head = 4
n_layer = 4
dropout = 0.1

lr = 3e-4
epochs = 500
batch_size = 512

should_train = True

model = Transformer(vocab_size, context_size, n_embd, n_head, n_layer, dropout).to(device)
# model.load_state_dict(torch.load("model.pth"))
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
@torch.no_grad()
def compute_test_score(n_sample: int = 50, topk: int = 1):
    model.eval()
    batch = dataset.sample(n_sample, context_size, train=False)
    predictions = model(batch[:, :-1])
    test_loss = model.loss(predictions, batch[:, 1:])
    test_accuracy = model.accuracy(predictions, batch[:, 1:], topk=topk)
    model.train()
    return test_loss, test_accuracy

In [None]:
def train(model):
    pbar = tqdm(range(epochs))

    for epoch in pbar:
        batch = dataset.sample(batch_size, context_size, device)
        optimizer.zero_grad()
        
        predictions = model(batch[:, :-1])
        
        loss = model.loss(predictions, batch[:, 1:])
        accuracy = model.accuracy(predictions, batch[:, 1:])
        loss.backward()
        
        optimizer.step()
        
        pbar.set_description(f"Loss: {loss.item():.4f}, Accuracy: {accuracy.item():.4f}")
        
        if epoch % 100 == 0:
            test_loss, test_accuracy = compute_test_score(topk=3)
            print(f"Test Loss: {test_loss.item():.4f}, Test Accuracy: {test_accuracy.item():.4f}")

In [None]:
if should_train:
    train(model)
    torch.save(model.state_dict(), 'model.pth')

In [None]:
partition = dataset.decode_partition(model.generate(100, temperature=1.0))
play(partition)