In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("Ready to implement scaled dot-product attention!")


In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Scaled Dot-Product Attention as defined in "Attention Is All You Need"
    
    Args:
        query: [batch_size, seq_len_q, d_k]
        key: [batch_size, seq_len_k, d_k]  
        value: [batch_size, seq_len_k, d_v]
        mask: [batch_size, seq_len_q, seq_len_k] optional
    
    Returns:
        output: [batch_size, seq_len_q, d_v]
        attention_weights: [batch_size, seq_len_q, seq_len_k]
    """
    
    # Get dimensions
    d_k = query.size(-1)
    
    # Compute attention scores: QK^T
    scores = torch.matmul(query, key.transpose(-2, -1))
    
    # Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)
    
    # Apply mask if provided (set masked positions to large negative value)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply attention weights to values
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

# Test the implementation
def test_scaled_dot_product_attention():
    """Test scaled dot-product attention with simple examples"""
    
    batch_size = 2
    seq_len = 4
    d_k = 8
    d_v = 8
    
    # Create test tensors
    query = torch.randn(batch_size, seq_len, d_k)
    key = torch.randn(batch_size, seq_len, d_k) 
    value = torch.randn(batch_size, seq_len, d_v)
    
    # Forward pass
    output, attention_weights = scaled_dot_product_attention(query, key, value)
    
    print("SCALED DOT-PRODUCT ATTENTION TEST")
    print("=" * 40)
    print(f"Input shapes:")
    print(f"  Query: {query.shape}")
    print(f"  Key: {key.shape}")
    print(f"  Value: {value.shape}")
    print()
    print(f"Output shapes:")
    print(f"  Output: {output.shape}")
    print(f"  Attention weights: {attention_weights.shape}")
    print()
    
    # Verify attention weights sum to 1
    attention_sums = attention_weights.sum(dim=-1)
    print(f"Attention weights sum (should be 1.0):")
    print(f"  Min: {attention_sums.min().item():.6f}")
    print(f"  Max: {attention_sums.max().item():.6f}")
    print(f"  Mean: {attention_sums.mean().item():.6f}")
    
    return output, attention_weights

# Run the test
output, weights = test_scaled_dot_product_attention()
