https://pytorch.org/tutorials/beginner/translation_transformer.html

https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

import math, time
import unicodedata, re, random
import numpy as np

import os
from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler

In [10]:
class CONFIG_class:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = 64         # Batch size = h , 64
        self.max_seq_length = 14
        self.src_vocab_size = 0
        self.tgt_vocab_size = 0
        
        # Hyperparameters for the Transformer model
        self.d_model = 256           # Embedding size for each word, 512
        self.num_heads = 8           # Number of attention heads, 8
        self.num_layers_encoder = 6  # Number of encoder layers, 6
        self.num_layers_decoder = 6  # Number of decoder layers, 6
        self.d_feedforward = 1024    # Dimension of the feedforward layer, 2048
        self.dropout = 0.1           # Dropout rate to prevent overfitting

CONFIG = CONFIG_class()

Lang1, Lang2 = "eng","fra"

In [11]:
SOS_token ,START_token, END_token, PADDING_token = 3, 1, 2, 0
SOS,START, END, PADDING = "[SOS]","[START]", "[END]", "[PADDING]"

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {PADDING: PADDING_token, START: START_token, END:END_token, SOS: SOS_token}
        self.index2word = {SOS_token: SOS, START_token: START, END_token: END, PADDING_token: PADDING}
        self.n_words = 4

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.index2word[self.n_words] = word
            self.n_words += 1

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = s.lower().strip()
    s = ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
    return START + " " + s.strip() + " " + END

    
def prepareData(lang1, lang2, reverse):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)
        
    print("Read %s sentence pairs" % len(pairs))
    pairs = [pair for pair in pairs if len(pair[0].split(' ')) < CONFIG.max_seq_length and len(pair[1].split(' ')) < CONFIG.max_seq_length ]
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

def indexesFromSentence(lang, sentence):
    indexes = [lang.word2index[word] for word in sentence.split(' ')]
    return indexes

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def get_dataloader(config):
    input_lang, output_lang, pairs = prepareData(Lang1, Lang2, False)

    n = len(pairs)
    input_ids = np.zeros((n, config.max_seq_length), dtype=np.int32)
    target_ids = np.zeros((n, config.max_seq_length), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    train_data = TensorDataset(torch.LongTensor(input_ids).to(config.device), torch.LongTensor(target_ids).to(config.device))
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=config.batch_size)
    return input_lang, output_lang, train_dataloader

In [12]:
class DiskDataset(Dataset):
    def __init__(self, input_file, target_file, length, max_seq_length):
        self.input_file = input_file
        self.target_file = target_file
        self.length = length
        self.max_seq_length = max_seq_length

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        with open(self.input_file, 'rb') as f_input, open(self.target_file, 'rb') as f_target:
            # Calcul de la position du batch à lire
            f_input.seek(idx * self.max_seq_length * 4)  # 4 bytes par int32
            f_target.seek(idx * self.max_seq_length * 4)

            # Lecture des données du fichier
            input_data = np.fromfile(f_input, dtype=np.int32, count=self.max_seq_length)
            target_data = np.fromfile(f_target, dtype=np.int32, count=self.max_seq_length)
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        return torch.LongTensor(input_data).to(device), torch.LongTensor(target_data).to(device)

def save_data_to_disk(input_ids, target_ids, input_file, target_file):
    input_ids.tofile(input_file)
    target_ids.tofile(target_file)

def get_dataloader(config,reverse):
    input_lang, output_lang, pairs = prepareData(Lang1, Lang2, reverse)

    n = len(pairs)
    input_ids = np.zeros((n, config.max_seq_length), dtype=np.int32)
    target_ids = np.zeros((n, config.max_seq_length), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    # Sauvegarder les données sur le disque
    input_file = 'input_ids.bin'
    target_file = 'target_ids.bin'
    save_data_to_disk(input_ids, target_ids, input_file, target_file)

    # Créer un Dataset personnalisé qui charge les données depuis le disque
    train_data = DiskDataset(input_file, target_file, n, config.max_seq_length)
    train_dataloader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True)

    return input_lang, output_lang, train_dataloader


In [13]:
input_lang, output_lang, train_dataloader = get_dataloader(CONFIG, reverse = True)

CONFIG.src_vocab_size = input_lang.n_words  # Taille du vocabulaire pour la langue source
CONFIG.tgt_vocab_size = output_lang.n_words  # Taille du vocabulaire pour la langue cible

