**PART 3**

**ADVANCED DEEP NETWORKS FOR COMPLEX PROBLEMS**

---

**CHAPTER 12 - Sequence-to-sequence learning: Part 2**

---

In the previous chapter, we built an English-to-German machine translator using a standard Encoder-Decoder architecture. We used teacher forcing for training and evaluated it using the BLEU score. Finally, we repurposed the model for inference using a recursive decoder.

In this chapter, we improve the model's accuracy by implementing the **attention mechanism**. This allows the decoder to access rich representations from all time steps of the input sequence, rather than relying solely on the final context vector. We will also visualize these attention mechanisms to gain insights into the model's behavior.

### **12.1 Eyeballing the past: Improving our model with attention**

Standard seq2seq models suffer from a bottleneck: the encoder must compress the entire input sequence into a single fixed-size context vector. This is often insufficient for long sequences.

**Bahdanau Attention** solves this by allowing the decoder to look at *all* the encoder's outputs at every time step. For each decoding step, the model computes a weighted sum of the encoder outputs (the context vector), where the weights represent the importance of each input word to the current output word.

The process involves:
1.  **Energy Computation**: Calculating "energy" scores using a small fully connected network ($W, U, v$) that measures the match between the previous decoder state and each encoder output.
2.  **Normalization**: Using Softmax to convert energies into probabilities (attention weights, $\alpha$).
3.  **Context Vector**: Computing the weighted sum of encoder outputs using these probabilities.

#### **12.1.1 Implementing Bahdanau attention in TensorFlow**

We implement a custom layer `DecoderRNNAttentionWrapper` using Keras subclassing because TensorFlow does not provide a built-in layer for this specific architecture.

* **`__init__`**: Initializes the layer with a standard RNN cell (e.g., `GRUCell`).
* **`build`**: Defines the trainable weight matrices $W_a$, $U_a$, and $V_a$.
* **`call`**: Uses `K.rnn` to iterate through the decoder inputs. It defines a `_step` function that calculates the attention energies, weights, and the final context vector for each step.

In [None]:
import tensorflow.keras.backend as K
import tensorflow as tf

