# This notebook will serve as a way to implement character generation MHA

In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import re
import numpy as np
import pickle
import random
from torch.utils.data import random_split
import torch.nn.functional as F
from torchtune.modules import RotaryPositionalEmbeddings

In [2]:
with open("/kaggle/input/rnn-input/encoding_map.pkl", "rb") as f:
    mapping = pickle.load(f)

mapping["PAD"] = len(mapping)

# Data preparation

In [3]:
# Decode
int2char = {i: ch for ch, i in mapping.items()}
print(int2char)

nb_char = len(int2char)

{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '%', 5: '&', 6: "'", 7: ')', 8: '+', 9: ',', 10: '-', 11: '.', 12: '/', 13: '0', 14: '1', 15: '2', 16: '3', 17: '4', 18: '5', 19: '6', 20: '7', 21: '8', 22: '9', 23: ':', 24: ';', 25: '?', 26: 'a', 27: 'b', 28: 'c', 29: 'd', 30: 'e', 31: 'f', 32: 'g', 33: 'h', 34: 'i', 35: 'j', 36: 'k', 37: 'l', 38: 'm', 39: 'n', 40: 'o', 41: 'p', 42: 'q', 43: 'r', 44: 's', 45: 't', 46: 'u', 47: 'v', 48: 'w', 49: 'x', 50: 'y', 51: 'z', 52: 'à', 53: 'â', 54: 'ç', 55: 'è', 56: 'é', 57: 'ê', 58: 'ë', 59: 'î', 60: 'ï', 61: 'ô', 62: 'ù', 63: 'û', 64: 'α', 65: 'β', 66: 'γ', 67: 'ε', 68: 'ζ', 69: 'η', 70: 'θ', 71: '€', 72: 'PAD'}


## Creation of the dataset

To increase the randomness during the training :

For each epoch the entire corpus will have a random specific offset value in order that the model during training doesn't see the exact same text during X epochs.

In [4]:
class SongDataset(Dataset):
    def __init__(self, texts, length_seq, stride, pad_id=mapping["PAD"], use_offset=True):
        self.samples = []
        self.length_seq = length_seq
        self.stride = stride
        self.pad_id = pad_id

        for text in texts:
            # --- Sécurité : s'assurer d'un tensor 1D long CPU ---
            if not isinstance(text, torch.Tensor):
                text = torch.tensor(text, dtype=torch.long)
            else:
                text = text.clone().detach().to(dtype=torch.long, device="cpu").contiguous()

            L = len(text)
            if L < 2:
                continue

            offset = torch.randint(0, stride, (1,)).item() if use_offset else 0

            # --- Boucle principale ---
            for start in range(offset, max(1, L - self.length_seq - 1), self.stride):
                x_start, x_end = start, start + self.length_seq
                y_start, y_end = start + 1, start + 1 + self.length_seq

                x = text[x_start:x_end]
                y = text[y_start:y_end]

                # --- Padding uniforme ---
                def pad_to_len(seq, pad_id, target_len):
                    pad_len = target_len - len(seq)
                    if pad_len > 0:
                        seq = torch.cat([seq, torch.full((pad_len,), pad_id, dtype=seq.dtype)])
                    return seq

                x = pad_to_len(x, self.pad_id, self.length_seq)
                y = pad_to_len(y, self.pad_id, self.length_seq)

                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [5]:
dataset_ = np.load("/kaggle/input/rnn-input/corpora_encoded.npy","r")

result = []
for t in dataset_:
    if t == 64 : 
        current = []
        current.append(t)
    elif t == 70:
        current.append(t)
        result.append(torch.tensor(current))
    else :
        current.append(t)
if current:  
    result.append(torch.tensor(current))

In [6]:
len_train = int(len(result) * 0.8)
len_test = len(result) - len_train

generator = torch.Generator().manual_seed(42)
train, test = random_split(result, [len_train, len_test], generator = generator)

In [7]:
train_ds = SongDataset(train,length_seq=256, stride = 64, use_offset = True)
test_ds = SongDataset(test,length_seq=256, stride = 64, use_offset = False)

In [8]:
print("".join([int2char[i] for i in train_ds[0][0][:256].numpy()]))
print()
print("".join([int2char[i] for i in train_ds[1][0][:256].numpy()]))

ans tes yeux je peux le voir
tu penses sûrement qu'j'suis plus heureux que toi
rien ne dure tout est éphémère frérot
j'ai rien à cacher comme les femens
avant j'avais qu'des p'tites pièces comme le passeur d'âme
et là je brunch avec madame à amsterdam
fair

eureux que toi
rien ne dure tout est éphémère frérot
j'ai rien à cacher comme les femens
avant j'avais qu'des p'tites pièces comme le passeur d'âme
et là je brunch avec madame à amsterdam
faire plaisir à ceux qu'on aime ça n'a pas d'prix
pour tout le reste


## Creation of Dataloader and co

In [9]:
batch_size = 516

