# Cross-Attention

## Paper Reference

**Vaswani, A., et al. (2017).** *Attention Is All You Need.* NeurIPS. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762)

---

## Mathematical Derivation

### Core Concept

In cross-attention, queries come from one sequence (decoder) while keys and values come from another (encoder):

$$\text{CrossAttention}(Q_{dec}, K_{enc}, V_{enc}) = \text{softmax}\left(\frac{Q_{dec}K_{enc}^T}{\sqrt{d_k}}\right)V_{enc}$$

### Use Cases

1. **Machine Translation**: Decoder attends to source sentence
2. **Image Captioning**: Text decoder attends to image features
3. **Speech Recognition**: Text decoder attends to audio features

### Dimensions

- Query: $(batch, n, d_{model})$ where $n$ = decoder sequence length
- Key/Value: $(batch, m, d_{model})$ where $m$ = encoder sequence length
- Output: $(batch, n, d_{model})$
- Attention: $(batch, h, n, m)$

### Complexity

- Time: $O(n \cdot m \cdot d)$
- Space: $O(n \cdot m)$

In [None]:
import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
class CrossAttention(nn.Module):
    """Cross-Attention for encoder-decoder architectures."""
    
    def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.0) -> None:
        super().__init__()
        assert d_model % num_heads == 0
        
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.scale = 1.0 / math.sqrt(self.d_k)
        
        self.w_q = nn.Linear(d_model, d_model)  # From decoder
        self.w_k = nn.Linear(d_model, d_model)  # From encoder
        self.w_v = nn.Linear(d_model, d_model)  # From encoder
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
    
    def forward(
        self,
        decoder_hidden: torch.Tensor,
        encoder_output: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute cross-attention.
        
        Args:
            decoder_hidden: Decoder states (batch, n, d_model)
            encoder_output: Encoder outputs (batch, m, d_model)
            mask: Optional encoder padding mask
        """
        batch_size = decoder_hidden.size(0)
        seq_len_q = decoder_hidden.size(1)
        seq_len_k = encoder_output.size(1)
        
        q = self.w_q(decoder_hidden).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(encoder_output).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(encoder_output).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask.bool(), float("-inf"))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attended = torch.matmul(attn_weights, v)
        attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        
        return self.w_o(attended), attn_weights

In [None]:
# Simulate translation: source (encoder) and target (decoder) have different lengths
torch.manual_seed(42)

batch_size, d_model, num_heads = 1, 64, 4
encoder_len, decoder_len = 10, 6  # Different lengths

encoder_output = torch.randn(batch_size, encoder_len, d_model)
decoder_hidden = torch.randn(batch_size, decoder_len, d_model)

cross_attn = CrossAttention(d_model=d_model, num_heads=num_heads)
output, weights = cross_attn(decoder_hidden, encoder_output)

print(f"Encoder output shape: {encoder_output.shape}")
print(f"Decoder hidden shape: {decoder_hidden.shape}")
print(f"Cross-attention output: {output.shape}")
print(f"Attention weights: {weights.shape}")

In [None]:
# Visualize cross-attention
fig, ax = plt.subplots(figsize=(10, 6))

# Average across heads
avg_weights = weights[0].mean(dim=0).detach().numpy()

sns.heatmap(
    avg_weights,
    ax=ax,
    cmap='viridis',
    annot=True,
    fmt='.2f',
    xticklabels=[f'Enc_{i}' for i in range(encoder_len)],
    yticklabels=[f'Dec_{i}' for i in range(decoder_len)]
)
ax.set_xlabel('Encoder Position (Source)')
ax.set_ylabel('Decoder Position (Target)')
ax.set_title('Cross-Attention: How Decoder Attends to Encoder')

plt.tight_layout()
plt.show()

## When to Use Cross-Attention

| Task | Cross-Attention Role |
|------|---------------------|
| Translation | Align target words to source |
| Summarization | Focus on important source parts |
| Image Captioning | Attend to image regions |
| VQA | Query image with question |