## This code will implement the Seq2seq no attention model

This code is used on Kaggle

In [None]:
import pickle
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader, random_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

## Pipeline comparaison

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
with open("/kaggle/input/new-seq2seq/Encoding_map.pkl", "rb") as f:
    mapping = pickle.load(f)

In [None]:
class MMapSeq2Seq(Dataset):
    def __init__(self, prefix=""):
        self.ctx = np.load(prefix+"context_np.npy","r")
        self.x   = np.load(prefix+"X_np.npy","r")
        self.y   = np.load(prefix+"Y_np.npy","r")

        self.ctx_off = np.load(prefix+"context_offset.npy")
        self.ctx_len = np.load(prefix+"context_length.npy")
        self.x_off   = np.load(prefix+"X_offset.npy")
        self.x_len   = np.load(prefix+"X_length.npy")
        self.y_off   = np.load(prefix+"Y_offset.npy")
        self.y_len   = np.load(prefix+"Y_length.npy")

        assert len(self.ctx_len) == len(self.x_len) == len(self.y_len)

    def __len__(self):
        return len(self.ctx_len)

    def __getitem__(self, i):
        # vues memmap -> torch (zéro copie CPU)
        s, L = int(self.ctx_off[i]), int(self.ctx_len[i])
        ctx = torch.from_numpy(self.ctx[s:s+L].astype(np.int32))

        s, L = int(self.x_off[i]), int(self.x_len[i])
        x = torch.from_numpy(self.x[s:s+L].astype(np.int32))

        s, L = int(self.y_off[i]), int(self.y_len[i])
        y = torch.from_numpy(self.y[s:s+L].astype(np.int32))

        return ctx, x, y

In [None]:
def collate(batch):
    ctxs, xs, ys = zip(*batch)  # tuples de 1D tensors CPU
    ctxs = [t.to(torch.long, copy=False) for t in ctxs]
    xs   = [t.to(torch.long, copy=False) for t in xs]
    ys   = [t.to(torch.long, copy=False) for t in ys]

    ctx_pad = torch.nn.utils.rnn.pad_sequence(ctxs, batch_first=True, padding_value=PAD_ID)
    x_pad   = torch.nn.utils.rnn.pad_sequence(xs,   batch_first=True, padding_value=PAD_ID)
    y_pad   = torch.nn.utils.rnn.pad_sequence(ys,   batch_first=True, padding_value=PAD_ID)

    ctx_len = torch.tensor([t.numel() for t in ctxs], dtype=torch.int64)  # CPU Long (pack_padded)
    x_len   = torch.tensor([t.numel() for t in xs],   dtype=torch.int64)

    return (ctx_pad, x_pad, y_pad), (ctx_len, x_len)

### Now, pipeline prep

In [None]:
PAD_ID = 0
ds = MMapSeq2Seq(prefix="/kaggle/input/new-seq2seq/")  # mets des préfixes train/val si tu split offline

In [None]:
train_size = int(0.85 * len(ds))
test_size   = len(ds) - train_size
train_ds, test_ds = random_split(ds, [train_size, test_size], torch.Generator().manual_seed(42))

batch_size = 256

In [None]:
if device == "cpu" :
    train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate) #Shuffle False because we need the RNN to use previous sequences data to predict next one
    test_ld = DataLoader(test_ds, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate)
else : 
    train_ld = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device=device, 
                        num_workers=4, shuffle=True, drop_last=True, persistent_workers=False, collate_fn=collate) #Shuffle False because we need the RNN to use previous sequences data to predict next one
    test_ld = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device=device, 
                        num_workers=2, shuffle=False, drop_last=True, persistent_workers=False, collate_fn=collate)

List of operations :

- First the entire batch pass through the Encodeur, we extract the h,c from this

- Then we iterate over all the length of the sequence, with the previous input and we use at first the h,c from the Encodeur

- Then we calculate the loss

## Create Seq2seq

In [None]:
class Encodeur(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=1):
        super(Encodeur, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(input_size = embedding_dim, 
                            hidden_size = hidden_size, 
                            num_layers = num_layers, 
                            dropout = 0.2,
                            batch_first=True)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, length_batch):
        #Encodeur
        emb = self.embed(x)
        emb_drop = self.dropout(emb)
        #Pack before LSTM to avoid unecessary compute
        packed_x = pack_padded_sequence(emb_drop, length_batch, batch_first=True, enforce_sorted=False)                      
        packed_out, (h,c) = self.lstm(packed_x)             
        #Unpack after
        _, _ = pad_packed_sequence(packed_out, batch_first=True) #First right now only h,c
        return (h, c)

