# 🎯 Notebook 05: Attention Mechanisms in Depth

**Week 3-4: Deep Learning & NLP Foundations**  
**Gen AI Masters Program**

---

## 📋 Objectives

By the end of this notebook, you will master:
1. ✅ Why attention improves sequence models
2. ✅ Additive (Bahdanau) vs Multiplicative (Luong) attention
3. ✅ Self-attention vs cross-attention
4. ✅ Visualizing attention weight distributions
5. ✅ Implementing attention-enhanced seq2seq models in PyTorch
6. ✅ Applying attention to manufacturing incident reports

**Estimated Time:** 3-4 hours

---

## 🔍 Why Attention?

RNNs compress everything into a single vector. Attention lets the model **focus** on relevant parts dynamically.

Use cases:
- 🧠 Neural machine translation
- 📝 Document summarization
- 🏭 Incident triage (our focus)
- 🔄 Time-series alignment

Let's explore how attention mechanisms work under the hood!

In [None]:
# Imports
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from typing import Tuple, List
from collections import Counter

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'✅ Using device: {device}')
print(f'PyTorch version: {torch.__version__}')

## 1️⃣ Additive Attention (Bahdanau)

Introduced in 2014 for neural machine translation.

### Formula

Given decoder hidden state $s_t$ and encoder outputs $h_i$:

$$ e_{t,i} = v_a^T 	anh(W_a s_t + U_a h_i) $$
$$ lpha_{t,i} = 	ext{softmax}(e_{t,i}) $$
$$ c_t = um_i lpha_{t,i} h_i $$

### Intuition
- Learnable scoring function
- Uses MLP to combine state and encoder features
- Works well for small embedding sizes

In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, encoder_hidden_dim: int, decoder_hidden_dim: int, attention_dim: int):
        super().__init__()
        self.W_a = nn.Linear(decoder_hidden_dim, attention_dim, bias=False)
        self.U_a = nn.Linear(encoder_hidden_dim, attention_dim, bias=False)
        self.v_a = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, decoder_hidden: torch.Tensor, encoder_outputs: torch.Tensor, mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
        # decoder_hidden: (batch, decoder_hidden_dim)
        # encoder_outputs: (batch, seq_len, encoder_hidden_dim)
        decoder_hidden = decoder_hidden.unsqueeze(1)  # (batch, 1, hidden)
        score = torch.tanh(self.W_a(decoder_hidden) + self.U_a(encoder_outputs))
        attention = self.v_a(score).squeeze(-1)  # (batch, seq_len)

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(attention, dim=-1)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        return context, attn_weights


# Demo with random tensors
batch_size, seq_len, enc_dim, dec_dim = 2, 5, 16, 16
encoder_outputs = torch.randn(batch_size, seq_len, enc_dim)
decoder_hidden = torch.randn(batch_size, dec_dim)
mask = torch.ones(batch_size, seq_len)

attn = BahdanauAttention(enc_dim, dec_dim, attention_dim=32)
context, weights = attn(decoder_hidden, encoder_outputs, mask)

print('🧠 Bahdanau Attention Demo')
print('='*60)
print(f'Context shape: {context.shape}')
print(f'Weights shape: {weights.shape}')

In [None]:
plt.figure(figsize=(6, 4))
sns.heatmap(weights.detach().numpy(), annot=True, cmap='YlOrRd', cbar=True)
plt.title('Additive Attention Weights (Sampled)', fontweight='bold')
plt.xlabel('Encoder Time Steps')
plt.ylabel('Batch Index')
plt.tight_layout()
plt.show()

print('✅ Additive attention learns alignment scores via a small neural network!')

## 2️⃣ Multiplicative Attention (Luong)

Introduced in 2015; more efficient for larger hidden sizes.

### Variants
- **Dot**: $e_{t,i} = s_t^T h_i$
- **General**: $e_{t,i} = s_t^T W_a h_i$
- **Concat**: similar to additive but uses a different formulation

### Advantages
- Faster because it uses matrix multiplications
- Works well when encoder/decoder dimensions match

In [None]:
class LuongAttention(nn.Module):
    def __init__(self, encoder_hidden_dim: int, decoder_hidden_dim: int, mode: str = 'general'):
        super().__init__()
        assert mode in ['dot', 'general'], 'Unsupported attention mode'
        self.mode = mode
        if mode == 'general':
            self.W_a = nn.Linear(encoder_hidden_dim, decoder_hidden_dim, bias=False)

    def forward(self, decoder_hidden: torch.Tensor, encoder_outputs: torch.Tensor, mask=None):
        # decoder_hidden: (batch, hidden)
        # encoder_outputs: (batch, seq_len, hidden)
        if self.mode == 'general':
            encoder_outputs = self.W_a(encoder_outputs)

        decoder_hidden = decoder_hidden.unsqueeze(2)  # (batch, hidden, 1)
        scores = torch.bmm(encoder_outputs, decoder_hidden).squeeze(-1)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)

        return context, attn_weights


