# ⚡ Notebook 04: Transformers & Self-Attention

**Week 3-4: Deep Learning & NLP Foundations**  
**Gen AI Masters Program**

---

## 📋 Objectives

By the end of this notebook, you will master:
1. ✅ Transformer architecture and encoder-decoder pipeline
2. ✅ Scaled dot-product self-attention
3. ✅ Multi-head attention and positional encoding
4. ✅ Building transformer blocks from scratch in PyTorch
5. ✅ Applying transformers to manufacturing maintenance logs
6. ✅ Visualizing attention weights and interpreting model focus

**Estimated Time:** 4-5 hours

---

## 🌟 Why Transformers?

RNNs struggle with long-range dependencies and parallelization. Transformers solve this with **self-attention**, enabling:
- ⚡ Parallel processing of entire sequences
- 🧠 Global context capture
- 🏆 State-of-the-art performance in NLP, CV, Speech
- 🏭 Enhanced understanding of manufacturing logs and procedures

Let's dive deep into the architecture that powers modern Generative AI! 🚀

In [None]:
# Core libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from collections import Counter
from typing import List, Tuple, Dict

# Visualization style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'✅ Using device: {device}')
print(f'PyTorch version: {torch.__version__}')

## 1️⃣ Limitations of RNN-based Seq2Seq

Traditional sequence-to-sequence models rely on **encoder-decoder RNNs**:
- Encoder compresses entire input into a single vector
- Decoder generates outputs step-by-step

**Problems:**
- ❌ Bottleneck: All information squeezed into one vector
- ❌ Sequential computation: Hard to parallelize
- ❌ Long-range dependencies: Context forgotten

Transformers replace recurrence with **attention** to overcome these issues.

## 2️⃣ Scaled Dot-Product Attention

Given query $Q$, key $K$, and value $V$ matrices:
$$	ext{Attention}(Q, K, V) = 	ext{softmax}eft(rac{QK^T}{qrt{d_k}}
ight) V$$

- $d_k$: dimension of keys
- Scaling prevents extremely large dot products
- Softmax generates attention weights

We'll implement it step-by-step.

In [None]:
def scaled_dot_product_attention(query: torch.Tensor,
                                    key: torch.Tensor,
                                    value: torch.Tensor,
                                    mask: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
    


In [None]:
# Visualize attention weights
plt.figure(figsize=(6, 4))
weights_matrix = attn_weights.squeeze().detach().numpy()
sns.heatmap(weights_matrix, annot=True, cmap='Blues', cbar=True)
plt.title('Attention Weight Distribution', fontweight='bold')
plt.xlabel('Key Positions')
plt.ylabel('Query Positions')
plt.tight_layout()
plt.show()

print(
)

## 3️⃣ Multi-Head Attention

Instead of a single attention function, Transformers use **multiple heads** to capture diverse relationships.

### Steps:
1. Project $Q, K, V$ into multiple subspaces.
2. Apply attention in each head independently.
3. Concatenate and project back.

$$	ext{MultiHead}(Q, K, V) = 	ext{Concat}(	ext{head}_1, ..., 	ext{head}_h) W^O$$

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear projections for Q, K, V
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None, need_weights=False):
        batch_size, seq_length, embed_dim = query.size()

        # Linear projections
        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

        # Split into heads: (batch, num_heads, seq_len, head_dim)
        Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention per head
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)

        # Final linear projection
        output = self.out_proj(attn_output)

        if need_weights:
            # Average attention weights across heads
            attn_weights = attn_weights.mean(dim=1)
            return output, attn_weights
        return output, None


# Test multi-head attention
test_embed = torch.randn(2, 6, 32)  # (batch=2, seq=6, embed=32)
mha = MultiHeadAttention(embed_dim=32, num_heads=4)
out, weights = mha(test_embed, test_embed, test_embed, need_weights=True)

print(
)
print(
*60)
print(f

## 4️⃣ Positional Encoding


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim: int, max_len: int = 5000):
        super().__init__()

        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, embed_dim)

        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len, embed_dim)
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return x


