# Seq2Seq NMT with Bahdanau Attention (English → French)

This notebook demonstrates a sequence-to-sequence (Seq2Seq) neural machine translation model enhanced with **Bahdanau-style attention**. Unlike Luong attention (which is post-RNN), Bahdanau computes attention **before** generating the decoder hidden state, using an *alignment model*.

We evaluate two alignment strategies:
- **Concat (Original Bahdanau)**
- **Additive (Optimized MLP variant)**


We use a subset of the **English-French** dataset from [Tatoeba (ManyThings.org)](https://www.manythings.org/anki/).

---


# 1. Imports and Setup

In [None]:
import os
import random
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from dataloader_generator import (
    normalizeString,
    prepareData,
    DatasetEngFra,
    collate_batch,
    PAD_token,
    SOS_token,
    EOS_token
)

from models import EncoderWithBahdanauAttention, DecoderWithBahdanauAttention
from utils import (
    train_bahdanau_attention,
    translate_with_attention,
    plot_attention
)

# Set device and random seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 13
torch.manual_seed(seed)


# 2. Hyperparameters & Dataset

## Dataset and Hyperparameters

- Based on the [Tatoeba English–French dataset](https://www.manythings.org/anki/)
- Sentence pairs truncated to a max length of 10 tokens
- Preprocessing: Unicode normalization, tokenization, lowercasing

We will train on:
- **20 epochs**
- **Hidden size**: 256
- **Batch size**: 32
- **Bahdanau alignment modes**: `["concat", "additive"]`


In [None]:
# Parameters
num_epochs = 20
hidden_size = 256
alignment_size = 128
BATCH_SIZE = 32
learning_rate = 1e-3
MAX_LENGTH = 10
alignment_modes = ['dot_product', 'concat', 'general']

# Download if not exists
if not os.path.exists('fra.txt'):
    os.system('wget -q https://www.manythings.org/anki/fra-eng.zip')
    os.system('unzip -oq fra-eng.zip')

# Load and preprocess
text_pairs = []
for line in open('fra.txt', 'r'):
    a = line.find('CC-BY')
    line = line[:a].strip()
    if '\t' not in line:
        continue
    eng, fra = line.split('\t')
    text_pairs.append((normalizeString(eng), normalizeString(fra)))

input_lang, output_lang, pairs = prepareData('eng', 'fra', text_pairs)
dataset = DatasetEngFra(pairs, input_lang, output_lang)
train_dl = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)


# 3. Training Loop Across Modes:
## Training and Evaluation

We now train two models using the Bahdanau attention mechanism with different alignment modes:
- **Concat attention** (classic Bahdanau formulation)
- **Additive attention** (optimized MLP-based scoring)

For each model, we track:
- Training loss
- Sample translations
- Attention visualizations


In [None]:
for alignment_mode in alignment_modes:
    print(f"\n{'=' * 40} {alignment_mode.upper()} MODE {'=' * 40}")

    encoder = EncoderWithBahdanauAttention(input_lang.n_words, hidden_size).to(device)
    decoder = DecoderWithBahdanauAttention(output_lang.n_words, hidden_size, alignment_size=alignment_size,
                                           alignment_mode=alignment_mode).to(device)

    loss_fn = nn.NLLLoss(ignore_index=PAD_token)
    encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)

    train_loss = train_bahdanau_attention(encoder, decoder, train_dl, num_epochs,loss_fn,
                                          encoder_optimizer, decoder_optimizer)

    print("\nSample Predictions:")
    for _ in range(10):
        eng, fra = random.choice(text_pairs)
        print("Input:", eng)
        print("Target:", fra)
        pred, attentions = translate_with_attention(encoder, decoder, eng, input_lang, output_lang)
        print("Predicted:", pred)
        print("#" * 80)

    # Visualize attention on one sample
    sample_input = "Life is often compared to a journey"
    correct_translation = "La vie est souvent comparée à un voyage."
    output_sentence, attention_weights = translate_with_attention(encoder, decoder, sample_input, input_lang, output_lang)
    print("\nPredicted translation:", output_sentence)
    print("Correct translation:", correct_translation)
    plot_attention(sample_input, output_sentence, attention_weights)

    # Plot training loss
    plt.plot(train_loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f"Training Loss - {alignment_mode.upper()} Attention")
    plt.grid(True)
    plt.show()


## Conclusion

This notebook implemented **Bahdanau-style attention** for neural machine translation.

### Observations:
- **Concat attention** follows Bahdanau's original formulation. It produced fluent translations but trained more slowly.
- **Additive attention** (a more parameter-efficient variant) trained faster and generally provided similar or better attention focus.

Both models produced coherent translations, especially for shorter sequences.

---

### Key Takeaways
- Bahdanau attention helps the decoder focus on relevant encoder states *before* generating the next word.
- **Alignment models** (concat/additive) influence convergence speed and context representation.
- Attention visualizations provide powerful insights into model behavior.