# Demo
attn_luong = LuongAttention(enc_dim, dec_dim, mode='general')
context_luong, weights_luong = attn_luong(decoder_hidden, encoder_outputs, mask)

plt.figure(figsize=(6, 4))
sns.heatmap(weights_luong.detach().numpy(), annot=True, cmap='BuGn', cbar=True)
plt.title('Multiplicative Attention Weights (Sampled)', fontweight='bold')
plt.xlabel('Encoder Time Steps')
plt.ylabel('Batch Index')
plt.tight_layout()
plt.show()

print('⚡ Multiplicative attention relies on efficient matrix products!')

### Comparing Additive vs Multiplicative

In [None]:
# Compare runtime
import time

def benchmark_attention(module, iterations=1000):
    start = time.time()
    for _ in range(iterations):
        _ = module(decoder_hidden, encoder_outputs, mask)
    end = time.time()
    return end - start

iterations = 500
bahdanau_time = benchmark_attention(attn, iterations)
luong_time = benchmark_attention(attn_luong, iterations)

print('⏱️ Attention Runtime Comparison')
print('='*60)
print(f'Bahdanau (Additive):    {bahdanau_time:.4f} seconds')
print(f'Luong (Multiplicative): {luong_time:.4f} seconds')
print(f'→ Luong is {bahdanau_time/luong_time:.2f}x faster in this configuration')

## 3️⃣ Self-Attention vs Cross-Attention

| Type | Query From | Key/Value From | Use Case |
|------|------------|----------------|----------|
| **Self-attention** | Same sequence | Same sequence | Encoder layers, capturing contextual dependencies |
| **Cross-attention** | Decoder state | Encoder outputs | Decoder focuses on encoder tokens |

Cross-attention is critical in seq2seq tasks like translating sensor logs into recommended actions.

## 4️⃣ Manufacturing Case Study: Incident Explanation

Goal: Generate a brief explanation (sequence-to-sequence) for a maintenance alert.

### Dataset
We'll create synthetic pairs of** sensor observations → explanation**.

In [None]:
sensor_sequences = [
    'temperature spike with coolant loss',
    'vibration rise near motor bearings',
    'hydraulic pressure drop and leakage',
    'conveyor belt speed oscillation detected',
    'voltage surge hitting robotic arm',
    'persistent sensor outage on line three',
    'steam valve stuck partially open',
    'gearbox overheating under heavy load',
    'unexpected torque variance in assembly',
    'coolant contamination alert triggered'
]

explanations = [
    'reduce load and inspect coolant loop',
    'check bearing lubrication and alignment',
    'isolate leak and restore fluid pressure',
    'calibrate drive rollers and tension system',
    'disconnect arm and test power regulators',
    'swap redundant sensor and trace wiring',
    'cycle valve manually and service actuator',
    'pause line and evaluate gear lubrication',
    'tighten fasteners and recalibrate wrench',
    'flush coolant system and replace filters'
]

df_seq2seq = pd.DataFrame({'sensor': sensor_sequences, 'explanation': explanations})
print(df_seq2seq)

### Tokenization & Vocabulary Building

In [None]:
def tokenize(text):
    return text.lower().replace('-', ' ').split()


SRC_PAD, SRC_UNK, SRC_SOS, SRC_EOS = '<pad>', '<unk>', '<sos>', '<eos>'
TRG_PAD, TRG_UNK, TRG_SOS, TRG_EOS = '<pad>', '<unk>', '<sos>', '<eos>'

# Build source vocab
source_tokens = [token for sent in sensor_sequences for token in tokenize(sent)]
source_vocab = {SRC_PAD: 0, SRC_UNK: 1, SRC_SOS: 2, SRC_EOS: 3}
for token, _ in Counter(source_tokens).most_common():
    source_vocab[token] = len(source_vocab)

# Build target vocab
target_tokens = [token for sent in explanations for token in tokenize(sent)]
target_vocab = {TRG_PAD: 0, TRG_UNK: 1, TRG_SOS: 2, TRG_EOS: 3}
for token, _ in Counter(target_tokens).most_common():
    target_vocab[token] = len(target_vocab)

print(f'Source vocab size: {len(source_vocab)} | Target vocab size: {len(target_vocab)}')

