# Neural Machine Translation (English-to-German) using LSTMs & Attention

This project implements a Sequence-to-Sequence (Seq2Seq) model with an Attention Mechanism to translate English sentences into German. It addresses the vanishing gradient problem in standard RNNs by allowing the decoder to focus on specific parts of the input sequence.


<a name="1"></a>
# Project Overview
This project implements a sequence-to-sequence (Seq2Seq) Deep Learning model capable of translating English sentences into German. Unlike standard Recurrent Neural Network (RNN) approaches, this implementation utilizes an Attention Mechanism to solve the vanishing gradient problem, allowing the decoder to "focus" on relevant parts of the input sentence at each generation step.

The model is built from scratch using the Trax deep learning framework and trained on the Opus/Medical dataset.

<a name="0.1"></a>
#### Key Features:

Architecture: Encoder-Decoder with Scaled Dot-Product Attention.

Framework: Google Trax.

Decoding Strategies: Implementation of both Greedy Decoding and Minimum Bayes-Risk (MBR) Decoding for improved translation quality.


<a name="1.1"></a>
## Data Pipeline & Preprocessing

The dataset used is the opus/medical corpus, containing medical-related texts translated between English and German. The pipeline handles data ingestion, tokenization using subword representations, and batching.

In [1]:
!pip install termcolor
!pip install trax

Defaulting to user installation because normal site-packages is not writeable



[notice] A new release of pip is available: 25.2 -> 26.0
[notice] To update, run: C:\Users\PHI\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


Defaulting to user installation because normal site-packages is not writeable
Collecting trax
  Using cached trax-1.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting absl-py (from trax)
  Using cached absl_py-2.4.0-py3-none-any.whl.metadata (3.3 kB)
Collecting funcsigs (from trax)
  Using cached funcsigs-1.0.2-py2.py3-none-any.whl.metadata (14 kB)
