# Scaled Dot-Product Attention

## Paper Reference

**Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017).** *Attention Is All You Need.* Advances in Neural Information Processing Systems (NeurIPS). [arXiv:1706.03762](https://arxiv.org/abs/1706.03762)

---

## Mathematical Derivation

### Core Formula

The scaled dot-product attention is defined as:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q \in \mathbb{R}^{n \times d_k}$: Query matrix
- $K \in \mathbb{R}^{m \times d_k}$: Key matrix
- $V \in \mathbb{R}^{m \times d_v}$: Value matrix
- $d_k$: Dimension of keys/queries
- $n$: Number of queries
- $m$: Number of keys/values

### Step-by-Step Derivation

**Step 1: Compute Attention Scores**

$$S = QK^T \in \mathbb{R}^{n \times m}$$

Each element $S_{ij}$ represents the dot product similarity between query $i$ and key $j$:

$$S_{ij} = q_i \cdot k_j = \sum_{l=1}^{d_k} q_{il} \cdot k_{jl}$$

**Step 2: Apply Scaling**

$$\tilde{S} = \frac{S}{\sqrt{d_k}}$$

The scaling factor $\sqrt{d_k}$ is crucial. Without it, for large $d_k$, the dot products grow large in magnitude, pushing the softmax into regions with extremely small gradients.

**Why scaling?** Assuming $q$ and $k$ are independent random variables with mean 0 and variance 1:
- $\mathbb{E}[q \cdot k] = 0$
- $\text{Var}[q \cdot k] = d_k$

By dividing by $\sqrt{d_k}$, we normalize the variance to 1.

**Step 3: Apply Softmax**

$$A = \text{softmax}(\tilde{S})$$

Where softmax is applied row-wise:

$$A_{ij} = \frac{\exp(\tilde{S}_{ij})}{\sum_{l=1}^{m} \exp(\tilde{S}_{il})}$$

**Step 4: Weighted Sum of Values**

$$\text{Output} = AV \in \mathbb{R}^{n \times d_v}$$

Each output row is a weighted combination of value vectors:

$$\text{Output}_i = \sum_{j=1}^{m} A_{ij} \cdot v_j$$

### Gradient Derivation

Let $L$ be the loss function. We derive gradients for backpropagation.

**Gradient w.r.t. V:**

$$\frac{\partial L}{\partial V} = A^T \frac{\partial L}{\partial \text{Output}}$$

**Gradient w.r.t. A:**

$$\frac{\partial L}{\partial A} = \frac{\partial L}{\partial \text{Output}} V^T$$

**Gradient through softmax:**

For each row $i$, let $a = A_i$ (attention weights for query $i$):

$$\frac{\partial a_j}{\partial \tilde{s}_k} = a_j(\delta_{jk} - a_k)$$

Where $\delta_{jk}$ is the Kronecker delta.

**Gradient w.r.t. Q and K:**

$$\frac{\partial L}{\partial Q} = \frac{\partial L}{\partial \tilde{S}} \cdot \frac{K}{\sqrt{d_k}}$$

$$\frac{\partial L}{\partial K} = \frac{\partial L}{\partial \tilde{S}}^T \cdot \frac{Q}{\sqrt{d_k}}$$

---

## Complexity Analysis

| Operation | Time Complexity | Space Complexity |
|-----------|-----------------|------------------|
| $QK^T$ | $O(n \cdot m \cdot d_k)$ | $O(n \cdot m)$ |
| Softmax | $O(n \cdot m)$ | $O(n \cdot m)$ |
| $AV$ | $O(n \cdot m \cdot d_v)$ | $O(n \cdot d_v)$ |
| **Total** | $O(n \cdot m \cdot d)$ | $O(n \cdot m)$ |

For self-attention where $n = m$:
- **Time**: $O(n^2 \cdot d)$
- **Space**: $O(n^2)$ for storing attention weights

## PyTorch Implementation

In [None]:
import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('viridis')

In [None]:
class ScaledDotProductAttention(nn.Module):
    """Scaled Dot-Product Attention mechanism.
    
    Implements: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
    
    Attributes:
        d_model: Model dimension (used for scaling).
        dropout: Dropout layer for regularization.
    """
    
    def __init__(self, d_model: int, dropout: float = 0.0) -> None:
        """Initialize scaled dot-product attention.
        
        Args:
            d_model: The dimensionality of the model.
            dropout: Dropout probability.
        """
        super().__init__()
        self.d_model = d_model
        self.scale = 1.0 / math.sqrt(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute scaled dot-product attention.
        
        Args:
            query: Query tensor (batch, seq_len_q, d_model).
            key: Key tensor (batch, seq_len_k, d_model).
            value: Value tensor (batch, seq_len_k, d_model).
            mask: Optional mask where True indicates masked positions.
        
        Returns:
            output: Attended values (batch, seq_len_q, d_model).
            attention_weights: Attention distribution (batch, seq_len_q, seq_len_k).
        """
        # Step 1 & 2: Compute scaled attention scores
        # scores = Q @ K^T / sqrt(d_k)
        scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask.bool(), float("-inf"))
        
        # Step 3: Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = torch.nan_to_num(attention_weights, nan=0.0)
        
        # Apply dropout
        attention_weights = self.dropout(attention_weights)
        
        # Step 4: Weighted sum of values
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

## Demonstration

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Create sample data
batch_size = 1
seq_len = 8
d_model = 64

# Initialize attention module
attention = ScaledDotProductAttention(d_model=d_model)

# Create sample Q, K, V tensors
query = torch.randn(batch_size, seq_len, d_model)
key = torch.randn(batch_size, seq_len, d_model)
value = torch.randn(batch_size, seq_len, d_model)

# Forward pass
output, attention_weights = attention(query, key, value)

print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape: {value.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

## Visualization

In [None]:
# Visualize attention weights
fig, ax = plt.subplots(figsize=(8, 6))

weights = attention_weights[0].detach().numpy()

sns.heatmap(
    weights,
    ax=ax,
    cmap='viridis',
    annot=True,
    fmt='.2f',
    square=True,
    cbar_kws={'label': 'Attention Weight'}
)

ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_title('Scaled Dot-Product Attention Weights')

plt.tight_layout()
plt.show()

In [None]:
# Verify attention weights sum to 1
weight_sums = attention_weights.sum(dim=-1)
print(f"Attention weight row sums (should be 1.0):\n{weight_sums[0]}")

## Effect of Scaling Factor

In [None]:
# Demonstrate importance of scaling
d_k_values = [1, 16, 64, 256, 512]
fig, axes = plt.subplots(1, len(d_k_values), figsize=(16, 3))

for idx, d_k in enumerate(d_k_values):
    q = torch.randn(1, 4, d_k)
    k = torch.randn(1, 4, d_k)
    
    # Scores without scaling
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    weights = F.softmax(scores, dim=-1)
    
    sns.heatmap(
        weights[0].detach().numpy(),
        ax=axes[idx],
        cmap='viridis',
        vmin=0, vmax=1,
        square=True,
        cbar=False
    )
    axes[idx].set_title(f'd_k = {d_k}')
    axes[idx].set_xlabel('Key')
    if idx == 0:
        axes[idx].set_ylabel('Query')

plt.suptitle('Attention Weights with Proper Scaling (1/sqrt(d_k))')
plt.tight_layout()
plt.show()

## Comparison: When to Use Scaled Dot-Product Attention

| Aspect | Scaled Dot-Product | Alternative |
|--------|-------------------|-------------|
| **Simplicity** | Single-head, straightforward | Multi-head for more capacity |
| **Efficiency** | Baseline O(n^2) | Linear attention for long sequences |
| **Use Case** | Building block for other attention | Direct use in simple models |
| **Memory** | Full attention matrix | Flash attention for memory efficiency |

### Key Takeaways

1. **Foundation**: Scaled dot-product attention is the fundamental building block for all transformer attention mechanisms.

2. **Scaling is Critical**: Without the $1/\sqrt{d_k}$ scaling, gradients become vanishingly small for large $d_k$.

3. **Quadratic Complexity**: The $O(n^2)$ space complexity for storing attention weights becomes prohibitive for very long sequences.