In [None]:
class Decodeur(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, masked_mapping, num_layers=1):
        super(Decodeur, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(input_size = embedding_dim, 
                            hidden_size = hidden_size, 
                            num_layers = num_layers, 
                            dropout = 0.2,
                            batch_first=True)
        
        self.final = nn.Linear(hidden_size,vocab_size)
        self.mask = masked_mapping

#    def forward(self, x, h, c, length_batch):
    def forward(self, x, h, c, length_batch, targets=None, teacher_forcing_ratio=1.0):
        batch_size, max_len = targets.size()
        #batch_size, max_len = Y.size() 
        device = x.device

        # On stockera les logits ici
        all_logits = []

        # Premier input : <SOS> pour tout le batch
        input_t = x[:, 0]

        for t in range(0, max_len):
            emb = self.embed(input_t).unsqueeze(1)  # (batch, 1, emb_dim)
            out, (h, c) = self.lstm(emb, (h, c))    # (batch, 1, hidden)
            logit = self.final(out.squeeze(1))      # (batch, vocab_size)
            masked_logit = logit.masked_fill(self.mask, float("-inf"))
            all_logits.append(masked_logit.unsqueeze(1))

            # Decide teacher forcing ou pas
            if torch.rand(1).item() < teacher_forcing_ratio:
                input_t = targets[:, t]   # vérité
            else:
                input_t = masked_logit.argmax(dim=-1)  # prédiction

        all_logits = torch.cat(all_logits, dim=1)  # (batch, max_len-1, vocab_size)
        return all_logits, (h, c)

In [None]:
vocab_size = len(mapping)
embedding_size = 96
hidden_size = 512
num_epoch = 200

#Some of the mapping are only for the encodeur so the decodeur can't produce them, we need to mask them from the loss
mapping_inverse = {i: ch for ch, i in mapping.items()}
masked_mapping = list(mapping_inverse.keys())[117:-1]

mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
mask[masked_mapping] = True

nb_step_test = len(test_ld)
nb_step_train = len(train_ld)

enco = Encodeur(vocab_size, embedding_size, hidden_size, num_layers=2).to(device)
deco = Decodeur(vocab_size, embedding_size, hidden_size, mask,num_layers=2).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=0) #Just ignore the padding

params = list(enco.parameters()) + list(deco.parameters())
opti = torch.optim.AdamW(params, lr=0.01, betas=(0.9,0.999), weight_decay=1e-3)

sched_warm = torch.optim.lr_scheduler.LinearLR(opti, start_factor=0.25, end_factor=1.0, total_iters=nb_step_train*5)
sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*200, T_mult=2, eta_min=0.001) #1 epoch => 2 => 4 => 8

## Creation boucle entraînement Decodeur

In [None]:
def token_accuracy(logits, Y, pad_id=0):
    # logits: [B, T, V], Y: [B, T]
    preds = logits.argmax(dim=-1)                 # [B, T]
    mask  = (Y != pad_id)                         # [B, T]
    correct = (preds == Y) & mask
    return correct.sum().float() / mask.sum().float()

def topk_token_accuracy(logits, Y, k=5, pad_id=0):
    """
    logits : (B, T, V)
    Y      : (B, T)
    """
    # tronquer Y si jamais logits et Y diffèrent un peu
    Y = Y[:, :logits.size(1)]

    # top-k indices
    topk = logits.topk(k, dim=-1).indices  # (B, T, k)

    # mask pour ignorer les PAD
    mask = (Y != pad_id)

    # comparaison : Y est-il dans top-k ?
    correct = (topk == Y.unsqueeze(-1)).any(dim=-1) & mask  # (B, T)

    return correct.sum().float() / mask.sum().float()


In [None]:
l_tot = []
teacher_forcing_ratio = 1
#Early stopping
#early_stopping_count = 0
#patience = 5
best_val = float("inf")
begin_epoch = 0
scaler = torch.amp.GradScaler()

for epoch in range(num_epoch) :
    
    enco.train()
    deco.train()

    #Create loss per epoch
    l_train = 0.0
    l_test = 0.0

#    teacher_forcing_ratio = max(0.5, 1.0 - (max(0,epoch-40)) * 0.002)
    if epoch%10==0:
        print(teacher_forcing_ratio)
    
    for (context,X, Y), (length_cont,length_text) in iter(train_ld) :
        X = X.to(device, non_blocking = True)
        Y = Y.to(device, non_blocking = True)
        context = context.to(device, non_blocking = True)
