In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
import random

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Enhanced plotting setup
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Custom colormap for attention visualization
colors = ['white', 'lightblue', 'blue', 'darkblue']
attention_cmap = LinearSegmentedColormap.from_list('attention', colors)

print("Ready to explore attention mechanisms!")


In [None]:
# Visualize the information bottleneck problem
def visualize_information_bottleneck():
    """Demonstrate how information gets compressed and lost"""
    
    # Simulate different sentence lengths and information content
    sentences = [
        "Hi",
        "Hello there",
        "How are you doing today?",
        "I am currently working on a machine learning project that involves natural language processing",
        "The field of artificial intelligence has been rapidly advancing in recent years with breakthrough developments in deep learning"
    ]
    
    # Fixed context vector size (typical: 256-512 dimensions)
    context_vector_size = 256
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Information content vs context vector capacity
    sentence_lengths = [len(sent.split()) for sent in sentences]
    # Assume each word contains ~50 units of information
    information_content = [length * 50 for length in sentence_lengths]
    
    ax1.bar(range(len(sentences)), information_content, alpha=0.7, color='skyblue', label='Input Information')
    ax1.axhline(y=context_vector_size, color='red', linestyle='--', linewidth=2, label='Context Vector Capacity')
    ax1.fill_between(range(len(sentences)), 0, context_vector_size, alpha=0.3, color='red')
    
    ax1.set_xlabel('Sentence Index')
    ax1.set_ylabel('Information Units')
    ax1.set_title('Information Bottleneck Problem')
    ax1.legend()
    ax1.set_xticks(range(len(sentences)))
    ax1.set_xticklabels([f'S{i+1}' for i in range(len(sentences))])
    
    # Information retention percentage
    retention_percentages = [min(100, 100 * context_vector_size / info) for info in information_content]
    
    colors = ['green' if r > 90 else 'orange' if r > 50 else 'red' for r in retention_percentages]
    bars = ax2.bar(range(len(sentences)), retention_percentages, alpha=0.7, color=colors)
    
    ax2.set_xlabel('Sentence Index')
    ax2.set_ylabel('Information Retained (%)')
    ax2.set_title('Information Loss by Sentence Length')
    ax2.set_ylim(0, 100)
    ax2.set_xticks(range(len(sentences)))
    ax2.set_xticklabels([f'S{i+1}' for i in range(len(sentences))])
    
    # Add percentage labels on bars
    for i, (bar, pct) in enumerate(zip(bars, retention_percentages)):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'{pct:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed analysis
    print("INFORMATION BOTTLENECK ANALYSIS")
    print("=" * 50)
    for i, (sent, length, info, retention) in enumerate(zip(sentences, sentence_lengths, information_content, retention_percentages)):
        print(f"\nSentence {i+1}: \"{sent[:50]}{'...' if len(sent) > 50 else ''}\"")
        print(f"  Length: {length} words")
        print(f"  Estimated information: {info} units")
        print(f"  Information retained: {retention:.1f}%")
        if retention < 50:
            print("  ⚠️  SEVERE INFORMATION LOSS!")
        elif retention < 90:
            print("  ⚠️  Moderate information loss")
        else:
            print("  ✅ Good information preservation")

visualize_information_bottleneck()


