# Práctica 2: Implementación de un mecanismo de atención en un modelo Seq2Seq con LSTMs

Partiendo del código del modelo seq2seq con feedback para tareas de Traducción Automática Neuronal (NMT) del notebook anterior, se debe implementar el modelo de atención de Bahdanau o Luong.

Objetivos de la práctica:
- Entender el funcionamiento de los modelos Seq2Seq con LSTMs.
- Comprender e implementar mecanismos de atención.

In [1]:
from torchtext.data.utils import get_tokenizer
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import torchtext
import torch
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
import random
import warnings

from torch.utils.data import DataLoader, random_split
from attention.attention_factory import AttentionFactory

import wandb

Google Collab

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

# !cd '/content/drive/My Drive/LSTM_attention' && ls

# !pip uninstall torch torchtext -y
# !pip install torch==2.0.1 torchtext==0.15.2 --index-url https://download.pytorch.org/whl/cu118
# !pip install portalocker>=2.0.0

# !python -m spacy download es_core_news_md
# !python -m spacy download en_core_web_md

Conexión con *Weights & Biases*

In [3]:
wandb.init(project="LSTM-Attention", name="Dot-Product",
            config={
          "learning_rate": 0.001,
          "architecture": "LSTM",
          "epochs": 30,
          "batch_size": 7,
          })

