# This notebook will serve as a way to implement character generation LSTM 

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import re
import numpy as np
import pickle
import random
from torch.utils.data import random_split

In [None]:
with open("/kaggle/input/rnn-input/encoding_map.pkl", "rb") as f:
    mapping = pickle.load(f)

mapping["PAD"] = len(mapping)

# Data preparation

In [None]:
# Decode
int2char = {i: ch for ch, i in mapping.items()}
print(int2char)

nb_char = len(int2char)

## Creation of the dataset

In [None]:
class SongDataset(Dataset):
    def __init__(self, texts, length_seq, stride, pad_id=mapping["PAD"], use_offset=True):
        self.samples = []
        self.length_seq = length_seq
        self.stride = stride
        self.pad_id = pad_id

        for text in texts:
            # --- Sécurité : s'assurer d'un tensor 1D long CPU ---
            if not isinstance(text, torch.Tensor):
                text = torch.tensor(text, dtype=torch.long)
            else:
                text = text.clone().detach().to(dtype=torch.long, device="cpu").contiguous()

            L = len(text)
            if L < 2:
                continue

            offset = torch.randint(0, stride, (1,)).item() if use_offset else 0

            # --- Boucle principale ---
            for start in range(offset, max(1, L - self.length_seq - 1), self.stride):
                x_start, x_end = start, start + self.length_seq
                y_start, y_end = start + 1, start + 1 + self.length_seq

                x = text[x_start:x_end]
                y = text[y_start:y_end]

                # --- Padding uniforme ---
                def pad_to_len(seq, pad_id, target_len):
                    pad_len = target_len - len(seq)
                    if pad_len > 0:
                        seq = torch.cat([seq, torch.full((pad_len,), pad_id, dtype=seq.dtype)])
                    return seq

                x = pad_to_len(x, self.pad_id, self.length_seq)
                y = pad_to_len(y, self.pad_id, self.length_seq)

                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
dataset_ = np.load("/kaggle/input/rnn-input/corpora_encoded.npy","r")

result = []
for t in dataset_:
    if t == 64 : 
        current = []
        current.append(t)
    elif t == 70:
        current.append(t)
        result.append(torch.tensor(current))
    else :
        current.append(t)
if current:  
    result.append(torch.tensor(current))

In [None]:
len_train = int(len(result) * 0.8)
len_test = len(result) - len_train

generator = torch.Generator().manual_seed(42)
train, test = random_split(result, [len_train, len_test], generator = generator)

In [None]:
train_ds = SongDataset(train,length_seq=256, stride = 64, use_offset = True)
test_ds = SongDataset(test,length_seq=256, stride = 64, use_offset = False)

In [None]:
print("".join([int2char[i] for i in train_ds[0][0][:256].numpy()]))
print()
print("".join([int2char[i] for i in train_ds[1][0][:256].numpy()]))

## Creation of Dataloader and co

In [None]:
batch_size = 1024

train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", shuffle=False, drop_last=True) #Shuffle False because we need the RNN to use previous sequences data to predict next one
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", shuffle=False, drop_last=True)

## Models

### Training part

In [None]:
class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=1):
        super(CharLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(input_size = embedding_dim, 
                            hidden_size = hidden_size, 
                            num_layers = num_layers, 
                            batch_first=True, dropout = 0.15)
        self.drop = nn.Dropout(p=0.15)
        self.ln = nn.LayerNorm(hidden_size)

#        self.proj = nn.Linear(hidden_size, embedding_dim, bias=False)
        self.fc = nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, x, hidden):
        x = self.drop(self.embedding(x))              # (batch, seq, hidden_size)
        out, hidden = self.lstm(x, hidden)
        out = self.drop(self.ln(out))
        logits = self.fc(out)                  
        return logits, hidden

    def init_hidden(self, batch_size, device):
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device = device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device = device)
        return (h0,c0)

In [None]:
device1 = torch.device("cuda:0")

In [None]:
embedding_dim = 128
hidden_size = 512
vocab_size = len(mapping)
num_epoch = 100