# Visualize positional encodings
pos_encoder = PositionalEncoding(embed_dim=32, max_len=100)
PE = pos_encoder.pe.squeeze().detach().numpy()[:100, :16]  # First 100 positions, 16 dims

plt.figure(figsize=(12, 4))
plt.imshow(PE, aspect='auto', cmap='coolwarm')
plt.colorbar(label='Encoding Value')
plt.title('Sinusoidal Positional Encoding (Dim 0-15)', fontweight='bold')
plt.xlabel('Embedding Dimension')
plt.ylabel('Position Index')
plt.tight_layout()
plt.show()

print(
)

## 5️⃣ Transformer Encoder Block

A single encoder layer consists of:
1. Multi-head self-attention + residual + layer norm
2. Position-wise feed-forward network + residual + layer norm

We'll build it from scratch.

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_hidden_dim: int, dropout: float = 0.1):
        super().__init__()

        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_hidden_dim, embed_dim)
        )

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output, attn_weights = self.self_attn(x, x, x, mask=mask, need_weights=True)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        ff_output = self.ff(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)

        return x, attn_weights


# Test block
block = TransformerEncoderBlock(embed_dim=64, num_heads=8, ff_hidden_dim=256)
sample_input = torch.randn(4, 10, 64)
out, attn = block(sample_input)

print(
)
print(
*60)
print(f

## 6️⃣ Manufacturing Use Case: Maintenance Log Classification

We'll classify maintenance log entries into severity levels:
- 🟢 **Normal**
- 🟠 **Warning**
- 🔴 **Critical**

### Pipeline
1. Create synthetic maintenance sentences
2. Tokenize and build vocabulary
3. Encode sequences + positional encodings
4. Train transformer encoder classifier
5. Visualize attention to interpret model focus

In [None]:
# Synthetic manufacturing maintenance log dataset
NORMAL_SENTENCES = [
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    

]

WARNING_SENTENCES = [
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    

]

CRITICAL_SENTENCES = [
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    
,
    

]

labels_map = {
: 0, 
: 1, 
: 2}

all_sentences = NORMAL_SENTENCES + WARNING_SENTENCES + CRITICAL_SENTENCES
all_labels = ([labels_map['normal']] * len(NORMAL_SENTENCES) +
              [labels_map['warning']] * len(WARNING_SENTENCES) +
              [labels_map['critical']] * len(CRITICAL_SENTENCES))

# Create DataFrame for analysis
df_logs = pd.DataFrame({
    'sentence': all_sentences,
    'label': all_labels
})

df_logs['label_name'] = df_logs['label'].map({v: k.title() for k, v in labels_map.items()})
print(df_logs.sample(6, random_state=42))

plt.figure(figsize=(6, 4))
sns.countplot(data=df_logs, x='label_name', palette='viridis')
plt.title('Class Distribution of Maintenance Logs', fontweight='bold')
plt.xlabel('Severity Level')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()

### Tokenization & Vocabulary

In [None]:
def tokenize(text: str) -> List[str]:
    tokens = text.lower().replace('-', ' ').split()
    return tokens


# Build vocabulary
all_tokens = [token for sentence in all_sentences for token in tokenize(sentence)]
vocab_counter = Counter(all_tokens)

# Reserve special tokens
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
CLS_TOKEN = '<cls>'

# Build vocab dictionary
vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1, CLS_TOKEN: 2}
for token, freq in vocab_counter.most_common():
    vocab[token] = len(vocab)

inv_vocab = {idx: token for token, idx in vocab.items()}

