# This notebook will serve as a way to train subword generation MultiHead attention transformer + RoPe

# Data preparation

In [None]:
import math
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torch.nn.utils.rnn import pad_sequence

from transformers import PreTrainedTokenizerFast, get_cosine_schedule_with_warmup
from torchtune.modules import RotaryPositionalEmbeddings

In [None]:
TOKENIZERS_PARALLELISM=True

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="/kaggle/input/subword-rnn-lstm/rap_tokenizer.json",
    pad_token = "<PAD>"
)

## Creation of the dataset

In [None]:
dataset_ = np.load("/kaggle/input/subword-rnn-lstm/encoded.npy","r")

result = []

for t in dataset_:
    if t == tokenizer.convert_tokens_to_ids("α") : 
        current = []
        current.append(t)
    elif t == tokenizer.convert_tokens_to_ids("θ") :
        current.append(t)
        result.append(torch.tensor(current))
    else :
        current.append(t)
if current:  
    result.append(torch.tensor(current))

In [None]:
class SongDataset(Dataset):
    def __init__(self, texts, length_seq, stride, use_offset=True):
        self.samples = []
        self.length_seq = length_seq
        self.stride = stride

        for text in texts:
            L = len(text)

            offset = torch.randint(0, stride, (1,)).item() if use_offset else 0

            # --- Boucle principale ---
            for start in range(offset, max(1, L - self.length_seq), self.stride):
                x_start, x_end = start, start + self.length_seq
                y_start, y_end = start + 1, start + 1 + self.length_seq

                x = text[x_start:x_end]
                y = text[y_start:y_end]

                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
len_train = int(len(result) * 0.8)
len_test = len(result) - len_train

generator = torch.Generator().manual_seed(42)
train, test = random_split(result, [len_train, len_test], generator = generator)

## Creation of Dataloader and co

In [None]:
def collate_batch(batch, tokenizer):
    X, Y = zip(*batch)

    X_padded = pad_sequence(X, batch_first=True, padding_value=tokenizer.pad_token_id)
    Y_padded = pad_sequence(Y, batch_first=True, padding_value=tokenizer.pad_token_id)

    batch_enc = [ X_padded,Y_padded]

    return batch_enc

# Create a callable version of collate_fn with your tokenizer
collate_fn = partial(collate_batch, tokenizer=tokenizer)

#Normally a function requires to specify the options at the initiation but partial allows to specify values for the required option that will
#be stored and then be used when the function will be called
# collate_fn(batch) == collate_batch(batch, tokenizer=tokenizer)

In [None]:
batch_size = 256

train_ds = SongDataset(train, length_seq=256, stride = 4, use_offset = True) #The offset is smaller than for RNN's because the Transformer doesn't have the same assumption about data so it needs to see more to learn the same amount of data
test_ds = SongDataset(test, length_seq=256, stride = 4, use_offset = False)

train_dl = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", 
                        num_workers=4, prefetch_factor=4, shuffle=True, drop_last=True, collate_fn = collate_fn)
test_dl = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda:0", 
                       num_workers=2, prefetch_factor=2, shuffle=False, drop_last=True, collate_fn = collate_fn)

## Models

### Training part

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, nheads, dropout, max_seq, bias=True):
        super().__init__()

        self.nheads = nheads
        assert embed_dim % nheads == 0, "Embedding dim is not divisible by nheads"
        self.head_dim = embed_dim // nheads
        self.drop = nn.Dropout(dropout)
        self.prob_drop = dropout
        
        self.proj_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        
        self.rope = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len = max_seq)

    def forward(self, x: torch.Tensor, attn_mask=None, is_causal=True) -> torch.Tensor:
        
        # Step 1
        B,L,_ = x.shape
        result = self.proj_qkv(x)
        q, k, v = torch.chunk(result, 3, dim=-1)

        # Step 2
        # (N, L_t, head_dim) -> (N, L_t, nheads, head_dim)
        q,k,v = [t.reshape(B, L, self.nheads, self.head_dim) for t in (q,k,v)]

        q = self.rope(q)
        k = self.rope(k)

        # (N, nheads, L_t, head_dim)
        q,k,v = [t.transpose(1,2) for t in (q,k,v)]

        # Step 3
        attn_output = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.prob_drop, attn_mask = attn_mask, is_causal=is_causal)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        return self.drop(self.out_proj(attn_output))

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.15):
        super().__init__()
        self.d_ff = d_model * 4
        self.norm = nn.RMSNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout, max_seq=256)
        self.ff = nn.Sequential(
            nn.Linear(d_model, self.d_ff),
            nn.GELU(approximate="tanh"),
            nn.Linear(self.d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x, mask=None):
        # masked multi-head self-attention
        x = x + self.attn(self.norm(x), is_causal=True) #Is causal = for token t, mask on every tokens after (cannot see what's coming after)
        # feed-forward
        x = x + self.ff(self.norm(x))
        return x

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=4, n_layers=4, max_len=256):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.pos = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, n_heads) for _ in range(n_layers)])
        
        self.drop = nn.Dropout(0.1)
        self.norm = nn.RMSNorm(d_model)
        
        self.fc = nn.Linear(d_model, vocab_size, bias = False)
        self.fc.weight = self.emb.weight

    def forward(self, x):
        x = self.drop(self.emb(x))
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.fc(x)

