In [29]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [30]:
# chargement des données
fichier = open('villes.txt')
donnees = fichier.read()
villes = donnees.replace('\n', ',').split(',')
villes = [ville for ville in villes if len(ville) > 2]

In [150]:
# création du vocabulaire

vocabulaire = sorted(list(set(''.join(villes))))
vocabulaire = ["<pad>", "<SOS>", "<EOS>"] + vocabulaire

# pour convertir char <-> int
char_to_int = {}
int_to_char = {}

for (c, i) in zip(vocabulaire, range(len(vocabulaire))):
    char_to_int[c] = i
    int_to_char[i] = c

In [151]:
num_sequences = len(villes)
max_len = max([len(ville) for ville in villes]) + 2 # account for <SOS> and <EOS>, which are appened at the beg. and end of each seq

X = torch.zeros((num_sequences, max_len), dtype=torch.int32)

for i in range(num_sequences):
    X[i] = torch.tensor([char_to_int['<SOS>']] + [char_to_int[c] for c in villes[i]] + [char_to_int['<EOS>']] + [char_to_int['<pad>']] * (max_len - len(villes[i]) - 2))

n_split = int(0.9*X.shape[0])

idx_permut = torch.randperm(X.shape[0])
idx_train, _ = torch.sort(idx_permut[:n_split])
idx_val, _ = torch.sort(idx_permut[n_split:])

X_train = X[idx_train]
X_val = X[idx_val]

In [200]:
def get_batch(split, batch_size):
    data = X_train if split == 'train' else X_val

    idx_seed = torch.randint(low=int(batch_size/2), high=int(data.shape[0]-batch_size/2), size=(1,), dtype=torch.int32).item()

    batch = data[int(idx_seed-batch_size/2):int(idx_seed+batch_size/2)]
    X = batch[:, :-1] # (B, L=max_len-1=46)
    Y = batch[:, 1:] # (B, L)
    return X, Y.long()

In [462]:
class BengioMLP(nn.Module):
    def __init__(self, d_model, d_hidden, n_context, vocabulaire):
        super().__init__()

        self.vocabulaire = vocabulaire
        self.n_context = n_context

        self.embed = nn.Embedding(len(vocabulaire), d_model)

        self.fc1 = nn.Linear(n_context * d_model, d_hidden)
        self.fc2 = nn.Linear(d_hidden, len(vocabulaire))

    def forward(self, idx):
        embeddings = []
        for _ in range(self.n_context):
            embd = self.embed(idx)
            idx = torch.roll(idx, 1, 1)
            idx[:, 0] = char_to_int['<SOS>']
            embeddings.append(embd)

        embeddings = torch.cat(embeddings, -1) # (B, L, n_context*d_model)

        x = F.tanh(self.fc1(embeddings)) # (B, L, d_hidden)
        logits = self.fc2(x) # (B, L, vocab_size)

        return logits
    
    def sample(self, prompt = "", g = torch.Generator()):
        idx = torch.tensor([char_to_int[c] for c in prompt], dtype=torch.int32).unsqueeze(0)
        idx = torch.cat([torch.tensor(char_to_int['<SOS>']).view(1, 1), idx], dim=1)
        next_id = -1

        while next_id != char_to_int['<EOS>']:
            idx_cond = idx if idx.size(1) <= self.n_context else idx[:, -self.n_context:]
            #print(idx_cond.shape)
            logits = self.forward(idx_cond) # (1, l, vocab_size)
            probs = F.softmax(logits[:, -1, :], dim=-1)
            next_id = torch.multinomial(probs, num_samples=1, generator=g).item()
            idx = torch.cat([idx, torch.tensor(next_id).view(1, 1)], dim=1)
            #print("idx.shape : ", idx.shape)
        
        return "".join([int_to_char[p.item()] for p in idx[0, :-1]])

In [477]:
model = BengioMLP(d_model=2, d_hidden=100, n_context=3, vocabulaire=vocabulaire)
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [478]:
for i in range(10000):
    X, Y = get_batch('train', 16) # (B, L)
    logits = model(X) # (B, L, vocab_size)

    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=char_to_int['<pad>'])
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if i%1000==0:
        print(loss)


tensor(3.8545, grad_fn=<NllLossBackward0>)
tensor(2.4990, grad_fn=<NllLossBackward0>)
tensor(2.2816, grad_fn=<NllLossBackward0>)
tensor(2.3737, grad_fn=<NllLossBackward0>)
tensor(2.2873, grad_fn=<NllLossBackward0>)
tensor(2.1053, grad_fn=<NllLossBackward0>)
tensor(2.0822, grad_fn=<NllLossBackward0>)
tensor(2.1053, grad_fn=<NllLossBackward0>)
tensor(2.2142, grad_fn=<NllLossBackward0>)
tensor(2.0799, grad_fn=<NllLossBackward0>)


In [479]:
X, Y = get_batch('val', 16) # (B, L)
logits = model(X) # (B, L, vocab_size)

loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=char_to_int['<pad>'])
loss

tensor(1.9785, grad_fn=<NllLossBackward0>)

In [480]:
model.sample("saint")

'<SOS>saint-re-de-ourt'