## This code will implement the training of the Seq2seq light attention model

This code is used on Colab

In [None]:
import gc
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

## Pipeline comparaison

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
with open("/content/Encoding_map.pkl", "rb") as f:
    mapping = pickle.load(f)

In [5]:
class MMapSeq2Seq(Dataset):
    def __init__(self, prefix=""):
        self.ctx = np.load(prefix+"context_np.npy","r")
        self.x   = np.load(prefix+"X_np.npy","r")
        self.y   = np.load(prefix+"Y_np.npy","r")

        self.ctx_off = np.load(prefix+"context_offset.npy")
        self.ctx_len = np.load(prefix+"context_length.npy")
        self.x_off   = np.load(prefix+"X_offset.npy")
        self.x_len   = np.load(prefix+"X_length.npy")
        self.y_off   = np.load(prefix+"Y_offset.npy")
        self.y_len   = np.load(prefix+"Y_length.npy")

        assert len(self.ctx_len) == len(self.x_len) == len(self.y_len)

    def __len__(self):
        return len(self.ctx_len)

    def __getitem__(self, i):
        # vues memmap -> torch (zéro copie CPU)
        s, L = int(self.ctx_off[i]), int(self.ctx_len[i])
        ctx = torch.from_numpy(self.ctx[s:s+L].astype(np.int32))

        s, L = int(self.x_off[i]), int(self.x_len[i])
        x = torch.from_numpy(self.x[s:s+L].astype(np.int32))

        s, L = int(self.y_off[i]), int(self.y_len[i])
        y = torch.from_numpy(self.y[s:s+L].astype(np.int32))

        return ctx, x, y

In [6]:
def collate(batch):
    ctxs, xs, ys = zip(*batch)  # tuples de 1D tensors CPU
    ctxs = [t.to(torch.long, copy=False) for t in ctxs]
    xs   = [t.to(torch.long, copy=False) for t in xs]
    ys   = [t.to(torch.long, copy=False) for t in ys]

    ctx_pad = torch.nn.utils.rnn.pad_sequence(ctxs, batch_first=True, padding_value=PAD_ID)
    x_pad   = torch.nn.utils.rnn.pad_sequence(xs,   batch_first=True, padding_value=PAD_ID)
    y_pad   = torch.nn.utils.rnn.pad_sequence(ys,   batch_first=True, padding_value=PAD_ID)

    ctx_len = torch.tensor([t.numel() for t in ctxs], dtype=torch.int64)  # CPU Long (pack_padded)
    x_len   = torch.tensor([t.numel() for t in xs],   dtype=torch.int64)

    return (ctx_pad, x_pad, y_pad), (ctx_len, x_len)

### Now, pipeline prep

In [7]:
PAD_ID = 0
ds = MMapSeq2Seq(prefix="/content/")  # mets des préfixes train/val si tu split offline

In [8]:
train_size = int(0.85 * len(ds))
test_size   = len(ds) - train_size
train_ds, test_ds = random_split(ds, [train_size, test_size], torch.Generator().manual_seed(42))

batch_size = 16

In [9]:
len_max = 0
for i in ds :
    if len(i[0]) > len_max :
        len_max = len(i[0])
len_max

961

In [10]:
if device == "cpu" :
    train_ld = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True, collate_fn=collate) #Shuffle False because we need the RNN to use previous sequences data to predict next one
    test_ld = DataLoader(test_ds, batch_size=batch_size, shuffle=False, drop_last=True, collate_fn=collate)
else :
    train_ld = DataLoader(train_ds, batch_size=batch_size, pin_memory=True, pin_memory_device=device,
                        num_workers=4, shuffle=True, drop_last=True, persistent_workers=False, collate_fn=collate) #Shuffle False because we need the RNN to use previous sequences data to predict next one
    test_ld = DataLoader(test_ds, batch_size=batch_size, pin_memory=True, pin_memory_device=device,
                        num_workers=2, shuffle=False, drop_last=True, persistent_workers=False, collate_fn=collate)