Reading lines...
Read 406476 sentence pairs
Trimmed to 358902 sentence pairs
Counting words...
Counted words:
fra 36065
eng 25844


In [14]:
# Helper Module that adds positional encoding to the token embedding
# to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self, config):
        super(PositionalEncoding, self).__init__()
        # Compute the positional encodings once in log space.
        
        den = torch.exp(-torch.arange(0, config.d_model, 2) * (math.log(10000) / config.d_model))
        pos = torch.arange(config.max_seq_length).unsqueeze(1)  # shape: (max_seq_length, 1)
        pos_embedding = torch.zeros(config.max_seq_length, config.d_model)
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)

        pos_embedding = pos_embedding.unsqueeze(0).repeat(config.batch_size, 1, 1)

        # Store the positional embedding in a buffer (a tensor that is not a parameter)
        self.register_buffer('pos_embedding', pos_embedding)
        
        # Define dropout layer to be applied to the embeddings
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, token_embedding):
        # Add positional encoding to token embeddings and apply dropout
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :token_embedding.size(1)])


class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super(TokenEmbedding, self).__init__()
        
        # Create an embedding layer that maps each token index to an embedding vector
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.d_model)


# Seq2Seq Network using Transformer architecture
class Seq2SeqTransformer(nn.Module):
    def __init__(self, config):
        super(Seq2SeqTransformer, self).__init__()
        
        # Define the Transformer model with encoder and decoder layers
        self.transformer = nn.Transformer(d_model=config.d_model, nhead=config.num_heads, num_encoder_layers=config.num_layers_encoder, 
                                          num_decoder_layers=config.num_layers_decoder, dim_feedforward=config.d_feedforward, dropout=config.dropout,
                                          activation = "gelu", norm_first = False, batch_first = True, device = config.device)
        
        # Linear layer to project the transformer output to the target vocabulary size
        self.generator = nn.Linear(config.d_model, config.tgt_vocab_size)
        
        # Token embedding layers for source and target sequences
        self.src_token_emb = TokenEmbedding(config.src_vocab_size, config.d_model)
        self.tgt_token_emb = TokenEmbedding(config.tgt_vocab_size, config.d_model)
        
        # Positional encoding layer to add positional information to embeddings
        self.positional_encoding = PositionalEncoding(config)

    def forward(self, src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):

        src_emb = self.positional_encoding(self.src_token_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_token_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        #forward(src, tgt, src_mask, tgt_mask, memory_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src, src_mask):
        # Encode the source sequence using the Transformer encoder
        return self.transformer.encoder(self.positional_encoding(self.src_token_emb(src)), src_mask)

    def decode(self, tgt, memory, tgt_mask):
        # Decode the target sequence using the Transformer decoder with memory from the encoder
        return self.transformer.decoder(self.positional_encoding(self.tgt_token_emb(tgt)), memory, tgt_mask)


def create_mask(src, tgt):
    # Get the length of source and target sequences
    src_seq_len, tgt_seq_len = src.shape[1], tgt.shape[1]

    # Generate a target mask for self-attention to prevent attending to future tokens
    tgt_mask = torch.triu(torch.ones(tgt_seq_len, tgt_seq_len, dtype=torch.bool, device=CONFIG.device), diagonal=1)
    
    # Create a zero source mask (no masking for self-attention in encoder)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=CONFIG.device).type(torch.bool)

    # Create padding masks for the source and target sequences
    src_padding_mask, tgt_padding_mask = (src == PADDING_token), (tgt == PADDING_token)
    
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

transformer = Seq2SeqTransformer(CONFIG)


for parameter in transformer.parameters():
    if parameter.dim() > 1: nn.init.xavier_uniform_(parameter)

transformer = transformer.to(CONFIG.device)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PADDING_token)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
min_loss = float("+inf")