class DecoderRNNAttentionWrapper(tf.keras.layers.Layer):
    def __init__(self, cell_fn, units, **kwargs):
        self.cell_fn = cell_fn
        self.units = units
        super(DecoderRNNAttentionWrapper, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W_a = self.add_weight(name='W_a',
            shape=tf.TensorShape((input_shape[0][2], input_shape[0][2])),
            initializer='uniform', trainable=True)
        self.U_a = self.add_weight(name='U_a',
            shape=tf.TensorShape((self.cell_fn.units, self.cell_fn.units)),
            initializer='uniform', trainable=True)
        self.V_a = self.add_weight(name='V_a',
            shape=tf.TensorShape((input_shape[0][2], 1)),
            initializer='uniform', trainable=True)
        super(DecoderRNNAttentionWrapper, self).build(input_shape)

    def call(self, inputs, initial_state, training=False):
        encoder_outputs, decoder_inputs = inputs
        
        def _step(inputs, states):
            # Step function for computing energy for a single decoder state
            encoder_full_seq = states[-1]
            
            # Compute energy scores
            W_a_dot_h = K.dot(encoder_outputs, self.W_a)
            U_a_dot_s = K.expand_dims(K.dot(states[0], self.U_a), 1)
            Wh_plus_Us = K.tanh(W_a_dot_h + U_a_dot_s)
            
            # Calculate attention weights (alpha)
            e_i = K.squeeze(K.dot(Wh_plus_Us, self.V_a), axis=-1)
            a_i = K.softmax(e_i)
            
            # Compute weighted sum (context vector)
            c_i = K.sum(encoder_outputs * K.expand_dims(a_i, -1), axis=1)
            
            # Concatenate input and context vector, then pass to GRU cell
            s, states = self.cell_fn(K.concatenate([inputs, c_i], axis=-1), states)
            return (s, a_i), states

        attn_outputs, _ = K.rnn(
            step_function=_step, inputs=decoder_inputs,
            initial_states=[initial_state], constants=[encoder_outputs]
        )
        
        attn_out, attn_energy = attn_outputs
        return attn_out, attn_energy

#### **12.1.2 Defining the final model**

We integrate the attention wrapper into the final seq2seq model.

* **Encoder**: Remains a Bidirectional GRU.
* **Decoder**: Now uses the `DecoderRNNAttentionWrapper` wrapping a `GRUCell`. Unlike the standard decoder, it takes **all** encoder states (`en_states`) as input to calculate attention scores.

We also define helper functions `get_vectorizer` and `get_encoder` similar to the previous chapter.

In [None]:
def get_final_seq2seq_model_with_attention(n_vocab, encoder, vectorizer):
    """ Define the final encoder-decoder model with attention """
    e_inp = tf.keras.Input(shape=(1,), dtype=tf.string, name='e_input_final')
    fwd_state, bwd_state, en_states = encoder(e_inp)
    
    d_inp = tf.keras.Input(shape=(1,), dtype=tf.string, name='d_input')
    d_vectorized_out = vectorizer(d_inp)
    
    d_emb_layer = tf.keras.layers.Embedding(
        n_vocab+2, 128, mask_zero=True, name='d_embedding'
    )
    d_emb_out = d_emb_layer(d_vectorized_out)
    
    d_init_state = tf.keras.layers.Concatenate(axis=-1)([fwd_state, bwd_state])
    
    # Attention Mechanism
    gru_cell = tf.keras.layers.GRUCell(256)
    attn_out, _ = DecoderRNNAttentionWrapper(
        cell_fn=gru_cell, units=512, name="d_attention"
    )([en_states, d_emb_out], initial_state=d_init_state)
    
    d_dense_layer_1 = tf.keras.layers.Dense(512, activation='relu', name='d_dense_1')
    d_densel_out = d_dense_layer_1(attn_out)
    
    d_final_layer = tf.keras.layers.Dense(
        n_vocab+2, activation='softmax', name='d_dense_final'
    )
    d_final_out = d_final_layer(d_densel_out)
    
    seq2seq = tf.keras.models.Model(
        inputs=[e_inp, d_inp], outputs=d_final_out, 
        name='final_seq2seq_with_attention'
    )
    return seq2seq

#### **12.1.3 Training the model**

The training process is identical to the one used for the standard seq2seq model. We use the custom training loop `train_model` (defined in Chapter 11) which employs teacher forcing and evaluates the model using the BLEU score.

The attention-based model typically achieves significantly higher performance (e.g., doubling the BLEU score) compared to the standard model, demonstrating the power of the attention mechanism.

In [None]:
# Example training call (assuming data and helper functions are prepared)
# epochs = 5
# batch_size = 128
# train_model(final_model_with_attention, de_vectorizer, train_df, valid_df, 
#             test_df, epochs, batch_size)

### **12.2 Visualizing the attention**

One of the major advantages of attention is **interpretability**. The attention weights (energies) tell us exactly which words in the source sentence the model focused on when generating a specific word in the target sentence.

To visualize this, we define a special `attention_visualizer` model that returns the attention states (`attn_states`) alongside the standard predictions. We then use `matplotlib` to plot these weights as a heatmap.

![Figure 12.1 Attention patterns visualized for an input English text](./12.Chapter-12/Figure12-1.jpg)

A strong diagonal pattern in the heatmap usually indicates that the model has learned the correct alignment between the languages (e.g., the first English word corresponds to the first German word).

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def attention_visualizer(save_path):
    """ Define the attention visualizer model """
    model = tf.keras.models.load_model(save_path)
    # ... (Code to retrace layers and capture attention states) ...
    # Returns a model that outputs [d_final_out, attn_states, e_vec_out, d_vec_out]
    return visualizer_model

def visualize_attention(visualizer_model, en_vocabulary, de_vocabulary, 
                        sample_en_text, sample_de_text, fig_savepath):
    """ Visualize the attention patterns """
    d_pred, attention_weights, e_out, d_out = visualizer_model.predict(
        [np.array([sample_en_text]), np.array([sample_de_text])]
    )
    
    # Filtering and Plotting logic
    fig, ax = plt.subplots(figsize=(14, 14))
    im = ax.imshow(attention_weights_filtered)
    
    # Set ticks and labels
    ax.set_xticklabels(x_ticklabels)
    ax.set_yticklabels(y_ticklabels)
    
    plt.colorbar(im)
    plt.savefig(fig_savepath)