# This notebook will serve as a way to train subword generation RNN

In [None]:
import math
from functools import partial

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizerFast

In [None]:
TOKENIZERS_PARALLELISM=True

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="/kaggle/input/subword-rnn-lstm/rap_tokenizer.json",
    pad_token = "<PAD>"
)

device1 = torch.device("cuda:0")

## Creation of the dataset

In [None]:
dataset_ = np.load("/kaggle/input/subword-rnn-lstm/encoded.npy","r")

result = []
for t in dataset_:
    if t == tokenizer.convert_tokens_to_ids("α") : 
        current = []
        current.append(t)
    elif t == tokenizer.convert_tokens_to_ids("θ") :
        current.append(t)
        result.append(torch.tensor(current))
    else :
        current.append(t)
if current:  
    result.append(torch.tensor(current))

class SongDataset(Dataset):
    def __init__(self, texts, length_seq, stride, use_offset=True):
        self.samples = []
        self.length_seq = length_seq
        self.stride = stride

        for text in texts:
            L = len(text)

            offset = torch.randint(0, stride, (1,)).item() if use_offset else 0

            # --- Boucle principale ---
            for start in range(offset, max(1, L - self.length_seq), self.stride):
                x_start, x_end = start, start + self.length_seq
                y_start, y_end = start + 1, start + 1 + self.length_seq

                x = text[x_start:x_end]
                y = text[y_start:y_end]

                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
len_train = int(len(result) * 0.8)
len_test = len(result) - len_train

generator = torch.Generator().manual_seed(42)
train, test = random_split(result, [len_train, len_test], generator = generator)

train_ds = SongDataset(train, length_seq=256, stride = 16, use_offset = True) # Stride 16 to reduce overlap and redundancy of data : to prevent overfitting
test_ds = SongDataset(test, length_seq=256, stride = 16, use_offset = False) # No offset for validation set

## Creation of Dataloader and co

In [None]:
def collate_batch(batch, tokenizer):
    X, Y = zip(*batch)

    X_padded = pad_sequence(X, batch_first=True, padding_value=tokenizer.pad_token_id)
    Y_padded = pad_sequence(Y, batch_first=True, padding_value=tokenizer.pad_token_id)

    batch_enc = [ X_padded,Y_padded]

    return batch_enc

# Create a callable version of collate_fn with your tokenizer
collate_fn = partial(collate_batch, tokenizer=tokenizer)

#Normally a function requires to specify the options at the initiation but partial allows to specify values for the required option that will
#be stored and then be used when the function will be called
# collate_fn(batch) == collate_batch(batch, tokenizer=tokenizer)

In [None]:
batch_size = 512

train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", 
                        num_workers=4, prefetch_factor=4, shuffle=False, drop_last=True, collate_fn = collate_fn) #Shuffle False because we need the RNN to use previous sequences data to predict next one
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", 
                       num_workers=2, prefetch_factor=2, shuffle=False, drop_last=True, collate_fn = collate_fn)

## Models

### Training part

In [None]:
class Subword_RNN(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers=1, dropout = 0):
        super(Subword_RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx= 0)
        self.rnn = nn.RNN(emb_size, hidden_size, num_layers, batch_first=True, dropout = dropout, nonlinearity ="relu")
        self.drop = nn.Dropout(p=dropout)
        self.ln = nn.LayerNorm(hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden):
        x = self.drop(self.embedding(x))
        out, hidden = self.rnn(x, hidden)
#        out = self.ln(out)
        out = self.drop(out)
        out = self.fc(out)                  
        return out, hidden

    def init_hidden(self, batch_size):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size)

In [None]:
embedding_dim = 384
vocab_size = tokenizer.vocab_size
hidden_size = 512
num_epoch = 100

nb_step_train = len(train_dl)
nb_step_test = len(test_dl)