Collecting gin-config (from trax)
  Using cached gin_config-0.5.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gym (from trax)
  Using cached gym-0.26.2.tar.gz (721 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting jax (from trax)
  Using cached jax-0.9.0-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib (from trax)
  Using cached jaxlib-0.9.0-cp

ERROR: Could not install packages due to an OSError: [Errno 2] No such file or directory: 'C:\\Users\\PHI\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python313\\site-packages\\tensorflow\\include\\external\\com_github_grpc_grpc\\src\\core\\ext\\filters\\fault_injection\\fault_injection_service_config_parser.h'


[notice] A new release of pip is available: 25.2 -> 26.0
[notice] To update, run: C:\Users\PHI\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.13_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [4]:
import random
import numpy as np
from termcolor import colored

import trax
from trax import layers as tl
from trax.fastmath import numpy as fastnp
from trax.supervised import training

# 1.1 Data Ingestion
# We use the Opus Medical dataset via Tensorflow Datasets (TFDS)
train_stream_fn = trax.data.TFDS('opus/medical',
                                 data_dir='./data/',
                                 keys=('en', 'de'),
                                 eval_holdout_size=0.01,
                                 train=True)

eval_stream_fn = trax.data.TFDS('opus/medical',
                                data_dir='./data/',
                                keys=('en', 'de'),
                                eval_holdout_size=0.01,
                                train=False)

train_stream = train_stream_fn()
eval_stream = eval_stream_fn()

print(colored('Sample train data (en, de):', 'red'), next(train_stream))

ModuleNotFoundError: No module named 'trax'

<a name="1.2"></a>
### 1.2 Tokenization and Formatting
We utilize subword tokenization to handle out-of-vocabulary words efficiently. An End-of-Sentence (EOS) token is appended to every sequence to signal termination during inference.


In [None]:
# Vocabulary Constants
VOCAB_FILE = 'ende_32k.subword'
VOCAB_DIR = 'data/'
EOS = 1

# Tokenization pipeline
tokenized_train_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(train_stream)
tokenized_eval_stream = trax.data.Tokenize(vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)(eval_stream)

# Helper to append EOS token
def append_eos(stream):
    for (inputs, targets) in stream:
        inputs_with_eos = list(inputs) + [EOS]
        targets_with_eos = list(targets) + [EOS]
        yield np.array(inputs_with_eos), np.array(targets_with_eos)

tokenized_train_stream = append_eos(tokenized_train_stream)
tokenized_eval_stream = append_eos(tokenized_eval_stream)

# Filter long sentences to manage memory usage
filtered_train_stream = trax.data.FilterByLength(max_length=256, length_keys=[0, 1])(tokenized_train_stream)
filtered_eval_stream = trax.data.FilterByLength(max_length=512, length_keys=[0, 1])(tokenized_eval_stream)

<a name="1.3"></a>
### 1.3 Helper Functions: Tokenize & Detokenize
These utilities convert between raw strings and token IDs, essential for interpreting model input and output.

In [None]:
def tokenize(input_str, vocab_file=None, vocab_dir=None):
    """Encodes a string to an array of integers with EOS token."""
    EOS = 1
    inputs = next(trax.data.tokenize(iter([input_str]), vocab_file=vocab_file, vocab_dir=vocab_dir))
    inputs = list(inputs) + [EOS]
    batch_inputs = np.reshape(np.array(inputs), [1, -1])
    return batch_inputs

def detokenize(integers, vocab_file=None, vocab_dir=None):
    """Decodes an array of integers to a human readable string."""
    integers = list(np.squeeze(integers))
    EOS = 1
    if EOS in integers:
        integers = integers[:integers.index(EOS)]
    return trax.data.detokenize(integers, vocab_file=vocab_file, vocab_dir=vocab_dir)

<a name="1.4"></a>
## 1.4 Bucketing and Batching
To optimize training speed, sequences are grouped (bucketed) by length to minimize padding overhead.

In [None]:
boundaries =  [8,   16,  32, 64, 128, 256, 512]
batch_sizes = [256, 128, 64, 32, 16,    8,   4,  2]

train_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes,
    length_keys=[0, 1]
)(filtered_train_stream)

eval_batch_stream = trax.data.BucketByLength(
    boundaries, batch_sizes,
    length_keys=[0, 1]
)(filtered_eval_stream)

# Add masking for padding (id 0)
train_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(train_batch_stream)
eval_batch_stream = trax.data.AddLossWeights(id_to_mask=0)(eval_batch_stream)

# Inspect a batch
input_batch, target_batch, mask_batch = next(train_batch_stream)
print("Input batch shape:", input_batch.shape)
print("Target batch shape:", target_batch.shape)

<a name="2"></a>
# 2. Model Architecture
The model follows a standard Encoder-Decoder pattern augmented with Attention.


<a name="2.1"></a>
## 2.1 Input Encoder
The encoder converts input tokens into embeddings and processes them through LSTM layers to create a context-aware representation (Keys and Values for attention).

In [None]:
def input_encoder_fn(input_vocab_size, d_model, n_encoder_layers):
    """Encodes input tokens into activations for attention Keys/Values."""
    input_encoder = tl.Serial(
        tl.Embedding(input_vocab_size, d_model),
        [tl.LSTM(d_model) for _ in range(n_encoder_layers)]
    )
    return input_encoder

<a name="2.2"></a>
## 2.2 Pre-Attention Decoder
This component processes target tokens (shifted right for teacher forcing) to create the "Queries" for the attention mechanism.

In [None]:
def pre_attention_decoder_fn(mode, target_vocab_size, d_model):
    """Decodes targets into activations for attention Queries."""
    pre_attention_decoder = tl.Serial(
        tl.ShiftRight(mode=mode),
        tl.Embedding(target_vocab_size, d_model),
        tl.LSTM(d_model)
    )
    return pre_attention_decoder

