# This notebook will serve as a way to implement character generation LSTM and other implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import get_cosine_schedule_with_warmup


import fasttext
import math
import numpy as np
import pickle
import random
import re

In [None]:
@torch.no_grad()
def evaluate_tf1(model, dl, loss_fn, device, vocab_size):
    """Validation with teacher forcing = 1 (parallel, fast). Returns (ppl, acc)."""
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    correct = 0
    total = 0
    nb_step = len(dl)

    for X, local_context, Y in dl:
        local_context = local_context.to(device)
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        h = model.init_hidden(1024).to(device)
        
        with torch.amp.autocast(device_type="cuda"):
            pred, h = model(X, local_context, h)  # (bs, sl, vocab)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        total_loss += loss.item() 

    ppl = math.exp(total_loss / nb_step)
    return ppl


@torch.no_grad()
def evaluate_free(model, dl, loss_fn, device):
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    for X, local_context, Y in dl:
        local_context = local_context.to(device)
        X = X.to(device)
        Y = Y.to(device, dtype=torch.long)
        bs, sl = X.size(0), X.size(1)
        hid = model.init_hidden(bs).to(device)

        # Start with the first input token
        inp = X[:, :1]  # (bs, 1)
        for t in range(sl):
            inp_context = local_context[:, 0:t]
            with torch.amp.autocast(device_type="cuda"):
                pred, hid = model(inp, local_context, hid)          # (bs, 1, vocab)
                logits = pred[:, -1, :]              # (bs, vocab)
                loss = loss_fn(logits, Y[:, t])      # CE over current step

            total_loss += loss.item() 

            # Greedy next-token to feed back in
            next_token = logits.argmax(dim=-1).unsqueeze(1)  # (bs, 1)
            inp = next_token

    ppl = math.exp(total_loss / bs)
    return ppl

