# Práctica 2: Implementación de un mecanismo de atención en un modelo Seq2Seq con LSTMs

Partiendo del código del modelo seq2seq con feedback para tareas de Traducción Automática Neuronal (NMT) del notebook anterior, se debe implementar el modelo de atención de Bahdanau o Luong.

Objetivos de la práctica:
- Entender el funcionamiento de los modelos Seq2Seq con LSTMs.
- Comprender e implementar mecanismos de atención.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import warnings

from torch.utils.data import DataLoader
from attention.attention_factory import AttentionFactory
from translation import Translation, collate_fn
from seq2seq.encoder import Encoder
from seq2seq.decoder import Decoder
from seq2seq.seq2seq import Seq2Seq

import wandb

### Dudas
entregables: BAG of Words, LSTM Glove, LSTM-attention --> con readmes


### TODO:
- Modelo Loung 
- Modelo Badanauh:
    - son lstm bidireccionales con la segunda entrada en reverse : OJO
    - 

Conexión con *Weights & Biases*

In [2]:
wandb.init(project="LSTM-Attention", name="Dot-product",
            config={
          "learning_rate": 0.001,
          "architecture": "LSTM",
          "epochs": 30,
          "batch_size": 7,
          })

config = wandb.config

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msusanasrez[0m ([33mdata2023[0m). Use [1m`wandb login --relogin`[0m to force relogin


## 1. Cargar los datos

In [3]:
archivo_ingles = 'datasets_practice/mock/mock.en'
archivo_espanol = 'datasets_practice/mock/mock.es'

translation = Translation(archivo_ingles, archivo_espanol)

## 2. Entrenamiento

In [4]:
# Parámetros
input_dim = 300
output_dim = translation.vocab_es.vectors.shape[0]
hidden_dim = 512
num_layers = 1
num_workers = 0
shuffle = True

attention = AttentionFactory.initialize_attention("Multi-Layer Perceptron", hidden_dim, hidden_dim)

# Inicializa el modelo, el optimizador y la función de pérdida
encoder = Encoder(input_dim, hidden_dim, num_layers)
decoder = Decoder(input_dim, hidden_dim, output_dim, num_layers, attention=attention)
model = Seq2Seq(encoder, decoder)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()


#dataloader = DataLoader(translation, batch_size=config.batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)

train_loader = DataLoader(translation.train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)
test_loader = DataLoader(translation.test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)

In [None]:
warnings.filterwarnings("ignore")

for epoch in range(config.epochs):

    model.train()
    total_loss = 0

    for batch_idx, (src, tgt, src_indices, tgt_indices) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(src, tgt)

        tgt_indices = torch.tensor(tgt_indices, dtype=torch.long)
        loss = 0
        for t in range(1, tgt.shape[1]):
            loss += criterion(output[:, t, :], tgt_indices[:, t])

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if batch_idx % 5 == 0:
            print(f'Epoch [{epoch+1}/{config.epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
    

    model.eval()
    test_loss = 0

    with torch.no_grad():
        for src, tgt, src_indices, tgt_indices in test_loader:
            output = model(src, tgt)

            tgt_indices = torch.tensor(tgt_indices, dtype=torch.long)
            loss = 0
            for t in range(1, tgt.shape[1]):
                loss += criterion(output[:, t, :], tgt_indices[:, t])
            
            test_loss += loss.item()

    wandb.log({"Train loss": total_loss,"Test_loss": test_loss})

    print(f'Epoch [{epoch+1}/{config.epochs}], Average Train Loss: {total_loss / len(train_loader):.4f}, Average Test Loss: {test_loss / len(test_loader):.4f}')
    print('--------------------------------------------------------------------------------------------------------------')

Epoch [1/30], Step [1/2], Loss: 41.3668
Epoch [1/30], Average Train Loss: 40.9661, Average Test Loss: 36.8135
--------------------------------------------------------------------------------------------------------------
Epoch [2/30], Step [1/2], Loss: 36.5829
Epoch [2/30], Average Train Loss: 28.8277, Average Test Loss: 26.5496
--------------------------------------------------------------------------------------------------------------
Epoch [3/30], Step [1/2], Loss: 16.2330
Epoch [3/30], Average Train Loss: 21.3267, Average Test Loss: 22.8025
--------------------------------------------------------------------------------------------------------------
Epoch [4/30], Step [1/2], Loss: 21.1832
Epoch [4/30], Average Train Loss: 15.7309, Average Test Loss: 18.9899
--------------------------------------------------------------------------------------------------------------
Epoch [5/30], Step [1/2], Loss: 14.4041
Epoch [5/30], Average Train Loss: 11.5986, Average Test Loss: 15.1740
------

In [None]:
torch.save(model.state_dict(), './models/dot_product.pth')

In [None]:
wandb.finish()

0,1
Test_loss,█▇▆▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂
Train loss,▇█▆▆▆▄▃▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test_loss,20.80693
Train loss,6.80205


In [None]:
model = torch.load('./models/dot_product.pth')

In [None]:
model.eval()

sentence = "tiger"

# Convertir a vectores
tokens = translation.tokenizer_en(sentence)
tokens = tokens + ['<eos>']
text_tensor = translation.vocab_en.get_vecs_by_tokens(tokens)
text_tensor = text_tensor.unsqueeze(0)

with torch.no_grad():
    encoder_outputs, (hidden, cell) = model.encoder(text_tensor)

outputs = []

input_token = torch.tensor(translation.vocab_es.stoi['<sos>']).unsqueeze(0)
input_token = translation.vocab_es.vectors[input_token].unsqueeze(0)
    

for _ in range(5):
    with torch.no_grad():
        output, (hidden, cell) = model.decoder(input_token, hidden, cell) # teacher_forcing_ratio=0.0
        
    # Obtener el token con la probabilidad más alta
    best_guess = output.argmax(2).squeeze(0)
    outputs.append(best_guess.item())
        
    # Si el token es <eos>, terminar la traducción
    if best_guess == translation.vocab_es.stoi['<eos>']:
        break
        
    # Utilizar la palabra predicha como la siguiente entrada al decoder
    input_token = translation.vocab_es.vectors[best_guess].unsqueeze(0)
        
# Convertir los índices de salida a palabras
translated_sentence = [translation.vocab_es.itos[idx] for idx in outputs]
    
result = ' '.join(translated_sentence)

print(result)