### 2.3 Attention Input Preparation
This function formats the internal states of the encoder and decoder into the Query-Key-Value (QKV) format required for the attention layer. It also constructs the mask to ignore padding.


In [None]:
def prepare_attention_input(encoder_activations, decoder_activations, inputs):
    """Prepares queries, keys, values, and mask for the Attention layer."""
    keys = encoder_activations
    values = encoder_activations
    queries = decoder_activations
    
    # Generate mask (1 for real tokens, 0 for padding)
    mask = (inputs != 0)
    
    # Reshape mask for attention heads and decoder length
    mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
    
    # Broadcast mask
    mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1))
    
    return queries, keys, values, mask

### 2.4 NMT Model Assembly
This function integrates the encoder, decoder, and attention layer into the final Seq2Seq model.


In [None]:
def NMTAttn(input_vocab_size=33300,
            target_vocab_size=33300,
            d_model=1024,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_attention_heads=4,
            attention_dropout=0.0,
            mode='train'):
    """Returns the full LSTM sequence-to-sequence model with attention."""
    
    # Initialize branches
    input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers)
    pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model)

    model = tl.Serial(
      # Copy inputs/targets for multiple branches
      tl.Select([0, 1, 0, 1]),
      
      # Parallel processing: Encoder on inputs, Decoder on targets
      tl.Parallel(input_encoder, pre_attention_decoder),
      
      # Prepare attention QKV
      tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4),
      
      # Attention Block with Residual connection
      tl.Residual(tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)),
      
      # Drop mask, keep attention output and targets
      tl.Select([0, 2]),
      
      # Post-attention LSTM decoder layers
      [tl.LSTM(d_model) for _ in range(n_decoder_layers)],
      
      # Final prediction layer
      tl.Dense(target_vocab_size),
      tl.LogSoftmax()
    )
    return model

# Initialize and verify structure
model = NMTAttn()
print(model)

<a name="3"></a>
# 3. Training Loop
The model is trained using the Adam optimizer with a warmup learning rate schedule. We define TrainTask and EvalTask to handle the training logic and metrics (Accuracy, CrossEntropyLoss).


In [None]:
train_task = training.TrainTask(
    labeled_data=train_batch_stream,
    loss_layer=tl.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam(.01),
    lr_schedule=trax.lr.warmup_and_rsqrt_decay(1000, .01),
    n_steps_per_checkpoint=10,
)

eval_task = training.EvalTask(
    labeled_data=eval_batch_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
)

output_dir = 'output_dir/'

training_loop = training.Loop(NMTAttn(mode='train'),
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Execute training (Example: 10 steps for demonstration)
training_loop.run(10)

<a name="4"></a>
# 4. Inference & Decoding Strategies
With the model trained, we implement decoding logic to generate translations.

<a name="4.1"></a>
## 4.1 Greedy Decoding & Sampling
The next_symbol function predicts the next token based on current context. sampling_decode uses this to generate full sentences, controlled by a temperature parameter (0.0 = Greedy/Argmax, >0.0 = Stochastic).

In [None]:
# Load pre-trained weights for evaluation
model = NMTAttn(mode='eval')
model.init_from_file("model.pkl.gz", weights_only=True)
model = tl.Accelerate(model)

def next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature):
    """Predicts the index of the next token."""
    token_length = len(cur_output_tokens)
    padded_length = 2**int(np.ceil(np.log2(token_length + 1))) 
    padded = cur_output_tokens + [0] * (padded_length - token_length)
    padded_with_batch = np.expand_dims(padded, axis=0)
    
    output, _ = NMTAttn((input_tokens, padded_with_batch))
    log_probs = output[0, token_length, :]
    symbol = int(tl.logsoftmax_sample(log_probs, temperature))
    
    return symbol, float(log_probs[symbol])