nb_step_train = len(train_dl)
nb_step_test = len(test_dl)

model = CharLSTM(vocab_size, embedding_dim, hidden_size, num_layers=3).to(device1)
model = torch.compile(model)

#weights = torch.ones(vocab_size).to(device1)

#Structure marker (really important)
#for p in [" ", "\n"]:
#    idx = mapping[p]
#    weights[idx] = 1.25  

#Ponctuation (important) + part indic
#for p in [",", ".", "'"]:
#    idx = mapping[p]
#    weights[idx] = 1.1  

loss_fn = nn.CrossEntropyLoss(ignore_index=72)
val_fn = nn.CrossEntropyLoss(ignore_index=72)

opti = torch.optim.AdamW(model.parameters(), lr=0.003, weight_decay=1e-4)
sched_warm = torch.optim.lr_scheduler.LinearLR(opti, start_factor=0.2, end_factor=1.0, total_iters=nb_step_train*5)
sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*10, T_mult=2, eta_min=0.001) #1 epoch => 2 => 4 => 8

In [None]:
import torch

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, "Invalid rho"
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        scale = [g['rho'] / (grad_norm + 1e-12) for g in self.param_groups]
        for group, s in zip(self.param_groups, scale):
            for p in group['params']:
                if p.grad is None: continue
                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * s.to(p)
                p.add_(e_w)
                self.state[p]['e_w'] = e_w
        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None: continue
                p.sub_(self.state[p]['e_w'])
        self.base_optimizer.step()
        if zero_grad:
            self.zero_grad()

    def _grad_norm(self):
        device = self.param_groups[0]['params'][0].device
        norms = []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    scale = torch.abs(p) if group['adaptive'] else 1.0
                    norms.append((scale * p.grad).norm(p=2))
        return torch.norm(torch.stack(norms), p=2)


In [None]:
import math

