Implement a basic attention mechanism in TensorFlow 2 to compute weighted context vectors for sequence-to-sequence tasks (e.g., translation or summarization).

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
#Simulated encoder ouputs (batch_size=1,time_steps=4,hidden_size=8)
encoder_outputs=tf.random.normal([1,4,8])
decoder_hidden_state=tf.random.normal([1,8])

In [4]:
#Define basic attention layer
class BasicAttention(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.W1=tf.keras.layers.Dense(8)
        self.W2=tf.keras.layers.Dense(8)
        self.V=tf.keras.layers.Dense(1)

    def call(self,encoder_outputs,decoder_hidden):
        decoder_hidden_exp=tf.expand_dims(decoder_hidden,1)
        score=self.V(tf.nn.tanh(self.W1(encoder_outputs)+self.W2(decoder_hidden_exp)))
        attention_weights=tf.nn.softmax(score,axis=1)
        context_vector=attention_weights*encoder_outputs
        context_vector=tf.reduce_sum(context_vector,axis=1)
        return context_vector,attention_weights

In [5]:
attention = BasicAttention()
context_vector, attention_weights = attention(encoder_outputs, decoder_hidden_state)

In [6]:
print("Encoder outputs shape:", encoder_outputs.shape)
print("Decoder hidden state shape:", decoder_hidden_state.shape)
print("Context vector shape:", context_vector.shape)
print("Attention weights shape:", attention_weights.shape)
print("Attention weights:", tf.squeeze(attention_weights).numpy())

Encoder outputs shape: (1, 4, 8)
Decoder hidden state shape: (1, 8)
Context vector shape: (1, 8)
Attention weights shape: (1, 4, 1)
Attention weights: [0.24032415 0.42041123 0.16460535 0.17465934]