List of operations :

- First the entire batch pass through the Encodeur, we extract the h,c from this

- Then we iterate over all the length of the sequence, with the previous input and we use at first the h,c from the Encodeur

- Then we calculate the loss

## Create Seq2seq

In [11]:
class Encodeur(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=1):
        super(Encodeur, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(input_size = embedding_dim,
                            hidden_size = hidden_size,
                            num_layers = num_layers,
                            dropout = 0.2,
                            batch_first=True)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, length_batch):

        emb = self.embed(x)
        emb_drop = self.dropout(emb)

        packed_x = pack_padded_sequence(emb_drop, length_batch, batch_first=True, enforce_sorted=False)
        packed_out, (h,c) = self.lstm(packed_x)
        output, length_out = pad_packed_sequence(packed_out, batch_first=True, total_length=x.size(1))
        output = self.dropout(output)

        return output, length_out, h, c

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)
        
    def forward(self, query, keys, mask):
        #query (B,H)
        #keys/context_out (B,S,H) where S represent each unit time

        q = self.Wa(query).unsqueeze(1)            # (B,1,H)
        k = self.Ua(keys)                           # (B,S,H)
        scores = self.Va(torch.tanh(q + k))  # (B,S,1)
        scores = scores.squeeze(-1) # (B,S)

        scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = F.softmax(scores, dim=-1)         # (B,S)
        context = torch.bmm(weights.unsqueeze(1), keys)  # (B,1,H)

        return context #, weights free some memory

In [12]:
class DecoderAttention(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, masked_mapping, num_layers=1):
        super(DecoderAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.mask = masked_mapping

        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(input_size=embedding_dim + hidden_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            dropout=0.2,
                            batch_first=True)

        self.final = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.1)

    def forward_step(self, x, h, c, encod_out, mask_att):
        """
        x : (B,)          token ids
        h, c : (num_layers, B, H)
        encod_out : (B, L_enc, H)
        mask_att : (B, L_enc) attention mask
        """
        B, L_enc, H = encod_out.shape
        embedded = self.dropout(self.embed(x))  # (B,E)

        # ---- Luong dot-product attention ----
        # query = dernier état caché de la dernière couche
        query = h[-1]  # (B,H)

        # scores = produit scalaire (B,L_enc)
        scores = torch.bmm(encod_out, query.unsqueeze(2)).squeeze(2)

        # masque sur les PAD
        if mask_att is not None:
            scores = scores.masked_fill(mask_att == 0, float('-inf'))

        # softmax pour obtenir les poids
        weights = F.softmax(scores, dim=-1)  # (B,L_enc)

        # contexte pondéré
        h_prime = torch.bmm(weights.unsqueeze(1), encod_out).squeeze(1)  # (B,H)

        # ---- LSTM step ----
        lstm_in = torch.cat([embedded, h_prime], dim=-1).unsqueeze(1)  # (B,1,E+H)
        out, (h, c) = self.lstm(lstm_in, (h, c))

        logit = self.final(out.squeeze(1))  # (B, vocab_size)
        masked_logit = logit.masked_fill(self.mask, float("-inf"))

        return masked_logit, (h, c)

    def forward(self, x, encod_out, h, c, targets, teacher_forcing_ratio, mask_att, loss_fn=None):
      batch_size, max_len = targets.size()
      input_x = x[:, 0]  # <SOS>

      all_logits = []

      for t in range(max_len):
        out, (h, c) = self.forward_step(input_x, h, c, encod_out, mask_att)
        all_logits.append(out.unsqueeze(1))   # (B,1,V)

        # Teacher forcing
        if torch.rand(1).item() < teacher_forcing_ratio:
            input_x = targets[:, t]
        else:
            input_x = out.argmax(dim=-1)

      all_logits = torch.cat(all_logits, dim=1)  # (B, T, V)

      if loss_fn is None:
          return all_logits
      else:
        # reshape for CrossEntropyLoss
          loss = loss_fn(
              all_logits.reshape(-1, all_logits.size(-1)),
              targets.reshape(-1)
          )
          return loss

