In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from transformer import TransformerConfig
from lm import LM

from data import Dataset

In [32]:
dataset = Dataset()

config = TransformerConfig(d_model=128, n_layers=1, n_heads=4, max_len=dataset.max_len, dropout=0.)
model = LM(config, vocab_size=len(dataset.vocabulaire))
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)


In [33]:
sum(p.numel() for p in model.parameters())

274432

In [35]:
for i in range(1000):
    X, Y = dataset.get_batch('train', 32) # (B, L)
    logits = model(X) # (B, L, vocab_size)

    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>'])
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if i%100==0:
        print(loss)

tensor(1.9733, grad_fn=<NllLossBackward0>)
tensor(1.8507, grad_fn=<NllLossBackward0>)
tensor(1.8540, grad_fn=<NllLossBackward0>)
tensor(1.8158, grad_fn=<NllLossBackward0>)
tensor(1.9138, grad_fn=<NllLossBackward0>)
tensor(1.5096, grad_fn=<NllLossBackward0>)
tensor(1.7285, grad_fn=<NllLossBackward0>)
tensor(1.8739, grad_fn=<NllLossBackward0>)
tensor(1.9907, grad_fn=<NllLossBackward0>)
tensor(1.8522, grad_fn=<NllLossBackward0>)


In [41]:
X, Y = dataset.get_batch('val', 16) # (B, L)
logits = model(X) # (B, L, vocab_size)

loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>'])
loss

tensor(1.7709, grad_fn=<NllLossBackward0>)

In [42]:
torch.save(model.state_dict(), "transformer.pth")