inv_target_vocab = {idx: token for token, idx in target_vocab.items()}

### Encoding Sequences

In [None]:
MAX_SRC_LEN = 12
MAX_TRG_LEN = 12


def encode_sequence(text, vocab, max_len, sos_token=None, eos_token=None):
    tokens = tokenize(text)
    token_ids = []
    if sos_token: token_ids.append(vocab[sos_token])
    token_ids.extend([vocab.get(tok, vocab[list(vocab.keys())[1]]) for tok in tokens])
    if eos_token: token_ids.append(vocab[eos_token])

    if len(token_ids) < max_len:
        token_ids += [vocab[list(vocab.keys())[0]]] * (max_len - len(token_ids))
    else:
        token_ids = token_ids[:max_len]

    return token_ids


encoded_src = [encode_sequence(sent, source_vocab, MAX_SRC_LEN, sos_token=SRC_SOS, eos_token=SRC_EOS) for sent in sensor_sequences]
encoded_trg = [encode_sequence(sent, target_vocab, MAX_TRG_LEN, sos_token=TRG_SOS, eos_token=TRG_EOS) for sent in explanations]

print(np.array(encoded_src).shape, np.array(encoded_trg).shape)

### Seq2Seq Dataset

In [None]:
class IncidentExplanationDataset(Dataset):
    def __init__(self, src_sequences, trg_sequences):
        self.src = torch.tensor(src_sequences, dtype=torch.long)
        self.trg = torch.tensor(trg_sequences, dtype=torch.long)

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return self.src[idx], self.trg[idx]


dataset = IncidentExplanationDataset(encoded_src, encoded_trg)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

for src_batch, trg_batch in loader:
    print('Source batch shape:', src_batch.shape)
    print('Target batch shape:', trg_batch.shape)
    break

## 5️⃣ Encoder-Decoder with Bahdanau Attention

We'll build a compact seq2seq model using GRU encoder/decoder with additive attention.

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.gru(embedded)
        return outputs, hidden


class DecoderWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, attention_dim, num_layers=1, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim + hidden_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.attention = BahdanauAttention(hidden_dim, hidden_dim, attention_dim)
        self.fc_out = nn.Linear(hidden_dim * 2, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_token, hidden, encoder_outputs, mask=None):
        embedded = self.dropout(self.embedding(input_token.unsqueeze(1)))  # (batch, 1, embed)
        context, attn_weights = self.attention(hidden[-1], encoder_outputs, mask)
        gru_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)
        output, hidden = self.gru(gru_input, hidden)
        output = torch.cat([output.squeeze(1), context], dim=-1)
        prediction = self.fc_out(output)
        return prediction, hidden, attn_weights


class Seq2SeqAttention(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx

    def create_mask(self, src):
        return (src != self.pad_idx)

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size, trg_len = trg.shape
        vocab_size = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(src.device)

        encoder_outputs, hidden = self.encoder(src)
        mask = self.create_mask(src)

        input_token = trg[:, 0]  # <sos>

        attention_history = []
        for t in range(1, trg_len):
            output, hidden, attn_weights = self.decoder(input_token, hidden, encoder_outputs, mask)
            outputs[:, t] = output
            attention_history.append(attn_weights.detach().cpu())

            teacher_force = np.random.rand() < teacher_forcing_ratio
            top1 = output.argmax(dim=1)
            input_token = trg[:, t] if teacher_force else top1

        return outputs, attention_history


# Instantiate model
ENC_EMBED_DIM = 64
DEC_EMBED_DIM = 64
HIDDEN_DIM = 128
ATTN_DIM = 64

encoder = Encoder(len(source_vocab), ENC_EMBED_DIM, HIDDEN_DIM)
decoder = DecoderWithAttention(len(target_vocab), DEC_EMBED_DIM, HIDDEN_DIM, ATTN_DIM)
seq2seq_model = Seq2SeqAttention(encoder, decoder, pad_idx=source_vocab[SRC_PAD]).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=target_vocab[TRG_PAD])
optimizer = optim.Adam(seq2seq_model.parameters(), lr=1e-3)

print(seq2seq_model)

### Training Loop