train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", shuffle=False, drop_last=True) #Shuffle False because we need the RNN to use previous sequences data to predict next one
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", shuffle=False, drop_last=True)

In [10]:
len(dataset_),len(train_dl)

(3705039, 80)

## Models

### Training part

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, nheads, dropout, max_seq, bias=True):
        super().__init__()

        self.nheads = nheads
        assert embed_dim % nheads == 0, "Embedding dim is not divisible by nheads"
        self.head_dim = embed_dim // nheads
        self.dropout = dropout
        
        self.proj_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len = max_seq)

    def forward(self, x: torch.Tensor, attn_mask=None, is_causal=False) -> torch.Tensor:
        
        # Step 1
        B,L,_ = x.shape
        result = self.proj_qkv(x)
        q, k, v = torch.chunk(result, 3, dim=-1)

        # Step 2
        # (N, L_t, head_dim) -> (N, L_t, nheads, head_dim) -> (N, nheads, L_t, head_dim)
        q = q.view(B, L, self.nheads, self.head_dim)
        k = k.view(B, L, self.nheads, self.head_dim)
        v = v.view(B, L, self.nheads, self.head_dim)

#        q = self.rope(q) #Doesn't need Rope / overkill because char didn't have the same relation
#        k = self.rope(k)

        #Adapt dim 
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        #Local attention window = 64
        mask = (torch.tril(torch.ones([L,L]), diagonal=-64)+torch.tril(torch.ones([L,L])))==1
        attn_mask = ~mask.to(device)

        # Step 3
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout, attn_mask = attn_mask, is_causal=is_causal)
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        return self.out_proj(attn_output)

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, max_seq, d_ff=256*4, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout, max_seq)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # masked multi-head self-attention
        x = x + self.attn(self.norm1(x), is_causal=False) #Is causal = for token t, mask on every tokens after (cannot see what's coming after)
        # feed-forward
        x = x + self.ff(self.norm2(x))
        return x

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=4, max_len=256):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, n_heads, max_seq=max_len) for _ in range(n_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        B, T = x.shape
        x = self.emb(x) 
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.fc(x)

In [13]:
device = torch.device("cuda:0")

In [14]:
vocab_size = len(mapping)
num_epoch = 100

nb_step_train = len(train_dl)
nb_step_test = len(test_dl)

model = DecoderOnlyTransformer(vocab_size, d_model=256, n_heads=4, n_layers=8).to(device) 
model = torch.compile(model)
 
loss_fn = nn.CrossEntropyLoss(ignore_index=72)
val_fn = nn.CrossEntropyLoss(ignore_index=72)

opti = torch.optim.AdamW(model.parameters(), lr=0.0025, weight_decay=1e-3)
sched_warm = torch.optim.lr_scheduler.LinearLR(opti, start_factor=0.2, end_factor=1, total_iters=nb_step_train*5)
#sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*10, T_mult=2, eta_min=0.0005) #1 epoch => 2 => 4 => 8

In [15]:
import math

@torch.no_grad()
def evaluate_tf1(model, dl, loss_fn, device, vocab_size):
    """Validation with teacher forcing = 1 (parallel, fast). Returns (ppl, acc)."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    correct = 0
    total = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)

        with torch.amp.autocast(device_type="cuda"):
            pred = model(X)  # (bs, sl, vocab)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        total_loss += loss.item() * bs * sl
        total_tokens += bs * sl

        pred_ids = pred.argmax(dim=-1)
        correct += (pred_ids == Y).sum().item()
        total += bs * sl

    ppl = math.exp(total_loss / max(1, total_tokens))
    acc = correct / max(1, total)
    return ppl, acc


@torch.no_grad()
def evaluate_free(model, dl, loss_fn, device):
    """
    Autoregressive validation (teacher forcing = 0).
    Steps one token at a time and feeds predictions back in.
    Returns ppl (computed on next-token NLL).
    """
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)

        # Start with the first input token
        inp = X[:, :1]  # (bs, 1)
        for t in range(sl):
            with torch.amp.autocast(device_type="cuda"):
                pred = model(inp)          # (bs, 1, vocab)
                logits = pred[:, -1, :]              # (bs, vocab)
                loss = loss_fn(logits, Y[:, t])      # CE over current step

            total_loss += loss.item() * bs
            total_tokens += bs

            # Greedy next-token to feed back in
            next_token = logits.argmax(dim=-1).unsqueeze(1)  # (bs, 1)
            inp = next_token

    ppl = math.exp(total_loss / max(1, total_tokens))
    return ppl

def sample_with_temp(logits, temp=1.0):
    probs = (logits / temp).softmax(dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def distinct_n_chars(text, n=3):
    ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
    return len(set(ngrams)) / max(1, len(ngrams))

In [16]:
#opti = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=1e-3)
#sched_warm = torch.optim.lr_scheduler.LinearLR(opti, start_factor=0.2, end_factor=1, total_iters=nb_step_train*5)
#sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*10, T_mult=2, eta_min=0.0001) #1 epoch => 2 => 4 => 8

In [17]:
#opti = torch.optim.AdamW(model.parameters(), lr=0.0015, weight_decay=1e-3)

In [18]:
from time import time

In [None]:
l_tot = []
best_val = float("inf")

print(time())

scaler = torch.amp.GradScaler()
import math
for epoch in range(num_epoch) :
    
    #Offset the datas
#    train_ds.set_offset(offset); test_ds.set_offset(offset)

    model.train();

    #Create loss per epoch
    l_train = 0.0
    l_test = 0.0
    
    for X,Y in iter(train_dl) :
        X = X.to(device); Y= Y.to(device, dtype=torch.long)
        opti.zero_grad(set_to_none=True)
        
        #Computation of model        
        with torch.amp.autocast(device_type="cuda:0"):
            pred = model(X)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        scaler.step(opti); scaler.update()
        l_train += loss.detach().item()

        #Scheduler part
        #Warm start
        
#        if sched_post.T_cur == 0 and epoch > 6:  #After warm restart decrease the max learning rate
#            sched_post.base_lrs[0] = sched_post.base_lrs[0] * 0.8
#            sched_post.eta_min = sched_post.eta_min * 1.5
#            print(f"Decrease {sched_post.base_lrs[0]}, {sched_post.eta_min}")

#        step_scheduler = sched_warm if epoch < 5 else sched_post
        if epoch < 5 :
            sched_warm.step()

    # -------------- VALIDATION --------------
    val_ppl_tf1, val_acc = evaluate_tf1(model, test_dl, val_fn, device, vocab_size)
#    val_ppl_free = evaluate_free(model, test_dl, val_fn, device)

    train_ppl = math.exp(l_train/nb_step_train)
    print(
        f"Epoch {epoch} | "
        f"Train PPL: {train_ppl:.3f} | "
        f"Val PPL (TF=1): {val_ppl_tf1:.3f} | "
#        f"Val PPL (free): {val_ppl_free:.3f} | "
        f"Val Acc: {val_acc:.3f}"
    )

    # --------- Sample generation + diversity metrics ---------
    if epoch % 4 == 0 :
        model.eval()
        with torch.no_grad():
        # prends les 5 premiers chars du batch courant comme seed
            start = X[0:1, :20]  
            inp = start

            gen_chars = []
            for t in range(200):  # génère 200 caractères
                pred = model(inp)
                logits = pred[:, -1, :]  # dernier pas
                next_char = sample_with_temp(logits, temp=0.4)
                gen_chars.append(int2char[next_char.item()])
                inp = next_char # feed back

            gen_text = "".join(gen_chars)

        d2 = distinct_n_chars(gen_text, n=2)
        d3 = distinct_n_chars(gen_text, n=3)
    
        print("\n=== Initial text ===")
        print("".join([int2char[i] for i in X[0:1,:].squeeze(0).tolist()]))
        print("\n=== Sample Generation ===")
        print(gen_text[:200])  # affiche les 200 premiers chars
        print(f"Distinct-2: {d2:.3f} | Distinct-3: {d3:.3f}", end="\n")

    # Record accuracy
    if val_ppl_tf1 < best_val :
        best_val = val_ppl_tf1
        l_tot.append(val_acc)
        torch.save(
                {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opti.state_dict(),
#                "scheduler_state_dict": sched_post.state_dict(),
                "val_ppl": val_ppl_tf1,
            },
            "model_forced",
        )
print(time())

1760795172.5968385


W1018 13:46:24.569000 36 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch 0 | Train PPL: 13.501 | Val PPL (TF=1): 11.074 | Val Acc: 0.282

=== Initial text ===
n'fait pas l'man
faut parfois s'debrouiller seul
ya pas d'famille au parloir
pour s'installer au sommet
car l'paradis n'est pas louable
on connait les mauvaises passes
les frappes un deux trois sommeil
on essaye d'pas faire de mal au gens
d'pas mettre la s

=== Sample Generation ===
s d'aile de l's l'a mes de des pont j'd's r les le j'a j'pais s pon de ples l'ons l'eur t d'le d'e qurais de din quns faillan vis pas llentoure condis c'vinde prge d'a de mman de li la les de j'es lai
Distinct-2: 0.347 | Distinct-3: 0.662
Epoch 1 | Train PPL: 10.898 | Val PPL (TF=1): 10.685 | Val Acc: 0.292
Epoch 2 | Train PPL: 10.623 | Val PPL (TF=1): 10.498 | Val Acc: 0.294
Epoch 3 | Train PPL: 10.435 | Val PPL (TF=1): 10.364 | Val Acc: 0.296
Epoch 4 | Train PPL: 10.269 | Val PPL (TF=1): 10.220 | Val Acc: 0.296

=== Initial text ===
n'fait pas l'man
faut parfois s'debrouiller seul
ya pas d'famille au parloir
pour s

3.655