class DecoderAttention(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, masked_mapping, len_max, num_layers=1):
        super(DecoderAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.mask = masked_mapping
        
        self.embed = nn.Embedding(vocab_size, embedding_dim)
#        self.attention = BahdanauAttention(hidden_size)
        self.lstm = nn.LSTM(input_size = embedding_dim + hidden_size,
                            hidden_size = hidden_size,
                            num_layers = num_layers,
                            dropout = 0.2,
                            batch_first=True)
        
        self.final = nn.Linear(hidden_size,vocab_size)
        self.att_lin = nn.Linear(embedding_dim, len_max)
        self.dropout = nn.Dropout(0.1)

    def forward_step(self,x, h, c, encod_out,mask_att) :
        
        global L_enc
        _,L_enc,_ = encod_out.shape
        embedded = self.dropout(self.embed(x)) #(B,E)

        #Synthesizer part
        #Now Dense synthesizer from https://proceedings.mlr.press/v139/tay21a/tay21a.pdf
        #and https://iopscience.iop.org/article/10.1088/1742-6596/2580/1/012006/pdf
        
        emb_lin = self.att_lin(embedded)[:,:L_enc]
        emb_lin = emb_lin.masked_fill(mask_att == 0, float('-inf'))

        weights = F.softmax(emb_lin, dim=-1)
        h_prime = torch.bmm(weights.unsqueeze(1),encod_out).squeeze(1)

        #Use of attention by using the latest hidden
#        query = h[-1]
#        attn_context = self.attention(query,encod_out,mask_att)

        lstm_in = torch.cat([embedded, h_prime], dim=-1).unsqueeze(1) #(B,1,E+H)
        out, (h,c) = self.lstm(lstm_in, (h, c))
        
        logit = self.final(out.squeeze(1))      # (B, vocab_size)
        masked_logit = logit.masked_fill(self.mask, float("-inf"))
        
        return masked_logit, h

    def forward(self, x, encod_out, h, c, targets, teacher_forcing_ratio, mask_att, loss_fn=None): #On each batch

        batch_size, max_len = targets.size()
        input_x = x[:, 0]
        
        if loss_fn is None :
            all_logits = []
        total_loss = 0.0

        for t in range(0, max_len):
            out, h = self.forward_step(input_x, h, c, encod_out,mask_att)
            if loss_fn is None :
                all_logits.append(out.unsqueeze(1))
            else :
                total_loss += loss_fn(out, targets[:, t])

            if torch.rand(1).item() < teacher_forcing_ratio:
                input_x = targets[:, t]   # vérité
            else:
                input_x = out.argmax(dim=-1)  # prédiction

            del out

        if loss_fn is None :
            all_logits = torch.cat(all_logits, dim=1)
            return all_logits
        else :
            total_loss = total_loss / max_len  
            return total_loss

In [13]:
vocab_size = len(mapping)
embedding_size = 96
hidden_size = 512
num_epoch = 100
#accum_steps = 16
nb_step_test = len(test_ld)
nb_step_train = len(train_ld)
#nb_update_per_epoch = nb_step_train // accum_steps

#Some of the mapping are only for the encodeur so the decodeur can't produce them, we need to mask them from the loss
mapping_inverse = {i: ch for ch, i in mapping.items()}
masked_mapping = list(mapping_inverse.keys())[116:-1]

mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
mask[masked_mapping] = True

enco = Encodeur(vocab_size, embedding_size, hidden_size, num_layers=2).to(device)
deco = DecoderAttention(vocab_size, embedding_size, hidden_size, mask,num_layers=2).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=0) #Just ignore the padding

params = list(enco.parameters()) + list(deco.parameters())

opti = torch.optim.AdamW(params, lr=0.002, betas=(0.9,0.999), weight_decay=1e-4)