def sample_with_temp(logits, temp=1.0):
    probs = (logits / temp).softmax(dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def distinct_n_chars(text, n=3):
    ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
    return len(set(ngrams)) / max(1, len(ngrams))

# Data preparation

In [None]:
with open("/kaggle/input/rnn-input/encoding_map.pkl", "rb") as f:
    mapping = pickle.load(f)

mapping["PAD"] = len(mapping)

# Decode
int2char = {i: ch for ch, i in mapping.items()}
print(int2char)

nb_char = len(int2char)

## Creation of the dataset

In [None]:
ends = np.load("/kaggle/input/rnn-input/ends.npy","r")
starts = np.load("/kaggle/input/rnn-input/starts.npy","r")
word_matrix = np.load("/kaggle/input/rnn-input/word_matrix.npy","r")

fast_emb = fasttext.load_model("/kaggle/input/rnn-input/fasttext_corpus")
dataset_ = np.load("/kaggle/input/rnn-input/corpora_encoded.npy","r")

In [None]:
def get_context_windows(word_matrix, starts, ends, numerot_seq) :
    """
    Extract from a specific text the specific context_window for the entire of the sequence by using pre specified indication about word length
    For each character in the sequence, gave back a context windows of 3,100 values that represent the words in a space of 100 dim
    """

    x_begin = numerot_seq[0]
    x_end = numerot_seq[-1]
    mask = (starts <= x_end) & (ends >= x_begin)
    indices = np.where(mask)[0]

    beginning = starts[indices[0]]
    if beginning < x_begin :
        beginning = x_begin

    ending = ends[indices[-1]]
    if ending > x_end :
        ending = x_end

    whole_matrix = []
    first_idx = indices[0]
    for k,idx in enumerate(indices) :

        is_last = (idx == indices[-1])
    
        if k < 3 :
            context = np.vstack([np.zeros((3-k,100),dtype = word_matrix.dtype),word_matrix[first_idx:idx]])
        else : 
            context = word_matrix[idx-3:idx]
    
        if k == 0 :
            start_i = beginning
            end_i = ends[idx]
        elif is_last :
            start_i = ends[idx-1]
            end_i = ending
        else :
            start_i = ends[idx-1]
            end_i = ends[idx]
        
        for i in range(start_i,end_i) :
            whole_matrix.append(context)

    return np.array(whole_matrix, dtype=np.float32)

In [None]:
class SongDataset(Dataset):
    def __init__(self, texts, length_seq, stride, starts, ends, word_matrix, 
                 pad_id=mapping["PAD"], use_offset=True):
        self.samples = []
        self.length_seq = length_seq
        self.stride = stride
        self.pad_id = pad_id
        self.word_matrix = word_matrix
        self.starts = starts
        self.ends = ends

        for elem in texts:

            text, numerot = elem 
            
            L = len(text)
            offset = torch.randint(0, stride, (1,)).item() if use_offset else 0

            # --- Boucle principale ---
            for start in range(offset, max(1, L - self.length_seq - 1), self.stride):
                x_start, x_end = start, start + self.length_seq
                y_start, y_end = start + 1, start + 1 + self.length_seq

                numerot_seq = numerot[x_start:x_end] 
                x = text[x_start:x_end]
                y = text[y_start:y_end]

                local_context = get_context_windows(self.word_matrix, self.starts, self.ends, numerot_seq)
                local_context = torch.from_numpy(local_context)

                # --- Padding uniforme ---
                def pad_to_len(seq, pad_id, target_len):
                    pad_len = target_len - len(seq)
                    if pad_len > 0:
                        seq = torch.cat([seq, torch.full((pad_len,), pad_id, dtype=seq.dtype)])
                    return seq

                x = pad_to_len(x, self.pad_id, self.length_seq)
                y = pad_to_len(y, self.pad_id, self.length_seq)

                if len(local_context) < self.length_seq:
                    pad_len = self.length_seq - len(local_context)
                    pad = torch.zeros((pad_len, local_context.shape[1], local_context.shape[2]))
                    local_context = torch.cat([local_context, pad], dim=0)

                self.samples.append((x, local_context, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
dataset_ = np.load("/kaggle/input/rnn-input/corpora_encoded.npy","r")

result = []
for i,t in enumerate(dataset_):
    if t == 63 : 
        current = []
        num = []
        current.append(t)
        num.append(i)
    elif t == 69:
        current.append(t)
        num.append(i)
        result.append([torch.tensor(current), num])
    else :
        current.append(t)
        num.append(i)
if current:  
    result.append([torch.tensor(current), num])

In [None]:
len_train = int(len(result) * 0.8)
len_test = len(result) - len_train

generator = torch.Generator().manual_seed(42)
train, test = random_split(result, [len_train, len_test], generator = generator)

train_ds = SongDataset(train,length_seq= 256, stride = 48, starts = starts, ends = ends, word_matrix = word_matrix, use_offset = False)
test_ds = SongDataset(test,length_seq= 256, stride = 48, starts = starts, ends = ends, word_matrix = word_matrix, use_offset = False)

## Creation of Dataloader and co

In [None]:
batch_size = 1024

train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", shuffle=False, drop_last=True) #Shuffle False because we need the RNN to use previous sequences data to predict next one
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", shuffle=False, drop_last=True)

## Models

### Training part

In [None]:
class Char_GRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, dropout, num_layers):
        super(Char_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.drop = nn.Dropout(p=dropout)
        self.ln = nn.LayerNorm(hidden_size)

        self.embed_drop = nn.Dropout(p=0.1)    
        self.rnn_drop = nn.Dropout(p=0.2)      
        self.attn_drop = nn.Dropout(p=0.1)     
        self.proj_drop = nn.Dropout(p=0.1)  

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=71)
        self.hidden_to_emb = nn.Linear(hidden_size, embedding_dim)

        self.fc = nn.Linear(embedding_dim, vocab_size, bias=False)
        self.fc.weight = self.embedding.weight
        
        # Context dim proj
        self.context_proj = nn.Linear(3 * 100, embedding_dim)  # 3 prev words Ã— 100-dim vectors

        self.gru = nn.GRU(
            input_size=embedding_dim * 2,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.1
        )

        # Attention
        self.attn_combine = nn.Linear(hidden_size * 2, hidden_size) 
        self.Wa = nn.Linear(hidden_size, hidden_size, bias=False) #Learnable attention matrix

    def forward(self, x, context, h):
        
        x_embed = self.embed_drop(self.embedding(x))

        # --- Context projection ---
        context_flat = context.view(context.size(0), context.size(1), -1)  
        context_proj = self.context_proj(context_flat)                     

        x_input = torch.cat([x_embed, context_proj], dim=-1)               

        out, h = self.gru(x_input, h)
        out = self.rnn_drop(self.ln(out))

        #Attention part Q/K/V
        #query = hidden[-1].unsqueeze(1)    
        keys = self.Wa(out)
        values = out
        
        attn_scores = torch.bmm(out, keys.transpose(1, 2))/ math.sqrt(out.size(-1))

        #Mask attention
        L = out.size(1)
        mask = torch.tril(torch.ones(L, L, device=out.device)).unsqueeze(0)  
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_scores, dim=-1)                    
        context_vec = torch.bmm(attn_weights, values)

        combined = torch.cat((out, context_vec), dim=-1)                 
        combined = torch.tanh(self.attn_combine(combined)) #Non linearity

        emb_space = self.hidden_to_emb(combined)

        out = self.fc(emb_space)
        return out, h

    def init_hidden(self, batch_size):
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
        return h0

In [None]:
device1 = torch.device("cuda:0")

In [None]:
embedding_dim = 128
vocab_size = len(int2char)
hidden_size = 512
num_epoch = 200

nb_step_train = len(train_dl)
nb_step_test = len(test_dl)

In [None]:
model = Char_GRU(vocab_size, embedding_dim, hidden_size, dropout = 0.2, num_layers=3).to(device1)
model = torch.compile(model)

loss_fn = nn.CrossEntropyLoss(ignore_index=71)

opti = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=1e-4)

