In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from transformer import TransformerConfig
from lm import LM

from data import Dataset

In [2]:
dataset = Dataset()

config = TransformerConfig(d_model=64, n_layers=1, n_heads=8, max_len=dataset.max_len, dropout=0.)
model = LM(config, vocab_size=len(dataset.vocabulaire))
model.load_state_dict(torch.load("transformer_d_model64_n_heads8_1_57_1_68.pth", map_location=torch.device('cpu')))
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [3]:
sum(p.numel() for p in model.parameters())

71680

In [43]:
for i in range(10000):
    X, Y = dataset.get_batch('train', 32) # (B, L)
    logits = model(X) # (B, L, vocab_size)

    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>'])
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if i%1000==0:
        train_loss = loss.item()

        X, Y = dataset.get_batch('test', 128) # (B, L)
        logits = model(X) # (B, L, vocab_size)
        test_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>']).item()

        print(f"train loss : {train_loss:.2f} | test loss : {test_loss:.2f}")

train loss : 1.87 | test loss : 2.02
train loss : 2.08 | test loss : 1.94
train loss : 1.79 | test loss : 1.91
train loss : 1.98 | test loss : 2.02
train loss : 2.00 | test loss : 1.86
train loss : 1.86 | test loss : 1.98
train loss : 2.01 | test loss : 1.92
train loss : 1.93 | test loss : 1.96
train loss : 2.07 | test loss : 2.00
train loss : 2.04 | test loss : 1.83


In [5]:
X, Y = dataset.get_batch('val', 256) # (B, L)
logits = model(X) # (B, L, vocab_size)

loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>'])
loss

tensor(1.6019, grad_fn=<NllLossBackward0>)

In [46]:
torch.save(model.state_dict(), "transformer_d_model16.pth")

In [35]:
# transformer_d_model16 : loss de 1.85


In [6]:
def sample(self, prompt = "", g = torch.Generator()):
    idx = torch.tensor([dataset.char_to_int[c] for c in prompt], dtype=torch.int32).unsqueeze(0)
    idx = torch.cat([torch.tensor(dataset.char_to_int['<SOS>']).view(1, 1), idx], dim=1)
    next_id = -1

    while next_id != dataset.char_to_int['<EOS>']:
        logits = self(idx) # (1, l, d_model)

        probs = F.softmax(logits[:, -1, :], dim=-1)
        next_id = torch.multinomial(probs, num_samples=1, generator=g).item()
        idx = torch.cat([idx, torch.tensor(next_id).view(1, 1)], dim=1)
        
    return "".join([dataset.int_to_char[p.item()] for p in idx[0, 1:-1]])

In [14]:
for _ in range(10):
    print(sample(model))

passis
mont-sainte
aveslès-sur-boz
chaymont-en-calès
cherley
gragille
cignac
saint-andoux-le-creix
saullevern-le-sec
saint-vincent-des-prés