In [None]:
# Simulate human translation process vs encoder-decoder
def compare_translation_processes():
    """Compare how humans translate vs basic encoder-decoder models"""
    
    # Example sentence to translate
    source_sentence = "The quick brown fox jumps over the lazy dog"
    target_sentence = "Le renard brun rapide saute par-dessus le chien paresseux"
    
    source_words = source_sentence.split()
    target_words = target_sentence.split()
    
    print("TRANSLATION PROCESS COMPARISON")
    print("=" * 60)
    print(f"Source: {source_sentence}")
    print(f"Target: {target_sentence}")
    print()
    
    # Human translation process
    print("🧠 HUMAN TRANSLATOR PROCESS:")
    print("-" * 30)
    human_process = [
        ("Le", "Looking at 'The' - translating article"),
        ("renard", "Looking at 'fox' - main subject"),
        ("brun", "Looking back at 'brown' - adjective for fox"),
        ("rapide", "Looking at 'quick' - another adjective"),
        ("saute", "Looking at 'jumps' - main verb"),
        ("par-dessus", "Looking at 'over' - preposition"),
        ("le", "Looking at 'the' - second article"),
        ("chien", "Looking at 'dog' - second noun"),
        ("paresseux", "Looking at 'lazy' - final adjective")
    ]
    
    for i, (target_word, process) in enumerate(human_process):
        print(f"Step {i+1}: Generate '{target_word}' - {process}")
    
    print(f"\n✅ Human translators look back at different parts of the source")
    print("✅ They don't memorize everything at once")
    print("✅ They focus on relevant words for each output word")
    print()
    
    # Basic encoder-decoder process
    print("🤖 BASIC ENCODER-DECODER PROCESS:")
    print("-" * 35)
    print("Step 1: Encoder reads entire source sentence")
    print("Step 2: Encoder creates single context vector")
    print("Step 3: Decoder generates target using ONLY context vector")
    print()
    print("Decoder steps:")
    for i, word in enumerate(target_words):
        print(f"  Step {i+1}: Generate '{word}' - using same context vector")
    
    print(f"\n❌ Same context vector for all output words")
    print("❌ No looking back at specific source words")
    print("❌ Information bottleneck problem")
    
    return source_words, target_words

# Visualize the attention concept
def visualize_attention_concept():
    """Create a visual representation of attention mechanism"""
    
    source_words = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"]
    target_words = ["Le", "renard", "brun", "rapide", "saute", "par-dessus", "le", "chien", "paresseux"]
    
    # Simulated attention weights (what words to focus on for each target word)
    attention_patterns = [
        [0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0],  # "Le" attends to "The"
        [0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.1],  # "renard" attends to "fox"
        [0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0],  # "brun" attends to "brown"
        [0.0, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1],  # "rapide" attends to "quick"
        [0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.0, 0.1],  # "saute" attends to "jumps"
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0, 0.1],  # "par-dessus" attends to "over"
        [0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.0, 0.0],  # "le" attends to "the"
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9],  # "chien" attends to "dog"
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9, 0.1],  # "paresseux" attends to "lazy"
    ]
    
    # Create attention heatmap
    fig, ax = plt.subplots(figsize=(12, 8))
    
    attention_matrix = np.array(attention_patterns)
    im = ax.imshow(attention_matrix, cmap=attention_cmap, aspect='auto')
    
    # Set ticks and labels
    ax.set_xticks(range(len(source_words)))
    ax.set_yticks(range(len(target_words)))
    ax.set_xticklabels(source_words, rotation=45, ha='right')
    ax.set_yticklabels(target_words)
    
    # Add text annotations
    for i in range(len(target_words)):
        for j in range(len(source_words)):
            if attention_matrix[i, j] > 0.1:
                text = ax.text(j, i, f'{attention_matrix[i, j]:.1f}',
                             ha="center", va="center", color="white", fontweight='bold')
    
    ax.set_xlabel('Source Words (English)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Target Words (French)', fontsize=12, fontweight='bold')
    ax.set_title('Attention Mechanism: Where Each Target Word "Looks"', fontsize=14, fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, shrink=0.8)
    cbar.set_label('Attention Weight', rotation=270, labelpad=20)
    
    plt.tight_layout()
    plt.show()
    
    return attention_matrix

# Run the comparisons
source, target = compare_translation_processes()
print("\n" + "="*60)
attention_matrix = visualize_attention_concept()


