# This notebook will serve as a way to implement character generation LSTM and other implementation

In [None]:
import torch
from torch.utils.data import DataLoader, random_split, Dataset
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import re
import numpy as np
import pickle
from transformers import PreTrainedTokenizerFast
import math

from functools import partial

# Data preparation

In [None]:
TOKENIZERS_PARALLELISM=True

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="/kaggle/input/subword-rnn-lstm/rap_tokenizer.json",
    pad_token = "<PAD>"
)

## Creation of the dataset

To increase the randomness during the training :

For each epoch the entire corpus will have a random specific offset value in order that the model during training doesn't see the exact same text during X epochs.

In [None]:
dataset_ = np.load("/kaggle/input/subword-rnn-lstm/encoded.npy","r")

result = []

for t in dataset_:
    if t == tokenizer.convert_tokens_to_ids("α") : 
        current = []
        current.append(t)
    elif t == tokenizer.convert_tokens_to_ids("θ") :
        current.append(t)
        result.append(torch.tensor(current))
    else :
        current.append(t)
if current:  
    result.append(torch.tensor(current))

In [None]:
class SongDataset(Dataset):
    def __init__(self, texts, length_seq, stride, use_offset=True):
        self.samples = []
        self.length_seq = length_seq
        self.stride = stride

        for text in texts:
            L = len(text)

            offset = torch.randint(0, stride, (1,)).item() if use_offset else 0

            # --- Boucle principale ---
            for start in range(offset, max(1, L - self.length_seq), self.stride):
                x_start, x_end = start, start + self.length_seq
                y_start, y_end = start + 1, start + 1 + self.length_seq

                x = text[x_start:x_end]
                y = text[y_start:y_end]

                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
len_train = int(len(result) * 0.8)
len_test = len(result) - len_train

generator = torch.Generator().manual_seed(42)
train, test = random_split(result, [len_train, len_test], generator = generator)

train_ds = SongDataset(train, length_seq=256, stride = 16, use_offset = True)
test_ds = SongDataset(test, length_seq=256, stride = 16, use_offset = False)

## Creation of Dataloader and co

In [None]:
def collate_batch(batch, tokenizer):
    # batch = list of (X, Y) tuples from your Dataset
    X, Y = zip(*batch)

    # Pad each side to the longest sequence in the batch
    X_padded = pad_sequence(X, batch_first=True, padding_value=tokenizer.pad_token_id)
    Y_padded = pad_sequence(Y, batch_first=True, padding_value=tokenizer.pad_token_id)

    # Return the structure expected by your model
    batch_enc = [ X_padded,Y_padded]

    return batch_enc

# Create a callable version of collate_fn with your tokenizer
collate_fn = partial(collate_batch, tokenizer=tokenizer)

#Normally a function requires to specify the options at the initiation but partial allows to specify values for the required option that will
#be stored and then be used when the function will be called
# collate_fn(batch) == collate_batch(batch, tokenizer=tokenizer)

In [None]:
batch_size = 512

train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", 
                        num_workers=4, prefetch_factor=4, shuffle=False, drop_last=True, collate_fn = collate_fn) #Shuffle False because we need the RNN to use previous sequences data to predict next one
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", 
                       num_workers=2, prefetch_factor=2, shuffle=False, drop_last=True, collate_fn = collate_fn)

## Models

### Training part

In [None]:
class LSTM_3_layers(nn.Module) :
    def __init__(self, vocab_size, embed_dim, hidden_size, batch_size, proba):
        super().__init__()

        self.batch_size = batch_size
        self.hidden_size = hidden_size
    
        self.lstm_layers = nn.ModuleList([nn.LSTMCell(embed_dim if i == 0 else hidden_size, hidden_size, bias=True) for i in range(3)])

        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx = 0)
        self.drop = nn.Dropout(p=proba)
        self.ln = nn.LayerNorm(hidden_size)

        self.linear = nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, x, hidden) :
        h,c = hidden 

        x_t = self.embed(x)
        for i, lstm_cell in enumerate(self.lstm_layers) : #No dropout before BN : https://arxiv.org/pdf/1801.05134 
            h[i], c[i] = lstm_cell(x_t,(h[i],c[i]))
            x_t = self.drop(self.ln(h[i]))

        logits = self.linear(x_t)

        return logits, (h,c)

In [None]:
device = torch.device("cuda:0")

embedding_dim = 384
hidden_size = 512
vocab_size = tokenizer.vocab_size
num_epoch = 100

nb_step_train = len(train_dl)
nb_step_test = len(test_dl)

model = LSTM_3_layers(vocab_size, embedding_dim, hidden_size, batch_size, 0.2).to(device)
model = torch.compile(model)

loss_fn = nn.CrossEntropyLoss(ignore_index=0)
val_loss = nn.CrossEntropyLoss(ignore_index=0)

opti = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
sched_warm = torch.optim.lr_scheduler.LinearLR(opti, start_factor=0.2, end_factor=1.0, total_iters=nb_step_train*3)
sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*10, T_mult=2, eta_min=0.0001) #1 epoch => 2 => 4 => 8