In [None]:
def train_seq2seq(model, dataloader, optimizer, criterion, epochs=200):
    model.train()
    losses = []
    for epoch in range(1, epochs + 1):
        epoch_loss = 0
        for src_batch, trg_batch in dataloader:
            src_batch = src_batch.to(device)
            trg_batch = trg_batch.to(device)

            optimizer.zero_grad()
            output, _ = model(src_batch, trg_batch, teacher_forcing_ratio=0.75)

            output_dim = output.shape[-1]
            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg_batch[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        losses.append(avg_loss)
        if epoch % 40 == 0 or epoch == 1:
            print(f'Epoch {epoch:03d} | Loss: {avg_loss:.4f}')
    return losses


print('🚀 Training seq2seq model with additive attention...')
losses = train_seq2seq(seq2seq_model, loader, optimizer, criterion, epochs=200)
print('✅ Training complete!')

In [None]:
plt.figure(figsize=(12, 4))
plt.plot(losses, color='blue', linewidth=2)
plt.title('Seq2Seq Attention Training Loss', fontweight='bold')
plt.xlabel('Epoch')
plt.ylabel('Cross-Entropy Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### Evaluation & Attention Visualization

In [None]:
def translate_sentence(model, src_sentence: str):
    model.eval()
    src_encoded = encode_sequence(src_sentence, source_vocab, MAX_SRC_LEN, sos_token=SRC_SOS, eos_token=SRC_EOS)
    src_tensor = torch.tensor(src_encoded, dtype=torch.long).unsqueeze(0).to(device)

    trg_indices = [target_vocab[TRG_SOS]]
    max_len = MAX_TRG_LEN
    attention_maps = []

    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor)
        mask = model.create_mask(src_tensor)
        input_token = torch.tensor([target_vocab[TRG_SOS]], device=device)

        for t in range(1, max_len):
            output, hidden, attn_weights = model.decoder(input_token, hidden, encoder_outputs, mask)
            top1 = output.argmax(1).item()
            trg_indices.append(top1)
            attention_maps.append(attn_weights.squeeze().cpu().numpy())
            if top1 == target_vocab[TRG_EOS]:
                break
            input_token = torch.tensor([top1], device=device)

    trg_tokens = [inv_target_vocab.get(idx, TRG_UNK) for idx in trg_indices[1:-1]]
    return trg_tokens, np.array(attention_maps)


example_sentence = 'hydraulic pressure drop and leakage'
prediction, attention_map = translate_sentence(seq2seq_model, example_sentence)

print('📝 Input :', example_sentence)
print('💡 Output:', ' '.join(prediction))

In [None]:
# Visualize alignment
src_tokens = tokenize(example_sentence)
src_tokens = [SRC_SOS] + src_tokens + [SRC_EOS]

plt.figure(figsize=(10, 6))
sns.heatmap(attention_map[:len(prediction), :len(src_tokens)], cmap='magma', annot=True, fmt='.2f')
plt.xlabel('Source Tokens', fontweight='bold')
plt.ylabel('Generated Tokens', fontweight='bold')
plt.xticks(np.arange(len(src_tokens)) + 0.5, src_tokens, rotation=45, ha='right')
plt.yticks(np.arange(len(prediction)) + 0.5, prediction, rotation=0)
plt.title('Attention Alignment: Maintenance Explanation', fontweight='bold')
plt.tight_layout()
plt.show()

## 6️⃣ Attention Variants Recap

| Variant | Highlights | Pros | Cons |
|---------|-----------|------|------|
| **Bahdanau (Additive)** | MLP scoring | Flexible, good for small dims | Slightly slower |
| **Luong (Multiplicative)** | Dot-product scoring | Fast, efficient | Requires matching dims |
| **Self-Attention** | Intra-sequence focus | Captures global context | Quadratic cost |
| **Cross-Attention** | Decoder ↔ Encoder | Enables seq2seq alignment | Needs encoder states |
| **Multi-Head** | Multiple subspaces | Diverse patterns | Higher compute |

Attention mechanisms underpin modern Transformers and large language models.

## 🎉 Summary

Great job mastering attention!

### Key Takeaways
- ✅ Attention alleviates RNN bottlenecks
- ✅ Additive vs multiplicative trade-offs
- ✅ Self vs cross-attention roles
- ✅ Sequencing attention for manufacturing logs
- ✅ Visualizing alignments for explainability

### What You Built
1. 🧮 Bahdanau and Luong attention modules
2. ⚖️ Runtime benchmarking for attention variants
3. 🏭 Synthetic incident explanation dataset
4. 🔄 Seq2seq model with additive attention
5. 🔍 Attention heatmap highlighting critical tokens

### Manufacturing Insights
- Attention highlights influential observations in maintenance logs
- Alignment heatmaps support root-cause analysis
- Sequence explanations accelerate technician decisions

### Next Steps
Proceed to **Notebook 06: Embeddings** to understand vector representations that power attention models.

<div align="center">
<b>Attention unlocked! Time to dive into embeddings. 🧲🚀</b>
</div>