# This notebook will serve as a way to implement character generation RNN 

In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import re
import numpy as np
import pickle
import random
from torch.utils.data import random_split

In [2]:
with open("/kaggle/input/rnn-input/encoding_map.pkl", "rb") as f:
    mapping = pickle.load(f)

mapping["PAD"] = len(mapping)

# Data preparation

In [3]:
# Decode
int2char = {i: ch for ch, i in mapping.items()}
print(int2char)

nb_char = len(int2char)

{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '%', 5: '&', 6: "'", 7: ')', 8: '+', 9: ',', 10: '-', 11: '.', 12: '/', 13: '0', 14: '1', 15: '2', 16: '3', 17: '4', 18: '5', 19: '6', 20: '7', 21: '8', 22: '9', 23: ':', 24: ';', 25: '?', 26: 'a', 27: 'b', 28: 'c', 29: 'd', 30: 'e', 31: 'f', 32: 'g', 33: 'h', 34: 'i', 35: 'j', 36: 'k', 37: 'l', 38: 'm', 39: 'n', 40: 'o', 41: 'p', 42: 'q', 43: 'r', 44: 's', 45: 't', 46: 'u', 47: 'v', 48: 'w', 49: 'x', 50: 'y', 51: 'z', 52: 'à', 53: 'â', 54: 'ç', 55: 'è', 56: 'é', 57: 'ê', 58: 'ë', 59: 'î', 60: 'ï', 61: 'ô', 62: 'ù', 63: 'û', 64: 'α', 65: 'β', 66: 'γ', 67: 'ε', 68: 'ζ', 69: 'η', 70: 'θ', 71: '€', 72: 'PAD'}


## Creation of the dataset

To increase the randomness during the training :

For each epoch the entire corpus will have a random specific offset value in order that the model during training doesn't see the exact same text during X epochs.

In [4]:
class SongDataset(Dataset):
    def __init__(self, texts, length_seq, stride, pad_id=mapping["PAD"], use_offset=True):
        self.samples = []
        self.length_seq = length_seq
        self.stride = stride
        self.pad_id = pad_id

        for text in texts:
            # --- Sécurité : s'assurer d'un tensor 1D long CPU ---
            if not isinstance(text, torch.Tensor):
                text = torch.tensor(text, dtype=torch.long)
            else:
                text = text.clone().detach().to(dtype=torch.long, device="cpu").contiguous()

            L = len(text)
            if L < 2:
                continue

            offset = torch.randint(0, stride, (1,)).item() if use_offset else 0

            # --- Boucle principale ---
            for start in range(offset, max(1, L - self.length_seq - 1), self.stride):
                x_start, x_end = start, start + self.length_seq
                y_start, y_end = start + 1, start + 1 + self.length_seq

                x = text[x_start:x_end]
                y = text[y_start:y_end]

                # --- Padding uniforme ---
                def pad_to_len(seq, pad_id, target_len):
                    pad_len = target_len - len(seq)
                    if pad_len > 0:
                        seq = torch.cat([seq, torch.full((pad_len,), pad_id, dtype=seq.dtype)])
                    return seq

                x = pad_to_len(x, self.pad_id, self.length_seq)
                y = pad_to_len(y, self.pad_id, self.length_seq)

                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [5]:
dataset_ = np.load("/kaggle/input/rnn-input/corpora_encoded.npy","r")

result = []
for t in dataset_:
    if t == 64 : 
        current = []
        current.append(t)
    elif t == 70:
        current.append(t)
        result.append(torch.tensor(current))
    else :
        current.append(t)
if current:  
    result.append(torch.tensor(current))

In [6]:
len_train = int(len(result) * 0.8)
len_test = len(result) - len_train

generator = torch.Generator().manual_seed(42)
train, test = random_split(result, [len_train, len_test], generator = generator)

In [7]:
train_ds = SongDataset(train,length_seq=256, stride = 64, use_offset = True)
test_ds = SongDataset(test,length_seq=256, stride = 64, use_offset = False)

In [8]:
print("".join([int2char[i] for i in train_ds[0][0][:256].numpy()]))
print()
print("".join([int2char[i] for i in train_ds[1][0][:256].numpy()]))

ux le voir
tu penses sûrement qu'j'suis plus heureux que toi
rien ne dure tout est éphémère frérot
j'ai rien à cacher comme les femens
avant j'avais qu'des p'tites pièces comme le passeur d'âme
et là je brunch avec madame à amsterdam
faire plaisir à ceux q

n ne dure tout est éphémère frérot
j'ai rien à cacher comme les femens
avant j'avais qu'des p'tites pièces comme le passeur d'âme
et là je brunch avec madame à amsterdam
faire plaisir à ceux qu'on aime ça n'a pas d'prix
pour tout le reste y a ta mastercard


## Creation of Dataloader and co