In [None]:
@torch.no_grad()
def evaluate_free(model, dl, loss_fn, device):
    """
    Autoregressive validation (teacher forcing = 0).
    Steps one token at a time and feeds predictions back in.
    Returns ppl (computed on next-token NLL).
    """
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)
        hid = model.init_hidden(bs, device1)

        # Start with the first input token
        inp = X[:, :1]  # (bs, 1)
        for t in range(sl):
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(inp, hid)          # (bs, 1, vocab)
                logits = pred[:, -1, :]              # (bs, vocab)
                loss = loss_fn(logits, Y[:, t])      # CE over current step

            total_loss += loss.item() * bs
            total_tokens += bs

            # Greedy next-token to feed back in
            next_token = logits.argmax(dim=-1).unsqueeze(1)  # (bs, 1)
            inp = next_token

    ppl = math.exp(total_loss / max(1, total_tokens))
    return ppl

def sample_with_temp(logits, temp=1.0):
    probs = (logits / temp).softmax(dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def distinct_n_chars(text, n=3):
    ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
    return len(set(ngrams)) / max(1, len(ngrams))

In [None]:
def training(model, train_dl, loss_fn, vocab_size, opti, scaler, device) :
    
    model.train()
    train_loss = 0.0

    for X,Y in iter(train_dl) :
        
        X,Y = X.to(device), Y.to(device, dtype=torch.long)
        h = [torch.zeros(batch_size, hidden_size, device=device) for _ in range(3)]
        c = [torch.zeros(batch_size, hidden_size, device=device) for _ in range(3)]
        
        opti.zero_grad(set_to_none=True)
        batch_loss = 0

        for t in range(X.size(1)) : #On each timestep of the seq
            with torch.amp.autocast(device_type="cuda"):
                pred, (h,c) = model(X[:,t], (h,c))
                batch_loss += loss_fn(pred.view(-1, vocab_size), Y[:,t])

        batch_loss /= X.size(1)

        scaler.scale(batch_loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        scaler.step(opti)
        scaler.update()

        train_loss += batch_loss.detach().item() 

    return train_loss, scaler

def evaluate(model, eval_dl, loss_fn, vocab_size, device) :
    
    model.eval()
    eval_loss = 0.0

    for X,Y in iter(eval_dl) :
        
        X,Y = X.to(device), Y.to(device, dtype=torch.long)
        h = [torch.zeros(batch_size, hidden_size, device=device) for _ in range(3)]
        c = [torch.zeros(batch_size, hidden_size, device=device) for _ in range(3)]
    
        batch_loss = 0

        for t in range(X.size(1)) : #On each timestep of the seq
            pred, (h,c) = model(X[:,t], (h,c))
            batch_loss += loss_fn(pred.view(-1, vocab_size), Y[:,t])

        batch_loss /= X.size(1)
        eval_loss += batch_loss.detach().item() 

    return eval_loss

In [None]:
list_offset = []
l_tot = []
teacher_forcing_ratio = 1
best_val = float("inf")

scaler = torch.amp.GradScaler()

for epoch in range(num_epoch):
    
    # -------------- TRAIN LOOP --------------
    train_loss, scaler = training(model, train_dl, loss_fn, vocab_size, opti, scaler, device)
    train_ppl = math.exp(train_loss / nb_step_train)

    # ---- Scheduler part ---- #

    if sched_post.T_cur == 0 and epoch > 4:  #After warm restart decrease the max learning rate
        sched_post.base_lrs[0] = sched_post.base_lrs[0] * 0.9
    #    sched_post.eta_min = sched_post.eta_min * 0.85
        print(f"Decrease {sched_post.base_lrs[0]}, {sched_post.eta_min}")

    if epoch < 3:
        sched_warm.step()
    else:
        sched_post.step()
    
    # -------------- VALIDATION --------------
    with torch.no_grad() :      
        val_loss = evaluate(model, test_dl, loss_fn, vocab_size, device)
        val_ppl = math.exp(val_loss / nb_step_test)

    print(
        f"Epoch {epoch} | "
        f"Train PPL: {train_ppl:.3f} | "
        f"Val PPL : {val_ppl:.3f}"
    )

    # --------- Sample generation + diversity metrics ---------
    if epoch % 8 == 0 and epoch != 0 :
        
        model.eval()
        with torch.no_grad():
        # Context priming
            X, _ = next(iter(test_dl))
            
            start = X[0:1, :20].to(device)
            h = [torch.zeros(1, hidden_size, device=device) for _ in range(3)]
            c = [torch.zeros(1, hidden_size, device=device) for _ in range(3)]

            for j in range(start.size(1)) : 
                _, (h,c) = model(start[:,j] ,(h,c))

        #Generating
            inp = start[:,j]
            gen_chars = []

            for t in range(200):  
                logits, (h,c) = model(inp, (h,c))
                next_char = sample_with_temp(logits, temp=0.6)
                gen_chars.append(tokenizer.decode(next_char.item()))
                inp = next_char.view(1) # feed back

            gen_text = "".join(gen_chars)

        d2 = distinct_n_chars(gen_text, n=2)
        d3 = distinct_n_chars(gen_text, n=3)

        print("\n=== Initial text ===")
        print("".join(tokenizer.decode(X[0:1,:].squeeze(0))))
        print("\n=== Sample Generation ===")
        print(gen_text[:200]) 
        print(f"Distinct-2: {d2:.3f} | Distinct-3: {d3:.3f}", end="\n")

    # Record accuracy
    l_tot.append(val_ppl)
    if val_ppl < best_val :
        best_val = val_ppl
        torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": opti.state_dict(),
                    "scheduler_state_dict": sched_post.state_dict(),
                    "val_ppl": val_ppl,
                },
                "model",
        )