sched_warm = torch.optim.lr_scheduler.LinearLR(opti, start_factor=0.2, end_factor=1.0, total_iters=nb_step_train*3)
sched_post = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opti, T_0=nb_step_train*40, T_mult=2, eta_min=0.0005) #1 epoch => 2 => 4 => 8

Recharge to previous level

checkpoint = torch.load("/content/model1", map_location=device,weights_only=False)

enco.load_state_dict(checkpoint["encoder_state_dict"])
deco.load_state_dict(checkpoint["decoder_state_dict"])
opti.load_state_dict(checkpoint["optimizer_state_dict"])
sched_post.load_state_dict(checkpoint["scheduler_state_dict"])
start_epoch = checkpoint["epoch"] + 1
best_val = checkpoint["val_loss"]

## Creation boucle entraînement Decodeur

In [14]:
def token_accuracy(logits, Y, pad_id=0):
    # logits: [B, T, V], Y: [B, T]
    preds = logits.argmax(dim=-1)                 # [B, T]
    mask  = (Y != pad_id)                         # [B, T]
    correct = (preds == Y) & mask
    return correct.sum().float() / mask.sum().float()

def topk_token_accuracy(logits, Y, k=5, pad_id=0):
    """
    logits : (B, T, V)
    Y      : (B, T)
    """
    # tronquer Y si jamais logits et Y diffèrent un peu
    Y = Y[:, :logits.size(1)]

    # top-k indices
    topk = logits.topk(k, dim=-1).indices  # (B, T, k)

    # mask pour ignorer les PAD
    mask = (Y != pad_id)

    # comparaison : Y est-il dans top-k ?
    correct = (topk == Y.unsqueeze(-1)).any(dim=-1) & mask  # (B, T)

    return correct.sum().float() / mask.sum().float()


In [None]:
l_tot = []
teacher_forcing_ratio = 1
#Early stopping
#early_stopping_count = 0
#patience = 5
best_val = float("inf")
begin_epoch = 0
scaler = torch.amp.GradScaler()
#accum_steps = 16   # number of mini-batches to accumulate
#global_step = 0

for epoch in range(num_epoch) :
    enco.train(); deco.train()

    l_train = 0.0
    l_test = 0.0

    for (context, X, Y), (length_cont,length_text) in iter(train_ld):
        X = X.to(device, non_blocking=True)
        Y = Y.to(device, non_blocking=True)
        context = context.to(device, non_blocking=True)

        # pas de zero_grad ici → on le fait seulement après accum_steps
        encod_out, length_out, h_enco, c_enco = enco(context,length_cont)

        arange = torch.arange(context.shape[1], device=length_out.device).unsqueeze(0)
        mask_attn = (arange < length_out.unsqueeze(1)).to(device)

        with torch.amp.autocast(device_type="cuda"):
            loss = deco(X, encod_out, h_enco, c_enco, Y, teacher_forcing_ratio , mask_attn, loss_fn=loss_fn)
            #loss = loss_fn(logits.reshape(-1, vocab_size), Y.reshape(-1))

        scaler.scale(loss).backward()
        scaler.unscale_(opti)
        torch.nn.utils.clip_grad_norm_(list(enco.parameters()) + list(deco.parameters()), 1)

        scaler.step(opti)
        scaler.update()

        step_scheduler = sched_warm if epoch < 3 else sched_post
        step_scheduler.step()

        del loss, X, Y, context
        torch.cuda.empty_cache()
        gc.collect()

        # scheduler
   # if epoch in [20, 40, 80]:
   #     sched_post.base_lrs[0] *= 0.9
   #     print(f"Decrease {sched_post.base_lrs[0]}, {sched_post.eta_min}")