In [9]:
batch_size = 1024

train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", shuffle=False, drop_last=True) #Shuffle False because we need the RNN to use previous sequences data to predict next one
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", shuffle=False, drop_last=True)

### Check the good dataloading and offset validity

In [10]:
len(dataset_),len(train_dl)

(3705039, 40)

## Models

### Training part

In [11]:
class CharRNN(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers=1, dropout = 0):
        super(CharRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=72)
        self.rnn = nn.RNN(emb_size, hidden_size, num_layers, batch_first=True, dropout = dropout, nonlinearity ="relu")
        self.drop = nn.Dropout(p=dropout)
        self.ln = nn.LayerNorm(hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden):
        x = self.drop(self.embedding(x))
        out, hidden = self.rnn(x, hidden)
#        out = self.ln(out)
        out = self.drop(out)
        out = self.fc(out)                  
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size)

In [12]:
device1 = torch.device("cuda:0")

In [13]:
embedding_dim = 128
vocab_size = len(int2char)
hidden_size = 512
num_epoch = 150

nb_step_train = len(train_dl)
nb_step_test = len(test_dl)

model = CharRNN(vocab_size, embedding_dim, hidden_size, num_layers=4, dropout = 0.2).to(device1)
model = torch.compile(model)

#weights = torch.ones(vocab_size).to(device1)

#Structure marker (really important)
#for p in [" ", "\n"]:
#    idx = mapping[p]
#    weights[idx] = 1.5  

#Ponctuation (important) + part indic
#for p in [",", ".", "'", 'α','β','γ','ε','ζ','η','θ']:
#    idx = mapping[p]
#    weights[idx] = 1.25  

loss_fn = nn.CrossEntropyLoss(ignore_index=72)
val_fn = nn.CrossEntropyLoss(ignore_index=72)

opti = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=1e-4)
sched_warm = torch.optim.lr_scheduler.LinearLR(opti,start_factor=0.2,end_factor=1.0,total_iters=nb_step_train * 5)
sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*10, T_mult=2, eta_min=0.0002) 
#sched_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(opti,mode="max",factor=0.6,patience=3,min_lr=1e-5,verbose=True)

In [14]:
import math

