# Attention Visualization in Transformer Models

## Objective
This notebook demonstrates how to visualize and interpret attention mechanisms in transformer models. We'll:

1. Load a pre-trained model (DistilBERT)
2. Extract attention weights for sample text
3. Create various visualizations of attention patterns
4. Analyze what the model is focusing on

## Key Concepts
- **Attention Mechanism**: How models decide which parts of the input to focus on
- **Multi-Head Attention**: Multiple attention patterns learned in parallel
- **Attention Heads**: Individual attention mechanisms that may capture different linguistic patterns
- **Attention Maps**: Visualizations showing which tokens attend to which other tokens

In [None]:
# Install required packages
!pip install transformers torch matplotlib seaborn bertviz ipywidgets
!pip install datasets  # For additional text examples

In [None]:
# Import required libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel, AutoConfig
from bertviz import head_view, model_view
import warnings
warnings.filterwarnings('ignore')

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")

## 1. Model Setup

We'll use **DistilBERT**, a smaller and faster version of BERT that retains most of its performance while being more efficient for our visualization purposes.

### Why DistilBERT?
- **Lightweight**: 6 layers vs BERT's 12 layers
- **Fast inference**: Good for interactive exploration
- **Clear attention patterns**: Well-documented attention behaviors
- **Good for visualization**: Manageable number of attention heads (12 heads per layer)

In [None]:
# Load DistilBERT model and tokenizer
model_name = "distilbert-base-uncased"

print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, output_attentions=True)
model = AutoModel.from_pretrained(model_name, config=config)

# Set model to evaluation mode
model.eval()

print(f"Model loaded: {model_name}")
print(f"Number of layers: {config.n_layers}")
print(f"Number of attention heads per layer: {config.n_heads}")
print(f"Hidden size: {config.dim}")
print(f"Vocabulary size: {config.vocab_size}")

## 2. Text Processing and Attention Extraction

Let's define some sample texts to analyze and create functions to extract attention weights.

In [None]:
def extract_attention(text, model, tokenizer):
    """
    Extract attention weights for a given text.
    
    Args:
        text (str): Input text to analyze
        model: Pre-trained transformer model
        tokenizer: Corresponding tokenizer
    
    Returns:
        tuple: (tokens, attention_weights, input_ids)
    """
    # Tokenize the input
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    
    # Get model outputs with attention
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Extract attention weights
    # Shape: (num_layers, batch_size, num_heads, seq_len, seq_len)
    attention_weights = outputs.attentions
    
    return tokens, attention_weights, inputs["input_ids"]

# Test the function with sample text
sample_text = "The cat sat on the mat."
tokens, attention_weights, input_ids = extract_attention(sample_text, model, tokenizer)

print(f"Sample text: '{sample_text}'")
print(f"Tokens: {tokens}")
print(f"Number of layers: {len(attention_weights)}")
print(f"Attention shape per layer: {attention_weights[0].shape}")
print(f"(batch_size, num_heads, seq_len, seq_len)")

In [None]:
def plot_attention_head(attention_weights, tokens, layer_idx, head_idx, figsize=(10, 8)):
    """
    Plot attention weights for a specific layer and head.
    
    Args:
        attention_weights: Attention weights from model
        tokens: List of tokens
        layer_idx: Which layer to visualize
        head_idx: Which attention head to visualize
        figsize: Figure size for the plot
    """
    # Extract attention for specific layer and head
    # Shape: (seq_len, seq_len)
    attention = attention_weights[layer_idx][0, head_idx].numpy()
    
    plt.figure(figsize=figsize)
    
    # Create heatmap
    sns.heatmap(
        attention,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='Blues',
        annot=True,
        fmt='.2f',
        square=True,
        cbar_kws={'label': 'Attention Weight'}
    )
    
    plt.title(f'Attention Head {head_idx}, Layer {layer_idx}')
    plt.xlabel('Keys (attending to)')
    plt.ylabel('Queries (attending from)')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