In [None]:
# Explain attention mechanism components step by step
def explain_attention_components():
    """Break down the attention mechanism into understandable components"""
    
    print("ATTENTION MECHANISM COMPONENTS")
    print("=" * 50)
    
    # Step 1: Query, Key, Value concept
    print("1. QUERY-KEY-VALUE FRAMEWORK")
    print("-" * 30)
    print("🔍 QUERY (Q): 'What am I trying to generate right now?'")
    print("   - Current decoder state")
    print("   - Represents what we want to focus on")
    print("   - Example: When generating 'renard', Q represents the current decoding step")
    print()
    
    print("🔑 KEY (K): 'What information is available to attend to?'")
    print("   - All encoder hidden states")
    print("   - Represents available information")
    print("   - Example: Hidden states for ['The', 'quick', 'brown', 'fox', ...]")
    print()
    
    print("💎 VALUE (V): 'What information do I extract?'")
    print("   - Same as keys (encoder hidden states)")
    print("   - The actual information we combine")
    print("   - Example: Actual vector representations of input words")
    print()
    
    # Step 2: Attention computation
    print("2. ATTENTION COMPUTATION PROCESS")
    print("-" * 35)
    print("Step 1: Compute similarity scores between Query and all Keys")
    print("        scores = Q · K (dot product)")
    print()
    print("Step 2: Convert scores to probabilities (attention weights)")
    print("        attention_weights = softmax(scores)")
    print()
    print("Step 3: Compute weighted average of Values")
    print("        context = Σ(attention_weights × Values)")
    print()
    
    # Step 3: Mathematical example
    print("3. SIMPLIFIED MATHEMATICAL EXAMPLE")
    print("-" * 38)
    
    # Simulate small vectors for explanation
    print("Assume we have 3 encoder states and 1 decoder state:")
    
    # Dummy vectors (simplified to 2D for visualization)
    query = np.array([1.0, 0.5])
    keys = np.array([[0.8, 0.2],   # Key 1
                     [0.1, 0.9],   # Key 2  
                     [0.9, 0.1]])  # Key 3
    values = keys.copy()  # Values same as keys in basic attention
    
    print(f"Query (current decoder state): {query}")
    print(f"Key 1 (encoder state 1): {keys[0]}")
    print(f"Key 2 (encoder state 2): {keys[1]}")
    print(f"Key 3 (encoder state 3): {keys[2]}")
    print()
    
    # Compute attention scores
    scores = np.dot(keys, query)
    print(f"Attention scores (Q·K): {scores}")
    
    # Compute attention weights
    attention_weights = np.exp(scores) / np.sum(np.exp(scores))
    print(f"Attention weights (softmax): {attention_weights}")
    print(f"Sum of weights: {np.sum(attention_weights):.3f} (should be 1.0)")
    print()
    
    # Compute context vector
    context = np.sum(attention_weights.reshape(-1, 1) * values, axis=0)
    print(f"Context vector: {context}")
    print()
    
    # Visualize the process
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
    
    # Plot 1: Attention scores
    ax1.bar(['Key 1', 'Key 2', 'Key 3'], scores, color='lightblue', alpha=0.7)
    ax1.set_title('Step 1: Attention Scores (Q·K)')
    ax1.set_ylabel('Score')
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Attention weights
    colors = ['red', 'green', 'blue']
    bars = ax2.bar(['Key 1', 'Key 2', 'Key 3'], attention_weights, color=colors, alpha=0.7)
    ax2.set_title('Step 2: Attention Weights (Softmax)')
    ax2.set_ylabel('Weight')
    ax2.set_ylim(0, 1)
    for bar, weight in zip(bars, attention_weights):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{weight:.3f}', ha='center', va='bottom', fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Values weighted by attention
    weighted_values = attention_weights.reshape(-1, 1) * values
    ax3.bar(range(len(values)), [v[0] for v in weighted_values], alpha=0.7, 
            color=colors, label='Dimension 1')
    ax3.bar(range(len(values)), [v[1] for v in weighted_values], bottom=[v[0] for v in weighted_values],
            alpha=0.7, color=colors, label='Dimension 2')
    ax3.set_title('Step 3: Weighted Values')
    ax3.set_xlabel('Encoder State')
    ax3.set_ylabel('Weighted Value')
    ax3.set_xticks(range(3))
    ax3.set_xticklabels(['State 1', 'State 2', 'State 3'])
    ax3.legend()
    
    # Plot 4: Final context vector
    ax4.bar(['Dim 1', 'Dim 2'], context, color='purple', alpha=0.7)
    ax4.set_title('Step 4: Context Vector (Weighted Sum)')
    ax4.set_ylabel('Value')
    for i, val in enumerate(context):
        ax4.text(i, val + 0.01, f'{val:.3f}', ha='center', va='bottom', fontweight='bold')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return attention_weights, context

