In [8]:
import torch
import torch.nn.functional as F

from models.mamba.mamba import MambaConfig
from models.lm import LM

from data import Dataset

In [9]:
device = "cpu" # mettre "cuda" si gpu

d_model = 512 # dimension du modèle
n_layers = 1 # nombre de couches
dropout = 0.

lr = 3e-4
batch_size = 64 # mamba.py est gourmand en RAM, ne pas hésiter à baisser

In [10]:
dataset = Dataset(device=device) # toute la partie données de 2_mlp.py a été encapsulée dans l'objet Dataset

config = MambaConfig(d_model=d_model, n_layers=n_layers)
model = LM(config, vocab_size=len(dataset.vocabulaire)).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=lr)

print(f"Nombre de paramètres : {sum(p.numel() for p in model.parameters())}")

Nombre de paramètres : 1719296


In [4]:
for i in range(10000):
    X, Y = dataset.get_batch('train', batch_size) # (B, L)
    logits = model(X) # (B, L, vocab_size)

    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>'])
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if i%1000==0:
        X, Y = dataset.get_batch('test', batch_size) # (B, L)
        logits = model(X) # (B, L, vocab_size)
        val_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>']).item()

        print(f"train loss : {loss.item():.2f} | val loss : {val_loss:.2f}")

train loss : 3.84 | val loss : 3.83
train loss : 2.07 | val loss : 2.09
train loss : 1.95 | val loss : 1.92
train loss : 1.93 | val loss : 2.02
train loss : 1.87 | val loss : 1.83
train loss : 1.86 | val loss : 1.94
train loss : 1.87 | val loss : 1.85
train loss : 1.89 | val loss : 1.93
train loss : 1.76 | val loss : 1.92
train loss : 1.83 | val loss : 1.83


In [14]:
X = dataset.X_train[:, :-1].to(device)[:200] # (B, L=max_len-1=46)
Y = dataset.X_train[:, 1:].long().to(device)[:200] # (B, L)

logits = model(X) # (B, L, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>'])
print(f"total train loss : {loss.item():.2f}")

X = dataset.X_val[:, :-1].to(device) # (B, L=max_len-1=46)
Y = dataset.X_val[:, 1:].long().to(device) # (B, L)

logits = model(X) # (B, L, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=dataset.char_to_int['<pad>'])
print(f"total val loss   : {loss.item():.2f}")

total train loss : 1.79
total val loss   : 1.85


In [15]:
torch.save(model.state_dict(), f"mamba_d_model{d_model}_n_layers{n_layers}.pth")

## génération

In [4]:
model.load_state_dict(torch.load('mamba_d_model512.pth', map_location='cpu'))

<All keys matched successfully>

In [5]:
def sample(self, prompt = "", g = torch.Generator(device)):
    idx = torch.tensor([dataset.char_to_int[c] for c in prompt], dtype=torch.int32, device=device).unsqueeze(0)
    idx = torch.cat([torch.tensor(dataset.char_to_int['<SOS>'], device=device).view(1, 1), idx], dim=1)
    next_id = -1

    while next_id != dataset.char_to_int['<EOS>']:
        logits = self(idx) # (1, l, d_model)

        probs = F.softmax(logits[:, -1, :], dim=-1)
        next_id = torch.multinomial(probs, num_samples=1, generator=g).item()
        idx = torch.cat([idx, torch.tensor(next_id, device=device).view(1, 1)], dim=1)
        
    return "".join([dataset.int_to_char[p.item()] for p in idx[0, 1:-1]])

In [7]:
g = torch.Generator(device).manual_seed(123456789+4)

for _ in range(50):
    #print(sample(model, g=g))

    ville = sample(model, g=g)
    if ville in dataset.villes:
        print(ville)

quessigny
saint-amant
maxou
triqueville
chaux
ognéville
blécourt
le charme
langogne
brion