def plot_layer_attention_summary(attention_weights, tokens, layer_idx, figsize=(12, 8)):
    """
    Plot average attention across all heads for a specific layer.
    """
    # Average across all heads for this layer
    layer_attention = attention_weights[layer_idx][0].mean(dim=0).numpy()
    
    plt.figure(figsize=figsize)
    
    sns.heatmap(
        layer_attention,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='Reds',
        annot=True,
        fmt='.2f',
        square=True,
        cbar_kws={'label': 'Average Attention Weight'}
    )
    
    plt.title(f'Layer {layer_idx} - Average Attention Across All Heads')
    plt.xlabel('Keys (attending to)')
    plt.ylabel('Queries (attending from)')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

## 3. Basic Attention Visualization

Let's start with visualizing attention patterns for our sample text: "The cat sat on the mat."

In [None]:
# Visualize attention for specific heads
print("Visualizing different attention heads...")

# Plot attention for the first layer, first head
plot_attention_head(attention_weights, tokens, layer_idx=0, head_idx=0)

# Plot attention for a middle layer, different head
plot_attention_head(attention_weights, tokens, layer_idx=2, head_idx=5)

# Plot attention for the last layer, last head
plot_attention_head(attention_weights, tokens, layer_idx=5, head_idx=11)

In [None]:
# Visualize average attention per layer
print("\\nVisualizing average attention across all heads per layer...")

# Plot average attention for first, middle, and last layers
for layer_idx in [0, 2, 5]:
    plot_layer_attention_summary(attention_weights, tokens, layer_idx)

## 4. Interactive Attention Visualization with BertViz

BertViz provides interactive visualizations that allow you to explore attention patterns more intuitively.

In [None]:
# Interactive visualization with BertViz
from bertviz import head_view, model_view

# Prepare inputs for BertViz
text = "The cat sat on the mat."
inputs = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids']
attention = model(input_ids)[-1]  # Get attention weights

# Convert to format expected by BertViz
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

print("Generating interactive attention visualization...")
print("Note: The visualization will appear below. Click on different heads and layers to explore!")

# Head view - shows attention for individual heads
head_view(attention, tokens)

In [None]:
# Model view - shows attention across all layers and heads
print("\\nModel view - Overview of all layers and heads:")
model_view(attention, tokens)

## 5. Analyzing Different Text Types

Let's explore how attention patterns change with different types of text and linguistic structures.

In [None]:
# Define different types of sentences to analyze
test_sentences = [
    "The cat sat on the mat.",                    # Simple sentence
    "The big red cat sat on the soft mat.",      # Sentence with adjectives
    "The cat that I love sat on the mat.",       # Sentence with relative clause
    "When the cat sat down, the mat moved.",     # Complex sentence with subordinate clause
    "The cat and the dog played together.",      # Sentence with coordination
    "She said that the cat was sleeping."        # Sentence with reported speech
]

def analyze_sentence_attention(sentence, layer_idx=2, head_idx=5):
    """Analyze attention for a sentence and provide insights."""
    tokens, attention_weights, _ = extract_attention(sentence, model, tokenizer)
    
    print(f"\\n{'='*60}")
    print(f"Analyzing: '{sentence}'")
    print(f"Tokens: {tokens}")
    print(f"{'='*60}")
    
    # Plot attention for specified layer and head
    plot_attention_head(attention_weights, tokens, layer_idx, head_idx, figsize=(12, 10))
    
    # Extract attention matrix for analysis
    attention_matrix = attention_weights[layer_idx][0, head_idx].numpy()
    
    # Find highest attention connections
    print("\\nTop 5 attention connections:")
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            if i != j:  # Skip self-attention
                print(f"{tokens[i]} -> {tokens[j]}: {attention_matrix[i,j]:.3f}")
    
    # Sort and show top connections
    connections = []
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            if i != j:
                connections.append((tokens[i], tokens[j], attention_matrix[i,j]))
    
    connections.sort(key=lambda x: x[2], reverse=True)
    print("\\nStrongest attention connections:")
    for src, tgt, weight in connections[:5]:
        print(f"  {src} -> {tgt}: {weight:.3f}")