# Run the explanation
weights, context = explain_attention_components()


In [None]:
# Demonstrate information bottleneck with different sequence lengths
def demonstrate_bottleneck_problem():
    """Show how information gets lost in longer sequences"""
    
    # Example sentences of increasing length
    sentences = [
        ("Hi", "Salut"),
        ("Hello there", "Bonjour là"),
        ("How are you doing today", "Comment allez-vous aujourd'hui"),
        ("I would like to reserve a table for two people at eight o'clock tonight", 
         "Je voudrais réserver une table pour deux personnes à huit heures ce soir"),
        ("The weather forecast says it will be sunny tomorrow with a high temperature of twenty-five degrees", 
         "Les prévisions météorologiques disent qu'il fera ensoleillé demain avec une température maximale de vingt-cinq degrés")
    ]
    
    # Simulate context vector capacity (fixed size)
    context_capacity = 256  # typical hidden dimension
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # Data for plotting
    input_lengths = [len(s[0].split()) for s in sentences]
    output_lengths = [len(s[1].split()) for s in sentences]
    
    # Estimate information content (simplified)
    input_info = [length * 40 for length in input_lengths]  # 40 info units per word
    
    # Information retention calculation
    retention_rates = [min(100, 100 * context_capacity / info) for info in input_info]
    
    # Plot 1: Information vs Context Capacity
    ax1.bar(range(len(sentences)), input_info, alpha=0.7, color='skyblue', label='Input Information')
    ax1.axhline(y=context_capacity, color='red', linestyle='--', linewidth=2, label='Context Vector Capacity')
    ax1.fill_between(range(len(sentences)), 0, context_capacity, alpha=0.3, color='red', label='Representable Information')
    
    ax1.set_xlabel('Sentence Complexity')
    ax1.set_ylabel('Information Units')
    ax1.set_title('Information Bottleneck Problem')
    ax1.set_xticks(range(len(sentences)))
    ax1.set_xticklabels([f'Sent {i+1}\n({l} words)' for i, l in enumerate(input_lengths)])
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Information Retention Rate
    colors = ['green' if r > 80 else 'orange' if r > 50 else 'red' for r in retention_rates]
    bars = ax2.bar(range(len(sentences)), retention_rates, color=colors, alpha=0.7)
    
    ax2.set_xlabel('Sentence Complexity')
    ax2.set_ylabel('Information Retained (%)')
    ax2.set_title('Information Retention Rate')
    ax2.set_xticks(range(len(sentences)))
    ax2.set_xticklabels([f'Sent {i+1}\n({l} words)' for i, l in enumerate(input_lengths)])
    ax2.set_ylim(0, 100)
    ax2.grid(True, alpha=0.3)
    
    # Add percentage labels on bars
    for i, (bar, rate) in enumerate(zip(bars, retention_rates)):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2, 
                f'{rate:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed analysis
    print("BOTTLENECK ANALYSIS")
    print("=" * 50)
    for i, ((eng, fr), info, retention) in enumerate(zip(sentences, input_info, retention_rates)):
        print(f"\nSentence {i+1}:")
        print(f"  English: '{eng}' ({len(eng.split())} words)")
        print(f"  French:  '{fr}' ({len(fr.split())} words)")
        print(f"  Information: {info} units")
        print(f"  Retention: {retention:.1f}%")
        if retention < 50:
            print("  ⚠️  SEVERE INFORMATION LOSS - Translation quality will suffer!")
        elif retention < 80:
            print("  ⚠️  MODERATE INFORMATION LOSS - Some details may be lost")
        else:
            print("  ✅ GOOD RETENTION - Should translate well")

demonstrate_bottleneck_problem()