scheduler = get_cosine_schedule_with_warmup(
    opti,             
    num_warmup_steps=nb_step_train*4,  
    num_training_steps=nb_step_train*num_epoch
)

mod = torch.load("/kaggle/input/gru/pytorch/default/1/Model_GRU.pt")
model.load_state_dict(mod["model_state_dict"])

In [None]:
opti = torch.optim.AdamW(model.parameters(), lr=0.00025, weight_decay=1e-4)

scheduler = get_cosine_schedule_with_warmup(
    opti,             
    num_warmup_steps=nb_step_train*4,  
    num_training_steps=nb_step_train*num_epoch
)

In [None]:
list_offset = []
l_tot = []
teacher_forcing_ratio = 1
best_val = float("inf")

scaler = torch.amp.GradScaler()

for epoch in range(num_epoch):

    model.train()

    # -------------- TRAIN LOOP --------------
    train_loss = 0.0
    
    for X, local_context, Y in iter(train_dl):
        h = model.init_hidden(batch_size).to(device1)
        local_context = local_context.to(device1)
        X = X.to(device1)
        Y = Y.to(device1, dtype=torch.long)
        opti.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda"):
            pred, h = model(X, local_context, h)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        scaler.scale(loss).backward()
#        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        scaler.step(opti)
        scaler.update()

        # book-keeping
        train_loss += loss.detach().item()
        scheduler.step()
        
    train_ppl = math.exp(train_loss/nb_step_train)

    # -------------- VALIDATION --------------
    val_ppl_tf1 = evaluate_tf1(model, test_dl, loss_fn, device1, vocab_size)
#    val_ppl_free = evaluate_free(model, test_dl, val_fn, device1)

    print(
        f"Epoch {epoch} | "
        f"Train PPL: {train_ppl:.3f} | "
        f"Val PPL (TF=1): {val_ppl_tf1:.3f} | "
#        f"Val PPL (free): {val_ppl_free:.3f} | "
    )

    # --------- Sample generation + diversity metrics ---------

    # Record accuracy
    l_tot.append(val_ppl_tf1)
    if val_ppl_tf1 < best_val :
        best_val = val_ppl_tf1
        torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": opti.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "val_ppl": val_ppl_tf1,
                },
                "model",
        )