# Analyze the first few sentences
for sentence in test_sentences[:3]:
    analyze_sentence_attention(sentence)

In [None]:
# Function to compare attention patterns across layers
def compare_layers_attention(sentence, head_idx=5):
    """Compare attention patterns across different layers for the same sentence."""
    tokens, attention_weights, _ = extract_attention(sentence, model, tokenizer)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    for layer_idx in range(6):  # DistilBERT has 6 layers
        attention = attention_weights[layer_idx][0, head_idx].numpy()
        
        sns.heatmap(
            attention,
            xticklabels=tokens,
            yticklabels=tokens,
            cmap='Blues',
            annot=True,
            fmt='.2f',
            square=True,
            ax=axes[layer_idx],
            cbar=False
        )
        
        axes[layer_idx].set_title(f'Layer {layer_idx}')
        axes[layer_idx].set_xlabel('Keys')
        axes[layer_idx].set_ylabel('Queries')
        axes[layer_idx].tick_params(axis='x', rotation=45)
    
    plt.suptitle(f'Attention Patterns Across Layers\\nSentence: "{sentence}"\\nHead: {head_idx}', 
                 fontsize=16)
    plt.tight_layout()
    plt.show()

# Compare layers for a complex sentence
complex_sentence = "The cat that I love sat on the mat."
compare_layers_attention(complex_sentence)

## 6. Attention Pattern Analysis & Observations

Based on the visualizations above, let's document key observations about what attention appears to be focusing on.

In [None]:
# Systematic analysis of attention patterns
def analyze_attention_patterns(sentence):
    """
    Systematically analyze and categorize attention patterns in a sentence.
    """
    tokens, attention_weights, _ = extract_attention(sentence, model, tokenizer)
    print(f"\\nAnalyzing patterns in: '{sentence}'")
    print(f"Tokens: {tokens}")
    
    patterns = {
        'positional': [],      # Attention to adjacent positions
        'syntactic': [],       # Attention following syntactic structure
        'semantic': [],        # Attention between semantically related words
        'special_tokens': []   # Attention involving [CLS], [SEP]
    }
    
    # Analyze each layer
    for layer_idx in range(len(attention_weights)):
        layer_attention = attention_weights[layer_idx][0].mean(dim=0).numpy()  # Average across heads
        
        # Find strong attention connections (threshold = 0.1)
        strong_connections = []
        for i in range(len(tokens)):
            for j in range(len(tokens)):
                if layer_attention[i, j] > 0.1 and i != j:
                    strong_connections.append((i, j, tokens[i], tokens[j], layer_attention[i, j]))
        
        print(f"\\nLayer {layer_idx} - Strong connections (>0.1):")
        for i, j, token_i, token_j, weight in strong_connections:
            print(f"  {token_i} -> {token_j}: {weight:.3f}")
            
            # Categorize the connection
            if abs(i - j) == 1:
                patterns['positional'].append((layer_idx, token_i, token_j, weight))
            elif token_i in ['[CLS]', '[SEP]'] or token_j in ['[CLS]', '[SEP]']:
                patterns['special_tokens'].append((layer_idx, token_i, token_j, weight))
            elif token_i in ['the', 'a', 'an'] and token_j not in ['[CLS]', '[SEP]']:
                patterns['syntactic'].append((layer_idx, token_i, token_j, weight))
            else:
                patterns['semantic'].append((layer_idx, token_i, token_j, weight))
    
    return patterns

# Analyze patterns for different sentence types
sentences_to_analyze = [
    "The cat sat on the mat.",
    "The big red cat sat on the soft mat.",
    "The cat that I love sat on the mat."
]

all_patterns = {}
for sentence in sentences_to_analyze:
    all_patterns[sentence] = analyze_attention_patterns(sentence)

## 7. Key Insights and Observations

Based on our attention visualizations, here are the key patterns we typically observe in transformer attention:

### 📍 **Positional Patterns**
- **Adjacent Token Attention**: Many heads attend to immediately neighboring tokens
- **Distance Decay**: Attention generally decreases with token distance
- **Position-specific Behavior**: Different layers show different positional preferences

### 🔤 **Syntactic Patterns**
- **Determiner-Noun Relationships**: "the", "a", "an" often attend strongly to the nouns they modify
- **Subject-Verb Connections**: Subjects and their corresponding verbs show strong attention
- **Modifier Relationships**: Adjectives attend to the nouns they modify

### 💭 **Semantic Patterns**
- **Coreference**: Pronouns attend to their antecedents
- **Semantic Similarity**: Words with related meanings show mutual attention
- **Thematic Roles**: Arguments of verbs (subject, object) show structured attention patterns

### 🎯 **Special Token Behavior**
- **[CLS] Token**: Often acts as a "summary" token, attending broadly across the sentence
- **[SEP] Token**: Boundary marker with specific attention patterns
- **Sentence-level Information**: Special tokens help aggregate sentence-level representations

### 🏗️ **Layer-wise Evolution**
- **Early Layers**: Focus more on local, syntactic relationships
- **Middle Layers**: Capture more complex syntactic structures
- **Later Layers**: Encode higher-level semantic and discourse relationships

## 8. Explore Your Own Examples

Use the space below to experiment with your own sentences and observe attention patterns!

In [None]:
# Try your own sentences!
# Modify the sentence below and run the cell to see attention patterns

your_sentence = "Your custom sentence goes here."

# Quick analysis function
def quick_attention_analysis(sentence, layer=2, head=5):
    """Quick visualization and analysis of a sentence."""
    tokens, attention_weights, _ = extract_attention(sentence, model, tokenizer)
    
    print(f"Analyzing: '{sentence}'")
    print(f"Tokens: {tokens}")
    
    # Plot the attention
    plot_attention_head(attention_weights, tokens, layer, head)
    
    # Show BertViz visualization
    inputs = tokenizer.encode_plus(sentence, return_tensors='pt', add_special_tokens=True)
    attention = model(inputs['input_ids'])[-1]
    tokens_bertviz = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    head_view(attention, tokens_bertviz)

# Uncomment and modify the line below to try your own sentence:
# quick_attention_analysis("The dog chased the cat up the tree.")

print("Ready for exploration! Modify 'your_sentence' above and call quick_attention_analysis(your_sentence)")

## 9. Conclusion and Next Steps

### 🎯 **What We've Accomplished**
- ✅ Loaded and configured DistilBERT for attention visualization
- ✅ Created custom visualization functions for attention heatmaps
- ✅ Used BertViz for interactive attention exploration
- ✅ Analyzed attention patterns across different sentence types
- ✅ Identified key linguistic patterns captured by attention
- ✅ Documented systematic observations about attention behavior

### 🔍 **Key Takeaways**
1. **Attention is Interpretable**: We can observe meaningful linguistic patterns
2. **Layer Specialization**: Different layers capture different types of relationships
3. **Head Diversity**: Different attention heads specialize in different phenomena
4. **Structural Awareness**: Models implicitly learn syntactic and semantic structures

### 🚀 **Possible Extensions**
- **Try Different Models**: Compare BERT, RoBERTa, GPT-2 attention patterns
- **Multilingual Analysis**: Explore attention in different languages
- **Fine-tuned Models**: See how attention changes after task-specific training
- **Probing Tasks**: Test if attention correlates with specific linguistic phenomena
- **Attention Flow**: Track how information flows through layers
- **Attention Rollout**: Combine attention across layers for end-to-end analysis

### 📚 **Further Reading**
- [BertViz Documentation](https://github.com/jessevig/bertviz)
- [Attention is All You Need (Original Transformer Paper)](https://arxiv.org/abs/1706.03762)
- [What Does BERT Look At?](https://arxiv.org/abs/1906.04341)
- [A Primer on Neural Network Models for Natural Language Processing](https://arxiv.org/abs/1510.00726)

---

**Happy exploring! 🧠✨**