#        step_scheduler = sched_warm if epoch < 2 else sched_post
#        step_scheduler.step()

    #Test data part

    enco.eval()
    deco.eval()

    acc_sum, tok_sum = 0.0, 0
    topk_acc_sum = 0.0

    with torch.inference_mode():
        with torch.amp.autocast("cuda"):
            for (context, X, Y), (length_cont,length_text) in iter(test_ld) :
                X = X.to(device, non_blocking = True)
                Y = Y.to(device, non_blocking = True)
                context = context.to(device, non_blocking = True)

                encod_out, length_out, h_enco, c_enco = enco(context,length_cont)
                arange = torch.arange(context.shape[1], device=length_out.device).unsqueeze(0)  # (1,S)
                mask_attn = arange < length_out.unsqueeze(1)
                mask_attn = mask_attn.to(device)

                logits = deco(X, encod_out, h_enco, c_enco, Y, 1, mask_attn)
                loss = loss_fn(logits.reshape(-1, vocab_size), Y.reshape(-1))

                #logits, (h_dec, c_dec) = deco(X,h_dec,c_dec,length_text)
                #loss = loss_fn(logits.reshape(-1,vocab_size),Y.reshape(-1))
                l_test += loss.item()
                acc = token_accuracy(logits, Y).item()

                # top-1 accuracy
                acc_sum += ((logits.argmax(-1) == Y) & (Y != 0)).sum().item()
                tok_sum += (Y != 0).sum().item()

                # top-k accuracy (par ex. k=5)
                topk_acc_sum += topk_token_accuracy(logits, Y, k=5).item()

        epoch_token_acc = acc_sum / tok_sum
        epoch_topk_acc  = topk_acc_sum / nb_step_test   # moyenne sur les batches
        print("\n", epoch, np.exp(l_test/nb_step_test), epoch_token_acc, epoch_topk_acc, "\n")
        print("".join([mapping_inverse[j.item()] for j in Y[0][:length_text[0].item()]]))
        print("".join([mapping_inverse[j.argmax().item()] for j in logits[0][:length_text[0].item()]]))

        #process = psutil.Process(os.getpid())
        #print("RAM used:", process.memory_info().rss / 1024**3, "GB")

        del logits, loss, X, Y, context
        torch.cuda.empty_cache()
        gc.collect()

        #Record the loss of the epoch
        l_tot.append(l_test);

        if l_test < best_val :
            best_val = l_test
#            early_stopping_count = 0
            torch.save({
                "epoch": epoch,
                "encoder_state_dict": enco.state_dict(),
                "decoder_state_dict": deco.state_dict(),
                "optimizer_state_dict": opti.state_dict(),
                "scheduler_state_dict": sched_post.state_dict(),
                "val_loss": l_test,
            }, "model1.pt")

print(f"Liste of offset used : {list_offset}")


 0 6.114624229603374 0.457869417886966 0.8001526954733296 

Maritime seine c'est la qu'ma team sème<EOL>Ses graines pour qu'un jour elles germent par dizainesEND
Joiseeoesdo l  doest le puaoa poruedore Jo  pranne  dour lu'jn pour lnle  prnse t das lesenne  

 1 5.076262533564633 0.5094245431441365 0.8302845233752404 

Maritime seine c'est la qu'ma team sème<EOL>Ses graines pour qu'un jour elles germent par dizainesEND
Jaicaeqe du l  qoest pe puauoivêmuedore Ji  caannss daur lu'in pour dsle  mrnrest das les rneENDEND

 2 4.843221971714835 0.5229486869175844 0.8376495404007994 

Maritime seine c'est la qu'ma team sème<EOL>Ses graines pour qu'un jour elles germent par dizainesEND
Jaiioéoe du g  doest la puauo vêmu dure Ji  caonsss dour lu'on tour dtle  danrestsdas les nne  

 3 4.253583034049412 0.5575154302311509 0.8549634777469399 

Maritime seine c'est la qu'ma team sème<EOL>Ses graines pour qu'un jour elles germent par dizainesEND
Jaicaaqe du g  doest pa puaào pêru duce Jis paansss d