print(f

### PyTorch Dataset & DataLoader

In [None]:
class MaintenanceLogDataset(Dataset):
    def __init__(self, encoded_sentences: np.ndarray, labels: List[int]):
        self.encoded_sentences = torch.tensor(encoded_sentences, dtype=torch.long)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.encoded_sentences[idx], self.labels[idx]


# Train-test split
indices = np.arange(len(encoded_sentences))
np.random.shuffle(indices)
split = int(0.8 * len(indices))
train_indices, test_indices = indices[:split], indices[split:]

train_dataset = MaintenanceLogDataset(encoded_array[train_indices], df_logs['label'].iloc[train_indices].tolist())
test_dataset = MaintenanceLogDataset(encoded_array[test_indices], df_logs['label'].iloc[test_indices].tolist())

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

print(f

## 7️⃣ Transformer-Based Maintenance Classifier

Architecture:
- Token Embedding
- Positional Encoding
- Stacked Transformer Encoder Blocks
- Global pooling (CLS token)
- Classification head

In [None]:
class MaintenanceTransformerClassifier(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, num_heads: int, ff_hidden_dim: int,
                 num_layers: int, num_classes: int, max_len: int, dropout: float = 0.1):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab[PAD_TOKEN])
        self.positional_encoding = PositionalEncoding(embed_dim, max_len=max_len)

        self.layers = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, ff_hidden_dim, dropout)
            for _ in range(num_layers)
        ])

        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        # x: (batch_size, seq_len)
        mask = (x != vocab[PAD_TOKEN]).unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, seq_len)

        embeddings = self.embedding(x)  # (batch, seq_len, embed_dim)
        embeddings = self.positional_encoding(embeddings)

        attn_weights_collection = []
        out = embeddings
        for layer in self.layers:
            out, attn_weights = layer(out, mask=mask)
            attn_weights_collection.append(attn_weights)

        # Use CLS token representation (index 0)
        cls_repr = out[:, 0, :]
        logits = self.classifier(self.dropout(cls_repr))

        return logits, attn_weights_collection


# Instantiate model
model = MaintenanceTransformerClassifier(
    vocab_size=len(vocab),
    embed_dim=128,
    num_heads=4,
    ff_hidden_dim=256,
    num_layers=2,
    num_classes=len(labels_map),
    max_len=MAX_SEQ_LEN,
    dropout=0.2
).to(device)

print(model)

### Training Utilities

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)


def train_epoch(model, dataloader, optimizer, criterion):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0

    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits, _ = model(inputs)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * inputs.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return epoch_loss / total, correct / total


def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            logits, _ = model(inputs)
            loss = criterion(logits, labels)

            epoch_loss += loss.item() * inputs.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return epoch_loss / total, correct / total, all_preds, all_labels

### Train the Transformer Classifier

In [None]:
EPOCHS = 30
train_losses, test_losses = [], []
train_accuracies, test_accuracies = [], []

print('🔄 Training Transformer Classifier...')
print('=' * 60)

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion)
    test_loss, test_acc, _, _ = evaluate(model, test_loader, criterion)

    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_accuracies.append(train_acc)
    test_accuracies.append(test_acc)

    if epoch % 5 == 0 or epoch == 1:
        print(f

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

axes[0].plot(range(1, EPOCHS + 1), train_losses, label='Train Loss', color='blue', linewidth=2)
axes[0].plot(range(1, EPOCHS + 1), test_losses, label='Test Loss', color='red', linewidth=2)
axes[0].set_title('Loss Curves', fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Cross-Entropy Loss')
axes[0].legend()

axes[1].plot(range(1, EPOCHS + 1), train_accuracies, label='Train Acc', color='green', linewidth=2)
axes[1].plot(range(1, EPOCHS + 1), test_accuracies, label='Test Acc', color='orange', linewidth=2)
axes[1].set_title('Accuracy Curves', fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()

plt.tight_layout()
plt.show()

print('📈 Visualization complete!')

## 8️⃣ Inspect Predictions

In [None]:
# Map indices to labels
idx_to_label = {v: k.title() for k, v in labels_map.items()}

model.eval()
inputs, true_labels = next(iter(test_loader))
inputs = inputs.to(device)
true_labels = true_labels.to(device)

with torch.no_grad():
    logits, attn_history = model(inputs)
    preds = logits.argmax(dim=1)

for i in range(min(len(inputs), 5)):
    token_ids = inputs[i].cpu().numpy()
    tokens = [inv_vocab.get(idx, UNK_TOKEN) for idx in token_ids if idx != vocab[PAD_TOKEN]]
    sentence = ' '.join(tokens).replace(CLS_TOKEN + ' ', '')
    print(f