@torch.no_grad()
def evaluate_tf1(model, dl, loss_fn, device, vocab_size):
    """Validation with teacher forcing = 1 (parallel, fast). Returns (ppl, acc)."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    correct = 0
    total = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)
        hid = model.init_hidden(bs, device1)

        with torch.amp.autocast(device_type="cuda"):
            pred, _ = model(X, hid)  # (bs, sl, vocab)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        total_loss += loss.item() * bs * sl
        total_tokens += bs * sl

        pred_ids = pred.argmax(dim=-1)
        correct += (pred_ids == Y).sum().item()
        total += bs * sl

    ppl = math.exp(total_loss / max(1, total_tokens))
    acc = correct / max(1, total)
    return ppl, acc


@torch.no_grad()
def evaluate_free(model, dl, loss_fn, device):
    """
    Autoregressive validation (teacher forcing = 0).
    Steps one token at a time and feeds predictions back in.
    Returns ppl (computed on next-token NLL).
    """
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)
        hid = model.init_hidden(bs, device1)

        # Start with the first input token
        inp = X[:, :1]  # (bs, 1)
        for t in range(sl):
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(inp, hid)          # (bs, 1, vocab)
                logits = pred[:, -1, :]              # (bs, vocab)
                loss = loss_fn(logits, Y[:, t])      # CE over current step

            total_loss += loss.item() * bs
            total_tokens += bs

            # Greedy next-token to feed back in
            next_token = logits.argmax(dim=-1).unsqueeze(1)  # (bs, 1)
            inp = next_token

    ppl = math.exp(total_loss / max(1, total_tokens))
    return ppl

def sample_with_temp(logits, temp=1.0):
    probs = (logits / temp).softmax(dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def distinct_n_chars(text, n=3):
    ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
    return len(set(ngrams)) / max(1, len(ngrams))

In [None]:
#base_optimizer = torch.optim.AdamW  
#opti = SAM(model.parameters(), base_optimizer, lr=0.0005, rho=0.05)

In [None]:
l_tot = []
teacher_forcing_ratio = 1
best_val = float("inf")

scaler = torch.amp.GradScaler()

for epoch in range(num_epoch):

    # --- Teacher Forcing ratio decay (linear) ---
#    if epoch % 2 == 0 :
#        teacher_forcing_ratio = max(0.0, min(1.0, teacher_forcing_ratio - 0.02))
#        print(f"\nEpoch {epoch} | Teacher forcing ratio = {teacher_forcing_ratio:.2f}")

    model.train()

    # -------------- TRAIN LOOP --------------
    train_loss = 0.0

    for X, Y in iter(train_dl):
        hid = model.init_hidden(batch_size, device1)
        X = X.to(device1)
        Y = Y.to(device1, dtype=torch.long)
        opti.zero_grad(set_to_none=True)

        if opti.__class__.__name__ == "SAM" :
            def closure():
                opti.zero_grad(set_to_none=True)
                with torch.amp.autocast(device_type="cuda"):
                    pred, hid_ = model(X, hid)
                    loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))
                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                return loss

        # ----- SAM two-step -----
        # 1️⃣ First (ascent) step – move slightly uphill
            loss = closure()
            scaler.unscale_(opti.base_optimizer)  # unscale before the manual step
            opti.first_step(zero_grad=True)

        # 2️⃣ Second (descent) step – compute new grad at perturbed point
            closure()
            opti.second_step(zero_grad=True)
            scaler.update()

        else : 
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(X, hid)
                loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            scaler.step(opti)
            scaler.update()

            if sched_post.T_cur == 0 and epoch > 9:  #After warm restart decrease the max learning rate
                sched_post.base_lrs[0] = sched_post.base_lrs[0] * 0.85
                print(f"Decrease {sched_post.base_lrs[0]}, {sched_post.eta_min}")

            if epoch < 8:
                sched_warm.step()
            else:
                sched_post.step()
        train_loss += loss.detach().item() 

    train_ppl = math.exp(train_loss / nb_step_train)

    # -------------- VALIDATION --------------
    val_ppl_tf1, val_acc = evaluate_tf1(model, test_dl, val_fn, device1, vocab_size)
    val_ppl_free = evaluate_free(model, test_dl, val_fn, device1)

    print(
        f"Epoch {epoch} | "
        f"Train PPL: {train_ppl:.3f} | "
        f"Val PPL (TF=1): {val_ppl_tf1:.3f} | "
        f"Val PPL (free): {val_ppl_free:.3f} | "
        f"Val Acc: {val_acc:.3f}"
    )

    # --------- Sample generation + diversity metrics ---------
    if epoch % 5 == 0 :
        model.eval()
        with torch.no_grad():
        # prends les 5 premiers chars du batch courant comme seed
            start = X[0:1, :20]  
            hid_gen = model.init_hidden(1,device1)
            inp = start

            gen_chars = []
            for t in range(200):  # génère 200 caractères
                pred, hid_gen = model(inp, hid_gen)
                logits = pred[:, -1, :]  # dernier pas
                next_char = sample_with_temp(logits, temp=0.6)
                gen_chars.append(int2char[next_char.item()])
                inp = next_char # feed back

            gen_text = "".join(gen_chars)

        d2 = distinct_n_chars(gen_text, n=2)
        d3 = distinct_n_chars(gen_text, n=3)

    
        print("\n=== Initial text ===")
        print("".join([int2char[i] for i in X[0:1,:].squeeze(0).tolist()]))
        print("\n=== Sample Generation ===")
        print(gen_text[:200])  # affiche les 200 premiers chars
        print(f"Distinct-2: {d2:.3f} | Distinct-3: {d3:.3f}", end="\n")

    if val_ppl_tf1 < best_val :
        best_val = val_ppl_tf1
        l_tot.append(val_acc)
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opti.state_dict(),
                "scheduler_state_dict": sched_post.state_dict(),
                "val_ppl": val_ppl_tf1,
            },
            "model",
        )

In [None]:
RNN : Val PPL (TF=1): 3.6423183370497765