model = Subword_RNN(vocab_size, embedding_dim, hidden_size, num_layers=3, dropout = 0.2).to(device1)
model = torch.compile(model)
    
loss_fn = nn.CrossEntropyLoss(ignore_index = 0)

opti = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=1e-4)
sched_warm = torch.optim.lr_scheduler.LinearLR(opti,start_factor=0.2,end_factor=1.0,total_iters=nb_step_train * 3)
sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*20, T_mult=2, eta_min=0.0001) 

In [None]:
@torch.no_grad()
def evaluate_tf1(model, dl, loss_fn, device, vocab_size, nb_step_test):
    model.eval()
    total_loss = 0.0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)
        hid = model.init_hidden(bs).to(device)

        with torch.amp.autocast(device_type="cuda"):
            pred, _ = model(X, hid)  # (bs, sl, vocab)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        total_loss += loss.item()

    ppl = math.exp(total_loss / nb_step_test)
    return ppl


def sample_with_temp(logits, temp=1.0):
    probs = (logits / temp).softmax(dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def distinct_n_chars(text, n=3):
    ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
    return len(set(ngrams)) / max(1, len(ngrams))

In [None]:
list_offset = []
l_tot = []
best_val = float("inf")
teacher_forcing_ratio = 1

scaler = torch.amp.GradScaler()

for epoch in range(num_epoch):

    model.train()

    # -------------- TRAIN LOOP --------------
    train_loss_sum = 0.0

    for X,Y in iter(train_dl) :
        hid = model.init_hidden(batch_size).to(device1)
        X = X.to(device1)
        Y = Y.to(device1, dtype=torch.long)
        opti.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda"):
            pred, hid = model(X, hid)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        scaler.step(opti)
        scaler.update()

        train_loss_sum += loss.detach().item() 

        if epoch < 3:
            sched_warm.step()
        else:
            sched_post.step()

        if sched_post.T_cur == 0 and epoch > 4:  #After warm restart decrease the max learning rate
            sched_post.base_lrs[0] = sched_post.base_lrs[0] * 0.9
            print(f"Decrease {sched_post.base_lrs[0]}, {sched_post.eta_min}")
            
    train_ppl = math.exp(train_loss_sum / nb_step_train)

    # -------------- VALIDATION --------------
    val_ppl = evaluate_tf1(model, test_dl, loss_fn, device1, vocab_size, nb_step_test)

    print(
        f"Epoch {epoch} | "
        f"Train PPL: {train_ppl:.3f} | "
        f"Val PPL (TF=1): {val_ppl:.3f} "
    )

    # --------- Sample generation + diversity metrics ---------
    model.eval()
    with torch.no_grad():
        start = X[0:1, :20]  
        hid_gen = model.init_hidden(1).to(device1)
        inp = start.to(device1)

        gen_chars = []
        for t in range(200):  
            pred, hid_gen = model(inp, hid_gen)
            logits = pred[:, -1, :]  # last step
            next_char = sample_with_temp(logits, temp=0.6)
            gen_chars.append(tokenizer.decode(next_char.item()))
            inp = next_char # feed back

        gen_text = "".join(gen_chars)

    # --- Distinct-n métriques ---
    d2 = distinct_n_chars(gen_text, n=2)
    d3 = distinct_n_chars(gen_text, n=3)

    if epoch % 4 == 0 :
        print("\n=== Initial text ===")
        print("".join(tokenizer.decode(X[0:1,:20].squeeze(0))))
        print("\n=== Sample Generation ===")
        print(gen_text[:200])  # affiche les 200 premiers chars
        print(f"Distinct-2: {d2:.3f} | Distinct-3: {d3:.3f}", end="\n")

    # Record accuracy
    if val_ppl < best_val :
        best_val = val_ppl
        l_tot.append(val_ppl)
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opti.state_dict(),
                "scheduler_state_dict": sched_post.state_dict(),
                "val_ppl": val_ppl,
            },
            "model",
        )