In [15]:
class Translator:
    def __init__(self, model, input_lang, output_lang, config):
        self.model = model
        self.input_lang = input_lang
        self.output_lang = output_lang
        self.config = config
        
    def save(self, path):
        torch.save({
            'model_state_dict': self.model.state_dict(),  # Enregistrer uniquement les poids du modèle
            'input_lang': self.input_lang,
            'output_lang': self.output_lang,
            'config': self.config
        }, path)
    
    @classmethod
    def load(cls, path):
        checkpoint = torch.load(path, weights_only=False)
        input_lang = checkpoint['input_lang']
        output_lang = checkpoint['output_lang']
        config = checkpoint['config']
        model = Seq2SeqTransformer(config)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(config.device)
        return cls(model, input_lang, output_lang, config)


    def evaluate(self, src_sentence):
        self.model.eval()
        src = tensorFromSentence(self.input_lang, normalizeString(src_sentence))
        num_tokens = src.shape[1]
        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        
        src = src.to(CONFIG.device)
        src_mask = src_mask.to(self.config.device)
    
        memory = self.model.encode(src, src_mask)
        ys = torch.ones(1, 1).fill_(START_token).type(torch.long).to(self.config.device)
        
        for i in range(CONFIG.max_seq_length):
            memory = memory.to(self.config.device)
            tgt_mask = torch.triu(torch.ones(ys.size(1), ys.size(1), dtype=torch.bool, device=self.config.device), diagonal=1)
            out = self.model.decode(ys, memory, tgt_mask)
            prob = self.model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.item()
            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
            if next_word == END_token:
                break
                
        tgt_tokens = ys.flatten()
        return " ".join([self.output_lang.index2word[token.item()] for token in tgt_tokens])

In [16]:
name = 'translator_model_fra_eng2.pth'
loaded_translator = Translator.load(name)
transformer = loaded_translator.model
input_lang = loaded_translator.input_lang
output_lang = loaded_translator.output_lang
CONFIG = loaded_translator.config

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PADDING_token)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
min_loss = float("+inf")

FileNotFoundError: [Errno 2] No such file or directory: 'translator_model_fra_eng2.pth'

In [24]:
transformer.train()

n_epochs = 10
min_loss = float("+inf")
name = 'translator_model_fra_eng2.pth'

start = time.time()
for epoch in range(1,n_epochs+1):
    losses = 0
    with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch}/{n_epochs}', unit='batch') as pbar:
        for batch_idx, (src, tgt) in enumerate(train_dataloader):
            tgt_input = tgt[:, :-1]

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
            logits = transformer(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
            
            optimizer.zero_grad()
    
            tgt_out = tgt[:,1:]
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
            losses += loss.item()
            
            loss.backward()
            optimizer.step()
            
            # Mise à jour de la barre de progression
            pbar.set_postfix({'loss': loss.item(), 'mean loss' : losses/(batch_idx+1)})
            pbar.update(1)
        if min_loss > losses:
            min_loss = losses
            translator = Translator(transformer, input_lang, output_lang,CONFIG)
            translator.save(name)

Epoch 1/10: 100%|████████████████████████████████████| 5608/5608 [46:22<00:00,  2.02batch/s, loss=2.19, mean loss=2.09]
Epoch 2/10: 100%|████████████████████████████████████| 5608/5608 [47:41<00:00,  1.96batch/s, loss=1.56, mean loss=1.67]
Epoch 3/10: 100%|████████████████████████████████████| 5608/5608 [48:55<00:00,  1.91batch/s, loss=1.37, mean loss=1.47]
Epoch 4/10: 100%|███████████████████████████████████| 5608/5608 [49:43<00:00,  1.88batch/s, loss=0.895, mean loss=1.34]
Epoch 5/10: 100%|████████████████████████████████████| 5608/5608 [49:32<00:00,  1.89batch/s, loss=1.12, mean loss=1.25]
Epoch 6/10: 100%|████████████████████████████████████| 5608/5608 [49:28<00:00,  1.89batch/s, loss=1.66, mean loss=1.18]
Epoch 7/10: 100%|████████████████████████████████████| 5608/5608 [49:31<00:00,  1.89batch/s, loss=1.02, mean loss=1.14]
Epoch 8/10: 100%|████████████████████████████████████| 5608/5608 [49:31<00:00,  1.89batch/s, loss=1.27, mean loss=1.11]
Epoch 9/10: 100%|███████████████████████

In [22]:
translator = Translator(transformer, input_lang, output_lang,CONFIG)
translator.save(name)

In [34]:
loaded_translator = Translator.load(name)
src_sentence = "Le monde est vert. mais est t'il bleu ?"
translation = loaded_translator.evaluate(src_sentence)
print("Translation:", translation)

Translation: [START] is the world green but blue ? [END]