def sampling_decode(input_sentence, NMTAttn=None, temperature=0.0, vocab_file=None, vocab_dir=None):
    """Generates a translation using sampling decoding."""
    input_tokens = tokenize(input_sentence, vocab_file, vocab_dir)
    cur_output_tokens = []
    cur_output = 0
    EOS = 1
    
    while cur_output != EOS:
        cur_output, log_prob = next_symbol(NMTAttn, input_tokens, cur_output_tokens, temperature)
        cur_output_tokens.append(cur_output)
    
    sentence = detokenize(cur_output_tokens, vocab_file, vocab_dir)
    return cur_output_tokens, log_prob, sentence

# Test Greedy Decoding
print("Translation:", sampling_decode("I love languages.", model, temperature=0.0, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)[2])

<a name="4.1"></a>
## 4.2 Minimum Bayes-Risk (MBR) Decoding
Greedy decoding can sometimes produce suboptimal results. MBR Decoding improves performance by:

Generating multiple candidate translations (samples).

Scoring each candidate against all others using a similarity metric (e.g., Jaccard, ROUGE).

Selecting the "consensus" candidate that is most similar to the others.

In [None]:
from collections import Counter

def generate_samples(sentence, n_samples, NMTAttn=None, temperature=0.6, vocab_file=None, vocab_dir=None):
    """Generates n samples for a given input sentence."""
    samples, log_probs = [], []
    for _ in range(n_samples):
        sample, logp, _ = sampling_decode(sentence, NMTAttn, temperature, vocab_file=vocab_file, vocab_dir=vocab_dir)
        samples.append(sample)
        log_probs.append(logp)
    return samples, log_probs

def jaccard_similarity(candidate, reference):
    """Calculates Jaccard Index (Intersection over Union)."""
    can_unigram_set, ref_unigram_set = set(candidate), set(reference)  
    joint_elems = can_unigram_set.intersection(ref_unigram_set)
    all_elems = can_unigram_set.union(ref_unigram_set)
    return len(joint_elems) / len(all_elems)

def rouge1_similarity(system, reference):
    """Calculates ROUGE-1 F1 score."""
    sys_counter = Counter(system)
    ref_counter = Counter(reference)
    overlap = 0
    for token in sys_counter:
        token_count_sys = sys_counter.get(token, 0)
        token_count_ref = ref_counter.get(token, 0)
        overlap += min(token_count_sys, token_count_ref)
    
    precision = overlap / sum(sys_counter.values())
    recall = overlap / sum(ref_counter.values())
    
    if precision + recall != 0:
        return 2 * ((precision * recall) / (precision + recall))
    return 0

def weighted_avg_overlap(similarity_fn, samples, log_probs):
    """Calculates weighted mean score for each candidate."""
    scores = {}
    for index_candidate, candidate in enumerate(samples):    
        overlap, weight_sum = 0.0, 0.0
        for index_sample, (sample, logp) in enumerate(zip(samples, log_probs)):
            if index_candidate == index_sample:
                continue
            sample_p = float(np.exp(logp))
            weight_sum += sample_p
            sample_overlap = similarity_fn(candidate, sample)
            overlap += sample_p * sample_overlap
        scores[index_candidate] = overlap / weight_sum
    return scores

def mbr_decode(sentence, n_samples, score_fn, similarity_fn, NMTAttn=None, temperature=0.6, vocab_file=None, vocab_dir=None):
    """Performs Minimum Bayes-Risk Decoding."""
    samples, log_probs = generate_samples(sentence, n_samples, NMTAttn, temperature, vocab_file, vocab_dir)
    scores = weighted_avg_overlap(similarity_fn, samples, log_probs)
    max_index = max(scores, key=scores.get)
    translated_sentence = detokenize(samples[max_index], vocab_file, vocab_dir)
    return translated_sentence, max_index, scores

# Final Test
text = 'She speaks English and German.'
print(f"Input: {text}")
translation = mbr_decode(text, 4, weighted_avg_overlap, jaccard_similarity, model, 0.6, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)[0]
print(f"MBR Translation: {translation}")