config = wandb.config

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msusanasrez[0m ([33mdata2023[0m). Use [1m`wandb login --relogin`[0m to force relogin


## 1. Cargar los datos

In [4]:
class Translation(Dataset):
    def __init__(self, source_file, target_file, train_size=0.9):
        self.ingles = []
        self.espanol = []
        self.tokenizer_es = get_tokenizer("spacy", language="es_core_news_md")
        self.tokenizer_en = get_tokenizer("spacy", language="en_core_web_md")
        self.vocab_es = torchtext.vocab.FastText(language='es', unk_init=torch.Tensor.normal_)
        self.vocab_en = torchtext.vocab.FastText(language='en', unk_init=torch.Tensor.normal_)

        self.vocab_en = self.add_sos_eos_unk_pad(self.vocab_en)
        self.vocab_es = self.add_sos_eos_unk_pad(self.vocab_es)

        self.archivo_ingles = source_file
        self.archivo_espanol = target_file

        # Leer el conjunto de datos
        for ingles, espanol in self.read_translation():
            self.ingles.append(ingles)
            self.espanol.append(espanol)
        
        # Dividir en entrenamiento y test
        train_size = int(len(self) * train_size)
        test_size = len(self) - train_size
        self.train_dataset, self.test_dataset = random_split(self, [train_size, test_size])



    def add_sos_eos_unk_pad(self, vocabulary):
        words = vocabulary.itos
        vocab = vocabulary.stoi
        embedding_matrix = vocabulary.vectors

        # Tokens especiales
        sos_token = '<sos>'
        eos_token = '<eos>'
        pad_token = '<pad>'
        unk_token = '<unk>'

        # Inicializamos los vectores para los tokens especiales, por ejemplo, con ceros
        sos_vector = torch.full((1, embedding_matrix.shape[1]), 1.)
        eos_vector = torch.full((1, embedding_matrix.shape[1]), 2.)
        pad_vector = torch.zeros((1, embedding_matrix.shape[1]))
        unk_vector = torch.full((1, embedding_matrix.shape[1]), 3.)

        # Añade los vectores al final de la matriz de embeddings
        embedding_matrix = torch.cat((embedding_matrix, sos_vector, eos_vector, unk_vector, pad_vector), 0)

        # Añade los tokens especiales al vocabulario
        vocab[sos_token] = len(vocab)
        vocab[eos_token] = len(vocab)
        vocab[pad_token] = len(vocab)
        vocab[unk_token] = len(vocab)

        words.append(sos_token)
        words.append(eos_token)
        words.append(pad_token)
        words.append(unk_token)

        vocabulary.itos = words
        vocabulary.stoi = vocab
        vocabulary.vectors = embedding_matrix

        default_stoi = defaultdict(lambda : len(vocabulary)-1, vocabulary.stoi)
        vocabulary.stoi = default_stoi
    
        return vocabulary
        

    def read_translation(self):
        with open(self.archivo_ingles, 'r', encoding='utf-8') as f_ingles, open(self.archivo_espanol, 'r', encoding='utf-8') as f_espanol:
            for oracion_ingles, oracion_espanol in zip(f_ingles, f_espanol):
                yield oracion_ingles.strip().lower(), oracion_espanol.strip().lower()

    def __len__(self):
        return len(self.ingles)

    def __getitem__(self, idx):
        item = self.ingles[idx], self.espanol[idx]
        tokens_ingles = self.tokenizer_en(item[0])
        tokens_espanol = self.tokenizer_es(item[1])

        tokens_ingles = tokens_ingles + ['<eos>']
        tokens_espanol = ['<sos>'] + tokens_espanol + ['<eos>']

        if not tokens_ingles or not tokens_espanol:
            return torch.zeros(1, 300), torch.zeros(1, 300)
            # raise RuntimeError("Una de las muestras está vacía.")
    
        tensor_ingles = self.vocab_en.get_vecs_by_tokens(tokens_ingles)
        tensor_espanol = self.vocab_es.get_vecs_by_tokens(tokens_espanol)

        indices_ingles = [self.vocab_en.stoi[token] for token in tokens_ingles] + [self.vocab_en.stoi['<pad>']]
        indices_espanol = [self.vocab_es.stoi[token] for token in tokens_espanol] + [self.vocab_es.stoi['<pad>']]

        return tensor_ingles, tensor_espanol, indices_ingles, indices_espanol
        
            
        
def collate_fn(batch):
    ingles_batch, espanol_batch, ingles_seqs, espanol_seqs = zip(*batch)
    ingles_batch = pad_sequence(ingles_batch, batch_first=True, padding_value=0)
    espanol_batch = pad_sequence(espanol_batch, batch_first=True, padding_value=0)

    # Calcular la longitud máxima de la lista de listas de índices
    pad = espanol_seqs[0][-1]  # token <pad>
    max_len = max([len(l) for l in espanol_seqs])
    for seq in espanol_seqs:
        seq += [pad]*(max_len-len(seq))
        
    return ingles_batch, espanol_batch, ingles_seqs, espanol_seqs

In [5]:
archivo_ingles = 'datasets_practice/mock/mock.en'
archivo_espanol = 'datasets_practice/mock/mock.es'

translation = Translation(archivo_ingles, archivo_espanol)

## 2. Definición del modelo

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__() 
        self.rnn = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)

    def forward(self, x):
        output, (hidden, cell) = self.rnn(x)
        return output, (hidden, cell)
    
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, attention):
        super().__init__()
        self.rnn = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True) 
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.attention = attention

    # Modificar el forward para que haga la atención
    def forward(self, x, hidden, cell, outputs_encoder):
        output, (hidden, cell) = self.rnn(x, (hidden, cell))

        # attention_weights -> [ batch_size = 8, 1 valor, X palabras en el encoder = 3] -> [ 8, 3]
        attention_weights = self.attention.compute_score(output, outputs_encoder)

        # Normalized vectors -> [ 8, 3, 1]
        normalized_vectors = torch.softmax(attention_weights, dim=1).unsqueeze(-1)

        # [ 8, 3, 512] * [ 8, 3, 512] = [8, 3, 512]
        attention_output = normalized_vectors * outputs_encoder

        # Promedio de los vectores -> [8, 1, 512]
        summed_vectors = torch.sum(attention_output, dim=1, keepdim=True)

        # output = [8,1,512]
        output = self.fc_out(summed_vectors)
        # output = [8,1, tamaño_vocab]
        
        return output, (hidden, cell)

In [7]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder                           
        self.es_embeddings = torchtext.vocab.FastText(language='es')
        self.M = self.es_embeddings.vectors
        self.M = torch.cat((self.M, torch.zeros((4, self.M.shape[1]))), 0)

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        target_len = target.shape[1]
        batch_size = target.shape[0]

        # Tensor para almacenar las salidas del decoder
        outputs = torch.zeros(batch_size, target_len, 985671)
        
        # Primero, la fuente es procesada por el encoder
        outputs_encoder, (hidden, cell) = self.encoder(source)

        # La primera entrada al decoder es el vector <sos>
        x = target[:, 0, :]

        for t in range(1, target_len):
            output, (hidden, cell) = self.decoder(x.unsqueeze(1), hidden, cell, outputs_encoder)
            outputs[:, t, :] = output.squeeze(1)
            
            teacher_force = random.random() < teacher_forcing_ratio
            if teacher_force:
                x = target[:, t, :]
            else:
                x = torch.matmul(output.squeeze(1), self.M)
        return outputs

## 3. Entrenamiento

In [8]:
# Parámetros
input_dim = 300
output_dim = translation.vocab_es.vectors.shape[0]
hidden_dim = 512
num_layers = 2
"""
learning_rate = 0.001
num_epochs = 30
batch_size = 8
"""
num_workers = 0
shuffle = True

attention = AttentionFactory.initialize_attention("Dot-product")

# Inicializa el modelo, el optimizador y la función de pérdida
encoder = Encoder(input_dim, hidden_dim, num_layers)
decoder = Decoder(input_dim, hidden_dim, output_dim, num_layers, attention=attention)
model = Seq2Seq(encoder, decoder)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()


#dataloader = DataLoader(translation, batch_size=config.batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)

train_loader = DataLoader(translation.train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)
test_loader = DataLoader(translation.test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)

In [9]:
warnings.filterwarnings("ignore")

for epoch in range(config.epochs):

    model.train()
    total_loss = 0

    for batch_idx, (src, tgt, src_indices, tgt_indices) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(src, tgt)

        tgt_indices = torch.tensor(tgt_indices, dtype=torch.long)
        loss = 0
        for t in range(1, tgt.shape[1]):
            loss += criterion(output[:, t, :], tgt_indices[:, t])

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if batch_idx % 5 == 0:
            print(f'Epoch [{epoch+1}/{config.epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
    

    model.eval()
    test_loss = 0

    with torch.no_grad():
        for src, tgt, src_indices, tgt_indices in test_loader:
            output = model(src, tgt)

            tgt_indices = torch.tensor(tgt_indices, dtype=torch.long)
            loss = 0
            for t in range(1, tgt.shape[1]):
                loss += criterion(output[:, t, :], tgt_indices[:, t])
            
            test_loss += loss.item()

    wandb.log({"Train loss": total_loss,"Test_loss": test_loss})

    print(f'Epoch [{epoch+1}/{config.epochs}], Average Train Loss: {total_loss / len(train_loader):.4f}, Average Test Loss: {test_loss / len(test_loader):.4f}')
    print('--------------------------------------------------------------------------------------------------------------')

Epoch [1/30], Step [1/2], Loss: 41.4131
Epoch [1/30], Average Train Loss: 34.4047, Average Test Loss: 26.8760
--------------------------------------------------------------------------------------------------------------
Epoch [2/30], Step [1/2], Loss: 40.1861
Epoch [2/30], Average Train Loss: 33.0279, Average Test Loss: 25.4209
--------------------------------------------------------------------------------------------------------------
Epoch [3/30], Step [1/2], Loss: 36.9579
Epoch [3/30], Average Train Loss: 35.6014, Average Test Loss: 23.4232
--------------------------------------------------------------------------------------------------------------
Epoch [4/30], Step [1/2], Loss: 31.6118
Epoch [4/30], Average Train Loss: 25.7344, Average Test Loss: 21.1108
--------------------------------------------------------------------------------------------------------------
Epoch [5/30], Step [1/2], Loss: 22.9395
Epoch [5/30], Average Train Loss: 23.8357, Average Test Loss: 18.7035
------

In [10]:
wandb.finish()

0,1
Test_loss,█▇▆▄▃▂▁▁▁▁▂▂▃▃▄▄▃▃▃▃▃▃▃▃▄▄▄▄▄▄
Train loss,█▇█▆▅▃▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test_loss,20.17036
Train loss,9.90932
