In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
from matplotlib.colors import LinearSegmentedColormap

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Custom colormap for attention visualization
colors = ['white', 'lightblue', 'blue', 'darkblue', 'navy']
attention_cmap = LinearSegmentedColormap.from_list('attention', colors)

print("Ready to explore cross-attention mechanisms!")


In [None]:
def compare_attention_types():
    """Compare self-attention and cross-attention conceptually"""
    
    # Example sentences
    english = "The cat sits on the mat"
    french = "Le chat est assis sur le tapis"
    
    en_words = english.split()
    fr_words = french.split()
    
    print("SELF-ATTENTION vs CROSS-ATTENTION COMPARISON")
    print("=" * 55)
    print(f"English: {english}")
    print(f"French:  {french}")
    print()
    
    # Self-attention explanation
    print("🔄 SELF-ATTENTION (within same sequence)")
    print("-" * 40)
    print("English Self-Attention:")
    print("  - How does 'cat' relate to 'sits', 'mat', etc.?")
    print("  - Captures syntactic and semantic relationships")
    print("  - Q, K, V all come from English sentence")
    print()
    
    print("French Self-Attention:")
    print("  - How does 'chat' relate to 'assis', 'tapis', etc.?")
    print("  - Captures French grammar and word relationships")
    print("  - Q, K, V all come from French sentence")
    print()
    
    # Cross-attention explanation
    print("🔀 CROSS-ATTENTION (between different sequences)")
    print("-" * 45)
    print("English-to-French Cross-Attention:")
    print("  - When generating 'chat', which English words to focus on?")
    print("  - Q comes from French (what we're generating)")
    print("  - K, V come from English (what we're translating from)")
    print("  - Enables translation alignment")
    print()
    
    # Create conceptual attention matrices
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # English self-attention (dummy data for visualization)
    en_self_attention = np.random.rand(len(en_words), len(en_words))
    en_self_attention = en_self_attention / en_self_attention.sum(axis=1, keepdims=True)
    
    im1 = ax1.imshow(en_self_attention, cmap=attention_cmap, aspect='auto')
    ax1.set_title('English Self-Attention\n(English words attending to English words)')
    ax1.set_xticks(range(len(en_words)))
    ax1.set_yticks(range(len(en_words)))
    ax1.set_xticklabels(en_words, rotation=45)
    ax1.set_yticklabels(en_words)
    ax1.set_xlabel('Attended Words (Keys)')
    ax1.set_ylabel('Attending Words (Queries)')
    
    # French self-attention
    fr_self_attention = np.random.rand(len(fr_words), len(fr_words))
    fr_self_attention = fr_self_attention / fr_self_attention.sum(axis=1, keepdims=True)
    
    im2 = ax2.imshow(fr_self_attention, cmap=attention_cmap, aspect='auto')
    ax2.set_title('French Self-Attention\n(French words attending to French words)')
    ax2.set_xticks(range(len(fr_words)))
    ax2.set_yticks(range(len(fr_words)))
    ax2.set_xticklabels(fr_words, rotation=45)
    ax2.set_yticklabels(fr_words)
    ax2.set_xlabel('Attended Words (Keys)')
    ax2.set_ylabel('Attending Words (Queries)')
    
    # Cross-attention (French attending to English)
    cross_attention = np.random.rand(len(fr_words), len(en_words))
    # Make it more realistic - similar words have higher attention
    cross_attention[1, 1] = 0.8  # chat -> cat
    cross_attention[0, 0] = 0.7  # Le -> The
    cross_attention[2, 2] = 0.6  # est -> sits (approximate)
    cross_attention = cross_attention / cross_attention.sum(axis=1, keepdims=True)
    
    im3 = ax3.imshow(cross_attention, cmap=attention_cmap, aspect='auto')
    ax3.set_title('Cross-Attention\n(French words attending to English words)')
    ax3.set_xticks(range(len(en_words)))
    ax3.set_yticks(range(len(fr_words)))
    ax3.set_xticklabels(en_words, rotation=45)
    ax3.set_yticklabels(fr_words)
    ax3.set_xlabel('English Words (Keys/Values)')
    ax3.set_ylabel('French Words (Queries)')
    
    # Information flow diagram
    ax4.text(0.5, 0.8, 'INFORMATION FLOW', ha='center', va='center', 
             fontsize=16, fontweight='bold', transform=ax4.transAxes)
    
    ax4.text(0.1, 0.6, 'Self-Attention:\nQ, K, V from\nsame sequence', 
             ha='left', va='center', fontsize=12, transform=ax4.transAxes,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue"))
    
    ax4.text(0.1, 0.3, 'Cross-Attention:\nQ from target\nK, V from source', 
             ha='left', va='center', fontsize=12, transform=ax4.transAxes,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
    
    ax4.set_xlim(0, 1)
    ax4.set_ylim(0, 1)
    ax4.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return en_self_attention, fr_self_attention, cross_attention

# Run the comparison
en_self, fr_self, cross = compare_attention_types()
