In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from transformer import TransformerConfig
from lm import LM

In [2]:
# chargement des données
fichier = open('villes.txt')
donnees = fichier.read()
villes = donnees.replace('\n', ',').split(',')
villes = [ville for ville in villes if len(ville) > 2]

In [3]:
# création du vocabulaire

vocabulaire = sorted(list(set(''.join(villes))))
vocabulaire = ["<pad>", "<SOS>", "<EOS>"] + vocabulaire
# <SOS> et <EOS> sont ajoutés respectivement au début et à la fin de chaque séquence
# <pad> est utilisé pour faire en sorte que toutes les séquences aient la même longueur

# pour convertir char <-> int
char_to_int = {}
int_to_char = {}

for (c, i) in zip(vocabulaire, range(len(vocabulaire))):
    char_to_int[c] = i
    int_to_char[i] = c

In [4]:
num_sequences = len(villes)
max_len = max([len(ville) for ville in villes]) + 2 # <SOS> et <EOS>

X = torch.zeros((num_sequences, max_len), dtype=torch.int32)

for i in range(num_sequences):
    X[i] = torch.tensor([char_to_int['<SOS>']] + [char_to_int[c] for c in villes[i]] + [char_to_int['<EOS>']] + [char_to_int['<pad>']] * (max_len - len(villes[i]) - 2))

n_split = int(0.9*X.shape[0])

idx_permut = torch.randperm(X.shape[0])
idx_train, _ = torch.sort(idx_permut[:n_split])
idx_val, _ = torch.sort(idx_permut[n_split:])

X_train = X[idx_train]
X_val = X[idx_val]

In [5]:
def get_batch(split, batch_size):
    data = X_train if split == 'train' else X_val

    idx_seed = torch.randint(low=int(batch_size/2), high=int(data.shape[0]-batch_size/2), size=(1,), dtype=torch.int32).item()

    batch = data[int(idx_seed-batch_size/2):int(idx_seed+batch_size/2)]
    X = batch[:, :-1] # (B, L=max_len-1=46)
    Y = batch[:, 1:] # (B, L)
    return X, Y.long()

In [32]:
config = TransformerConfig(d_model=128, n_layers=1, n_heads=4, max_len=max_len, dropout=0.)
model = LM(config, vocab_size=len(vocabulaire))
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [33]:
sum(p.numel() for p in model.parameters())

274432

In [35]:
for i in range(1000):
    X, Y = get_batch('train', 32) # (B, L)
    logits = model(X) # (B, L, vocab_size)

    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=char_to_int['<pad>'])
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if i%100==0:
        print(loss)

tensor(1.9733, grad_fn=<NllLossBackward0>)
tensor(1.8507, grad_fn=<NllLossBackward0>)
tensor(1.8540, grad_fn=<NllLossBackward0>)
tensor(1.8158, grad_fn=<NllLossBackward0>)
tensor(1.9138, grad_fn=<NllLossBackward0>)
tensor(1.5096, grad_fn=<NllLossBackward0>)
tensor(1.7285, grad_fn=<NllLossBackward0>)
tensor(1.8739, grad_fn=<NllLossBackward0>)
tensor(1.9907, grad_fn=<NllLossBackward0>)
tensor(1.8522, grad_fn=<NllLossBackward0>)


In [41]:
X, Y = get_batch('val', 16) # (B, L)
logits = model(X) # (B, L, vocab_size)

loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=char_to_int['<pad>'])
loss

tensor(1.7709, grad_fn=<NllLossBackward0>)

In [42]:
torch.save(model.state_dict(), "transformer.pth")