@torch.no_grad()
def evaluate_tf1(model, dl, loss_fn, device, vocab_size):
    """Validation with teacher forcing = 1 (parallel, fast). Returns (ppl, acc)."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    correct = 0
    total = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)
        hid = model.init_hidden(bs).to(device)

        with torch.amp.autocast(device_type="cuda"):
            pred, _ = model(X, hid)  # (bs, sl, vocab)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        total_loss += loss.item() * bs * sl
        total_tokens += bs * sl

        pred_ids = pred.argmax(dim=-1)
        correct += (pred_ids == Y).sum().item()
        total += bs * sl

    ppl = math.exp(total_loss / max(1, total_tokens))
    acc = correct / max(1, total)
    return ppl, acc


@torch.no_grad()
def evaluate_free(model, dl, loss_fn, device):
    """
    Autoregressive validation (teacher forcing = 0).
    Steps one token at a time and feeds predictions back in.
    Returns ppl (computed on next-token NLL).
    """
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)
        hid = model.init_hidden(bs).to(device)

        # Start with the first input token
        inp = X[:, :1]  # (bs, 1)
        for t in range(sl):
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(inp, hid)          # (bs, 1, vocab)
                logits = pred[:, -1, :]              # (bs, vocab)
                loss = loss_fn(logits, Y[:, t])      # CE over current step

            total_loss += loss.item() * bs
            total_tokens += bs

            # Greedy next-token to feed back in
            next_token = logits.argmax(dim=-1).unsqueeze(1)  # (bs, 1)
            inp = next_token

    ppl = math.exp(total_loss / max(1, total_tokens))
    return ppl

def sample_with_temp(logits, temp=1.0):
    probs = (logits / temp).softmax(dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def distinct_n_chars(text, n=3):
    ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
    return len(set(ngrams)) / max(1, len(ngrams))

In [None]:
list_offset = []
l_tot = []
teacher_forcing_ratio = 1
best_val = float("inf")

scaler = torch.amp.GradScaler()

for epoch in range(num_epoch):

    # --- Teacher Forcing ratio decay (linear) ---
#    if epoch % 10 == 0 :
#        teacher_forcing_ratio = max(0.0, min(1.0, teacher_forcing_ratio - 0.02))
#        print(f"\nEpoch {epoch} | Teacher forcing ratio = {teacher_forcing_ratio:.2f}")

    model.train()

    # -------------- TRAIN LOOP --------------
    train_loss = 0.0
    train_loss_sum = 0.0
    train_tokens = 0

    for X, Y in iter(train_dl):
        hid = model.init_hidden(batch_size).to(device1)
        X = X.to(device1)
        Y = Y.to(device1, dtype=torch.long)
        opti.zero_grad(set_to_none=True)

        if teacher_forcing_ratio == 1.0:
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(X, hid)
                loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))
        else:
            # ---- Pass 1: forward with TF=1 (parallel) ----
            with torch.no_grad(), torch.amp.autocast(device_type="cuda"):
                pred_tf, _ = model(X, hid)     # (batch, seq_len, vocab)
            pred_tokens = pred_tf.argmax(dim=-1)  # (batch, seq_len)

            # ---- Random mask for partial TF ----
            mask = (torch.rand_like(X.float()) < teacher_forcing_ratio)
            X_mixed = torch.where(mask, X, pred_tokens)

            # ---- Pass 2: forward with partial TF ----
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(X_mixed, hid)
                loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        scaler.step(opti)
        scaler.update()

        # book-keeping
        train_loss += loss.detach().item()
        train_ppl = math.exp(train_loss/nb_step_train)
        if epoch < 5:
            sched_warm.step()
        else:
            sched_post.step()

        if getattr(sched_post, "T_cur", None) == 0 and epoch > 5:
            sched_post.base_lrs[0] *= 0.85
            # sched_post.eta_min *= 1.25
            print(f"Decrease {sched_post.base_lrs[0]}, {sched_post.eta_min}")

    # -------------- VALIDATION --------------
    val_ppl_tf1, val_acc = evaluate_tf1(model, test_dl, val_fn, device1, vocab_size)
    val_ppl_free = evaluate_free(model, test_dl, val_fn, device1)

    print(
        f"Epoch {epoch} | "
        f"Train PPL: {train_ppl:.3f} | "
        f"Val PPL (TF=1): {val_ppl_tf1:.3f} | "
        f"Val PPL (free): {val_ppl_free:.3f} | "
        f"Val Acc: {val_acc:.3f}"
    )

    # --------- Sample generation + diversity metrics ---------
    if epoch % 10 == 0 :
        model.eval()
        with torch.no_grad():
        # prends les 5 premiers chars du batch courant comme seed
            start = X[0:1, :20]  
            hid_gen = model.init_hidden(1).to(device1)
            inp = start

            gen_chars = []
            for t in range(200):  # génère 200 caractères
                pred, hid_gen = model(inp, hid_gen)
                logits = pred[:, -1, :]  # dernier pas
                next_char = sample_with_temp(logits, temp=0.6)
                gen_chars.append(int2char[next_char.item()])
                inp = next_char # feed back

            gen_text = "".join(gen_chars)

        d2 = distinct_n_chars(gen_text, n=2)
        d3 = distinct_n_chars(gen_text, n=3)
    
        print("\n=== Initial text ===")
        print("".join([int2char[i] for i in X[0:1,:].squeeze(0).tolist()]))
        print("\n=== Sample Generation ===")
        print(gen_text[:200])  # affiche les 200 premiers chars
        print(f"Distinct-2: {d2:.3f} | Distinct-3: {d3:.3f}", end="\n")

    # Record accuracy
    l_tot.append(val_ppl_tf1)
    if val_ppl_tf1 < best_val :
        best_val = val_ppl_tf1
        torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": opti.state_dict(),
                    "scheduler_state_dict": sched_post.state_dict(),
                    "val_ppl": val_ppl_tf1,
                },
                "model_forced",
        )

Epoch 0 | Train PPL: 19.439 | Val PPL (TF=1): 10.996 | Val PPL (free): 70.999 | Val Acc: 0.298

=== Initial text ===
 non, oui, oui; oui et non
/ε
γ
le temps rend malheureux les gens comme rer, rer
combien perdent leur religion comme rem, rem?
vu qu'le monde parait irréel même irl, irl
j'essaye de voir en eux comme irm
je veux, veux, veux, veux mettre un trait d'union
fe

=== Sample Generation ===
mes pand mais qur dait fres ç'mass la on m'ante je cormis me an pantes enx j'aus
d'maraste t'an la m'ais c'aste mante à
p'ra d've fout jans bont do sout anton le pans d'te de mant nou ce dont d'an ja 
Distinct-2: 0.447 | Distinct-3: 0.747
Epoch 1 | Train PPL: 9.566 | Val PPL (TF=1): 7.701 | Val PPL (free): 303.023 | Val Acc: 0.389
Epoch 2 | Train PPL: 7.311 | Val PPL (TF=1): 6.284 | Val PPL (free): 142.110 | Val Acc: 0.444
Epoch 3 | Train PPL: 6.259 | Val PPL (TF=1): 5.573 | Val PPL (free): 246.475 | Val Acc: 0.474
Epoch 4 | Train PPL: 5.693 | Val PPL (TF=1): 5.171 | Val PPL (free): 288.340 |

Val PPL (TF=1): 3.756 