In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import re
import numpy as np
import pickle
from transformers import PreTrainedTokenizerFast
from torch.utils.data import random_split
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from functools import partial
import math

# Data preparation

In [9]:
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="../../../Corpus/Encoding_RNN_LSTM/Subword/rap_tokenizer.json",
    pad_token = "<PAD>")

## Creation of the dataset

In [10]:
dataset_ = np.load("../../../Corpus/Encoding_RNN_LSTM/Subword/encoded.npy","r")

result = []
for t in dataset_:
    if t == tokenizer.convert_tokens_to_ids("α") : 
        current = []
        current.append(t)
    elif t == tokenizer.convert_tokens_to_ids("θ") :
        current.append(t)
        result.append(torch.tensor(current))
    else :
        current.append(t)
if current:  
    result.append(torch.tensor(current))

In [11]:
class SongDataset(Dataset):
    def __init__(self, texts, length_seq, stride, use_offset=True):
        self.samples = []
        self.length_seq = length_seq
        self.stride = stride

        for text in texts:
            L = len(text)

            offset = torch.randint(0, stride, (1,)).item() if use_offset else 0

            for start in range(offset, max(1, L - self.length_seq), self.stride):
                x_start, x_end = start, start + self.length_seq
                y_start, y_end = start + 1, start + 1 + self.length_seq

                x = text[x_start:x_end]
                y = text[y_start:y_end]

                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

## Creation of Dataloader and co

In [None]:
def collate_batch(batch, tokenizer):
    X, Y = zip(*batch)

    # Pad each side to the longest sequence in the batch
    X_padded = pad_sequence(X, batch_first=True, padding_value=tokenizer.pad_token_id)
    Y_padded = pad_sequence(Y, batch_first=True, padding_value=tokenizer.pad_token_id)

    batch_enc = [ X_padded,Y_padded]

    return batch_enc

# Create a callable version of collate_fn with your tokenizer
collate_fn = partial(collate_batch, tokenizer=tokenizer)

#Normally a function requires to specify the options at the initiation but partial allows to specify values for the required option that will
#be stored and then be used when the function will be called
# collate_fn(batch) == collate_batch(batch, tokenizer=tokenizer)

In [None]:
stride = 4
batch_size = 256
seq_length = 256

In [None]:
len_train = int(len(result) * 0.85)
len_test = len(result) - len_train

train, test = random_split(result, [len_train, len_test])

train_ds = SongDataset(train, length_seq=seq_length, stride = stride, use_offset = True)
test_ds = SongDataset(test, length_seq=seq_length, stride = stride, use_offset = False)

train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", 
                        num_workers=4, prefetch_factor=4, shuffle=False, drop_last=True, collate_fn = collate_fn) #Shuffle False because we need the RNN to use previous sequences data to predict next one
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", 
                       num_workers=2, prefetch_factor=2, shuffle=False, drop_last=True, collate_fn = collate_fn)

## Models

### Training part

In [None]:
class Subword_LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=1):
        super(Subword_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(input_size = embedding_dim, 
                            hidden_size = hidden_size, 
                            num_layers = num_layers, 
                            batch_first=True, dropout = 0.2)
        self.drop = nn.Dropout(p=0.2)
        self.ln = nn.LayerNorm(hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden):
        x = self.embedding(x)              # (batch, seq, hidden_size)
        x = self.drop(x)
        out, hidden = self.lstm(x, hidden)
        out = self.ln(out)
        out = self.drop(out)
        logits = self.fc(out)                  
        return logits, hidden

    def init_hidden(self, batch_size, device):
        self.device = device
        h0 = torch.randn(self.num_layers, batch_size, self.hidden_size, device = device)
        c0 = torch.randn(self.num_layers, batch_size, self.hidden_size, device = device)
        return (h0,c0)

In [None]:
device1 = torch.device("cuda:0")

In [None]:
embedding_dim = 300
hidden_size = 512
vocab_size = tokenizer.vocab_size
num_epoch = 100

nb_step_train = len(train_dl)
nb_step_test = len(test_dl)

model = Subword_LSTM(vocab_size, embedding_dim, hidden_size, num_layers=3)   

model.to(device1)

loss_fn = nn.CrossEntropyLoss()

opti = torch.optim.AdamW(model.parameters(), lr=0.0015, weight_decay=1e-3)
sched_warm = torch.optim.lr_scheduler.LinearLR(opti, start_factor=0.2, end_factor=1.0, total_iters=nb_step_train*4)
sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*20, T_mult=2, eta_min=0.0001) #1 epoch => 2 => 4 => 8

### Functions for training