#        length_text = (Y != 0).sum(dim=1).cpu()
        opti.zero_grad(set_to_none=True)

        #Encodeur part, sees the whole context of the batch and output the hidden and cell state
        h_enco,c_enco = enco(context,length_cont)
        
        h_dec = h_enco
        c_dec = c_enco
        
        with torch.amp.autocast(device_type="cuda"):
            logits, (h_dec, c_dec) = deco(
            X, h_dec, c_dec, length_text, 
            targets=Y, teacher_forcing_ratio=teacher_forcing_ratio   # <--- par ex.
            )
            loss = loss_fn(logits.reshape(-1, vocab_size), Y.reshape(-1))
            #logits, (h_dec, c_dec) = deco(X,h_dec,c_dec,length_text)
            #loss = loss_fn(logits.reshape(-1,vocab_size),Y.reshape(-1))

        scaler.scale(loss).backward()
        #Gradient clipping to avoid exploding gradient
        scaler.unscale_(opti)
        torch.nn.utils.clip_grad_norm_(list(enco.parameters()) + list(deco.parameters()), 0.25)

        #Scaler
        scaler.step(opti); scaler.update()
        l_train += loss.item()

        #Scheduler part
        #Warm start
        
        if sched_post.T_cur == 0 and epoch > 5:  #After warm restart decrease the max learning rate
            sched_post.base_lrs[0] = sched_post.base_lrs[0] * 0.5
#            sched_post.eta_min = sched_post.eta_min * 1.25
            print(f"Decrease {sched_post.base_lrs[0]}, {sched_post.eta_min}")

        step_scheduler = sched_warm if epoch < 5 else sched_post
        step_scheduler.step()

    #Test data part
    
    enco.eval()
    deco.eval()

    acc_sum, tok_sum = 0.0, 0
    topk_acc_sum = 0.0    
    
    with torch.inference_mode():
        with torch.amp.autocast("cuda"):
            for (context, X, Y), (length_cont,length_text) in iter(test_ld) :
                X = X.to(device, non_blocking = True)
                Y = Y.to(device, non_blocking = True)
                context = context.to(device, non_blocking = True)

                #Encodeur part, sees the whole context of the batch and output the hidden and cell state
                h_enco,c_enco = enco(context,length_cont)
        
                h_dec = h_enco
                c_dec = c_enco
                
                logits, (h_dec, c_dec) = deco(X, h_dec, c_dec, length_text, targets=Y, teacher_forcing_ratio=1)
                loss = loss_fn(logits.reshape(-1, vocab_size), Y.reshape(-1))

                #logits, (h_dec, c_dec) = deco(X,h_dec,c_dec,length_text)
                #loss = loss_fn(logits.reshape(-1,vocab_size),Y.reshape(-1))
                l_test += loss.item()
                acc = token_accuracy(logits, Y).item()
                
                # top-1 accuracy
                acc_sum += ((logits.argmax(-1) == Y) & (Y != 0)).sum().item()
                tok_sum += (Y != 0).sum().item()

                # top-k accuracy (par ex. k=5)
                topk_acc_sum += topk_token_accuracy(logits, Y, k=5).item()

        epoch_token_acc = acc_sum / tok_sum
        epoch_topk_acc  = topk_acc_sum / nb_step_test   # moyenne sur les batches
        if np.exp(best_val/nb_step_test) < 3.5 :
            teacher_forcing_ratio = 0.9
        elif np.exp(best_val/nb_step_test) < 3.75 :
            teacher_forcing_ratio = 0.9
        elif np.exp(best_val/nb_step_test) < 4 :
            teacher_forcing_ratio = 0.95

        #Record the loss of the epoch
        l_tot.append(l_test); 

        if l_test < best_val :
            print(epoch, np.exp(l_test/nb_step_test), epoch_token_acc, epoch_topk_acc, "\n")
            best_val = l_test
#            early_stopping_count = 0
            torch.save({
                "epoch": epoch,
                "encoder_state_dict": enco.state_dict(),
                "decoder_state_dict": deco.state_dict(),
                "optimizer_state_dict": opti.state_dict(),
                "scheduler_state_dict": sched_post.state_dict(),
                "val_loss": l_test,
            }, "model")
        
#        elif l_test >= best_val :
#            early_stopping_count += 1

#        if early_stopping_count == patience :
#            print("Early Stopping")
#            break 

print(f"Liste of offset used : {list_offset}")