# Attention Mechanism - Basics

The **Attention Mechanism** allows a model to focus on the most relevant parts of the input sequence when making predictions.

🔹 Traditional RNNs/LSTMs compress all information into a single hidden state.
🔹 Attention creates **weighted combinations** of all hidden states, letting the model learn which parts of the sequence are most important.

📌 Widely used in:
- Machine Translation
- Text Summarization
- Speech Recognition
- Transformers (BERT, GPT, etc.)

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer
import numpy as np

## Implement a Simple Attention Layer

We create a custom attention layer to demonstrate the concept.

In [None]:
class SimpleAttention(Layer):
    def __init__(self):
        super(SimpleAttention, self).__init__()

    def call(self, query, values):
        # Calculate attention scores (dot product)
        scores = tf.matmul(query, values, transpose_b=True)
        
        # Apply softmax to get weights
        weights = tf.nn.softmax(scores, axis=-1)
        
        # Weighted sum of values
        context = tf.matmul(weights, values)
        return context, weights

## Test Attention Layer with Dummy Data

In [None]:
# Dummy query and values
query = tf.random.normal(shape=(1, 1, 8))   # [batch, query_len, hidden_size]
values = tf.random.normal(shape=(1, 5, 8))  # [batch, seq_len, hidden_size]

attention = SimpleAttention()
context, weights = attention(query, values)

print("Query shape:", query.shape)
print("Values shape:", values.shape)
print("Context shape:", context.shape)
print("Attention weights:", weights.numpy())

## Visualization of Attention Weights
Let’s plot the attention weights to see how much focus is placed on each time step.

In [None]:
import matplotlib.pyplot as plt

plt.matshow(weights[0].numpy(), cmap='viridis')
plt.colorbar()
plt.title("Attention Weights")
plt.xlabel("Time Steps")
plt.ylabel("Query")
plt.show()