In [None]:
device = torch.device("cuda:0")

In [None]:
vocab_size = 4000
num_epoch = 100

nb_step_train = len(train_dl)
nb_step_test = len(test_dl)

total_steps = nb_step_train * num_epoch

model = DecoderOnlyTransformer(vocab_size, d_model=320, n_heads=5, n_layers=6).to(device) 
model = torch.compile(model)
    
loss_fn = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.05)
val_fn = nn.CrossEntropyLoss(ignore_index=0)

opti = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=2e-3)

scheduler = get_cosine_schedule_with_warmup(
    opti,
    num_warmup_steps=int(0.05 * total_steps),
    num_training_steps=total_steps
)

In [None]:
#ckpt = torch.load("/kaggle/input/warmed-up-sub-mh1/pytorch/default/3/MHA_model_warmed_up.pt", map_location=device)
#model.load_state_dict(ckpt["model_state_dict"])

#last_epoch = ckpt["epoch"]

In [None]:
def sample_with_temp(logits, temp=1.0):
    probs = (logits / temp).softmax(dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def distinct_n_chars(text, n=3):
    ngrams = [text[i:i+n] for i in range(len(text)-n+1)]
    return len(set(ngrams)) / max(1, len(ngrams))

In [None]:
def training(model, train_dl, loss_fn, vocab_size, opti, scaler, scheduler, device) :
    
    model.train()
    train_loss = 0.0

    for X,Y in iter(train_dl) :
        
        X,Y = X.to(device), Y.to(device, dtype=torch.long)
        opti.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda"):
            pred = model(X)
            loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        scaler.step(opti)
        scaler.update()
        scheduler.step()

        train_loss += loss.detach().item() 

    return train_loss, scaler

def evaluate(model, eval_dl, loss_fn, vocab_size, device) :
    
    model.eval()
    eval_loss = 0.0

    for X,Y in iter(eval_dl) :
        
        X,Y = X.to(device), Y.to(device, dtype=torch.long)
        
        pred = model(X)
        loss = loss_fn(pred.view(-1, vocab_size), Y.view(-1))

        eval_loss += loss.detach().item() 

    return eval_loss

In [None]:
list_offset = []
l_tot = []
teacher_forcing_ratio = 1
best_val = float("inf")

scaler = torch.amp.GradScaler()

for epoch in range(num_epoch):
    
    # -------------- TRAIN LOOP --------------
    train_loss, scaler = training(model, train_dl, loss_fn, vocab_size, opti, scaler, scheduler, device)
    train_ppl = math.exp(train_loss / nb_step_train)
    
    # -------------- VALIDATION --------------
    with torch.no_grad() :      
        val_loss = evaluate(model, test_dl, val_fn, vocab_size, device)
        val_ppl = math.exp(val_loss / nb_step_test)

    print(
        f"Epoch {epoch} | "
        f"Train PPL: {train_ppl:.3f} | "
        f"Val PPL : {val_ppl:.3f}"
    )

    # --------- Sample generation + diversity metrics ---------
    if epoch % 5 == 0 and epoch != 0:
        
        model.eval()
        with torch.no_grad():
        # Context priming
            X, _ = next(iter(test_dl))
            start = X[0:1, :20].to(device)

        #Generating
            context = start
            for _ in range(200):  
                logits = model(context)
                next_token = sample_with_temp(logits[:, -1, :], temp=0.6)
                context = torch.cat([context, next_token], dim=1)

            gen_text = tokenizer.decode(context[0, 20:].tolist())
            
        d2 = distinct_n_chars(gen_text, n=2)
        d3 = distinct_n_chars(gen_text, n=3)

        print("\n=== Initial text ===")
        print("".join(tokenizer.decode(X[0:1,:20].squeeze(0))))
        print("\n=== Sample Generation ===")
        print(gen_text[:200]) 
        print(f"Distinct-2: {d2:.3f} | Distinct-3: {d3:.3f}", end="\n")

    # Record accuracy
    l_tot.append(val_ppl)
    if val_ppl < best_val :
        best_val = val_ppl
        torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": opti.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "val_ppl": val_ppl,
                },
                "model",
        )