In [None]:
@torch.no_grad()
def evaluate_tf1(model, dl, loss_fn, device, vocab_size, bs = batch_size, sl = seq_length):
    """Validation with teacher forcing = 1 (parallel, fast). Returns (ppl, acc)."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    correct = 0
    total = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        hid = model.init_hidden(bs, device1)

        with torch.amp.autocast(device_type="cuda"):
            pred, _ = model(X, hid) 
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        total_loss += loss.item() * bs * sl
        total_tokens += bs * sl

        pred_ids = pred.argmax(dim=-1)
        correct += (pred_ids == Y).sum().item()
        total += bs * sl

    ppl = math.exp(total_loss / max(1, total_tokens))
    acc = correct / max(1, total)
    return ppl, acc


@torch.no_grad()
def evaluate_free(model, dl, loss_fn, device, bs=batch_size, sl=seq_length):
    """
    Autoregressive validation (teacher forcing = 0).
    Steps one token at a time and feeds predictions back in.
    Returns ppl.
    """
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    for X, Y in dl:
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        hid = model.init_hidden(bs, device1)

        # Start with the first input token
        inp = X[:, :1]  
        for t in range(sl):
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(inp, hid)          
                logits = pred[:, -1, :]              
                loss = loss_fn(logits, Y[:, t])      

            total_loss += loss.item() * bs
            total_tokens += bs

            # Greedy next-token to feed back in
            next_token = logits.argmax(dim=-1).unsqueeze(1)  
            inp = next_token

    ppl = math.exp(total_loss / max(1, total_tokens))
    return ppl

def sample_with_temp(logits, temp=1.0):
    probs = (logits / temp).softmax(dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def distinct_n_chars(text, n=3):
    ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
    return len(set(ngrams)) / max(1, len(ngrams))

In [None]:
l_tot = []
bs = batch_size
sl = seq_length
teacher_forcing_ratio = 1

scaler = torch.amp.GradScaler()

for epoch in range(num_epoch):

    # --- Teacher Forcing ratio decay ---
    teacher_forcing_ratio = max(0.0, min(1.0, teacher_forcing_ratio - 0.01))
#    print(f"\nEpoch {epoch} | Teacher forcing ratio = {teacher_forcing_ratio:.2f}")

    model.train()

    # -------------- TRAIN LOOP --------------
    train_loss = 0.0

    for X, Y in iter(train_dl):
        hid = model.init_hidden(batch_size, device1)
        X = X.to(device1)
        Y = Y.to(device1, dtype=torch.long)
        opti.zero_grad(set_to_none=True)

        if teacher_forcing_ratio == 1.0:
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(X, hid)
                loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))
        else:
            # ---- Pass 1: forward with TF=1 ----
            with torch.no_grad(), torch.amp.autocast(device_type="cuda"):
                pred_tf, _ = model(X, hid)    
            pred_tokens = pred_tf.argmax(dim=-1) 

            # ---- Random mask for partial TF ----
            mask = (torch.rand_like(X.float()) < teacher_forcing_ratio)
            X_mixed = torch.where(mask, X, pred_tokens)

            # ---- Pass 2: forward with partial TF ----
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(X_mixed, hid)
                loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        scaler.step(opti)
        scaler.update()

        train_loss += loss.detach().item()
        
        if epoch < 3:
            sched_warm.step()
        else:
            sched_post.step()

    train_ppl = math.exp(train_loss / nb_step_train)

    # -------------- VALIDATION --------------
    val_ppl_tf1, val_acc = evaluate_tf1(model, test_dl, loss_fn, device1, vocab_size)
    val_ppl_free = evaluate_free(model, test_dl, loss_fn, device1)

    print(
        f"Epoch {epoch} | "
        f"Train PPL: {train_ppl:.3f} | "
        f"Val PPL (TF=1): {val_ppl_tf1:.3f} | "
        f"Val PPL (free): {val_ppl_free:.3f} | "
        f"Val Acc: {val_acc:.3f}"
    )

    # --------- Sample generation + diversity metrics ---------
    model.eval()
    with torch.no_grad():
        # Warm up
        start = X[0:1, :20]  
        hid_gen = model.init_hidden(1,device1)
        inp = start

        gen_chars = []
        for t in range(200):  
            pred, hid_gen = model(inp, hid_gen)
            logits = pred[:, -1, :]  
            next_char = sample_with_temp(logits, temp=0.6)
            gen_chars.append(tokenizer.decode(next_char.item()))
            inp = next_char # feed back

        gen_text = "".join(gen_chars)

    d2 = distinct_n_chars(gen_text, n=2)
    d3 = distinct_n_chars(gen_text, n=3)

    
    if epoch % 4 == 0 :
        print("\n=== Initial text ===")
        print("".join(tokenizer.decode(X[0:1,:].squeeze(0))))
        print("\n=== Sample Generation ===")
        print(gen_text[:200])  
        print(f"Distinct-2: {d2:.3f} | Distinct-3: {d3:.3f}", end="\n")

    l_tot.append(val_acc)
    torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opti.state_dict(),
                "scheduler_state_dict": sched_post.state_dict(),
                "val_acc": val_acc,
            },
            "model",
    )