In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

print("Ready to implement cross-attention!")


In [None]:
class CrossAttention(nn.Module):
    """Cross-attention layer for encoder-decoder communication"""
    
    def __init__(self, d_model, num_heads=8):
        super(CrossAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """Scaled dot-product attention computation"""
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights
    
    def forward(self, query, key, value, mask=None):
        """
        Cross-attention forward pass
        
        Args:
            query: [batch_size, tgt_len, d_model] - from decoder
            key: [batch_size, src_len, d_model] - from encoder
            value: [batch_size, src_len, d_model] - from encoder
            mask: [batch_size, tgt_len, src_len] - attention mask
        
        Returns:
            output: [batch_size, tgt_len, d_model]
            attention_weights: [batch_size, num_heads, tgt_len, src_len]
        """
        batch_size, tgt_len, _ = query.size()
        src_len = key.size(1)
        
        # Linear projections
        Q = self.W_q(query)  # [batch_size, tgt_len, d_model]
        K = self.W_k(key)    # [batch_size, src_len, d_model]
        V = self.W_v(value)  # [batch_size, src_len, d_model]
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, tgt_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, src_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, src_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, tgt_len, self.d_model)
        
        # Final linear projection
        output = self.W_o(attention_output)
        
        return output, attention_weights

# Test cross-attention layer
def test_cross_attention():
    """Test the cross-attention implementation"""
    
    batch_size = 2
    src_len = 6  # encoder sequence length
    tgt_len = 4  # decoder sequence length  
    d_model = 128
    num_heads = 8
    
    # Create test tensors
    encoder_output = torch.randn(batch_size, src_len, d_model)  # Keys and Values
    decoder_input = torch.randn(batch_size, tgt_len, d_model)   # Queries
    
    # Initialize cross-attention layer
    cross_attn = CrossAttention(d_model, num_heads)
    
    # Forward pass
    output, attention_weights = cross_attn(decoder_input, encoder_output, encoder_output)
    
    print("CROSS-ATTENTION TEST")
    print("=" * 30)
    print(f"Encoder output (K,V): {encoder_output.shape}")
    print(f"Decoder input (Q): {decoder_input.shape}")
    print(f"Cross-attention output: {output.shape}")
    print(f"Attention weights: {attention_weights.shape}")
    print()
    
    # Check attention weight properties
    attn_sum = attention_weights.sum(dim=-1)
    print(f"Attention weights sum (should be ~1.0):")
    print(f"  Min: {attn_sum.min().item():.6f}")
    print(f"  Max: {attn_sum.max().item():.6f}")
    print(f"  Mean: {attn_sum.mean().item():.6f}")
    
    return output, attention_weights, cross_attn

# Run the test
output, attn_weights, cross_attn_layer = test_cross_attention()
