# Masked Causal Attention

## Paper Reference

**Vaswani, A., et al. (2017).** *Attention Is All You Need.* NeurIPS. [arXiv:1706.03762](https://arxiv.org/abs/1706.03762)

---

## Mathematical Derivation

### Causal Masking

For autoregressive generation, each position can only attend to previous positions:

$$\text{Mask}_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}$$

$$\text{CausalAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + \text{Mask}\right)V$$

### Lower Triangular Mask

```
Position:  0  1  2  3
Query 0:  [1  0  0  0]  <- can only see position 0
Query 1:  [1  1  0  0]  <- can see positions 0, 1
Query 2:  [1  1  1  0]  <- can see positions 0, 1, 2
Query 3:  [1  1  1  1]  <- can see all positions
```

### Complexity

- Time: $O(n^2 \cdot d)$ (same as standard, but half the actual work)
- Space: $O(n^2)$

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
class CausalAttention(nn.Module):
    """Causal (Masked) Attention for autoregressive models."""
    
    def __init__(self, d_model: int, num_heads: int = 8, max_seq_len: int = 2048) -> None:
        super().__init__()
        assert d_model % num_heads == 0
        
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.scale = 1.0 / math.sqrt(self.d_k)
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        # Precompute causal mask
        mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
        self.register_buffer('causal_mask', mask)
    
    def forward(self, x: torch.Tensor):
        batch_size, seq_len = x.size(0), x.size(1)
        
        q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        scores = scores.masked_fill(self.causal_mask[:seq_len, :seq_len], float("-inf"))
        
        attn_weights = F.softmax(scores, dim=-1)
        
        attended = torch.matmul(attn_weights, v)
        attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.w_o(attended), attn_weights

In [None]:
# Demonstrate causal masking
torch.manual_seed(42)

batch_size, seq_len, d_model, num_heads = 1, 8, 64, 4
causal_attn = CausalAttention(d_model=d_model, num_heads=num_heads)

x = torch.randn(batch_size, seq_len, d_model)
output, weights = causal_attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

In [None]:
# Visualize causal attention pattern
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Show the causal mask
causal_mask = causal_attn.causal_mask[:seq_len, :seq_len].float().numpy()
sns.heatmap(1 - causal_mask, ax=axes[0], cmap='Blues', square=True, cbar=False,
            annot=True, fmt='.0f')
axes[0].set_title('Causal Mask (1=attend, 0=masked)')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Query Position')

# Show actual attention weights (averaged across heads)
avg_weights = weights[0].mean(dim=0).detach().numpy()
sns.heatmap(avg_weights, ax=axes[1], cmap='viridis', square=True, annot=True, fmt='.2f')
axes[1].set_title('Causal Attention Weights')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Query Position')

plt.tight_layout()
plt.show()

In [None]:
# Verify upper triangle is zero
for i in range(seq_len):
    for j in range(i + 1, seq_len):
        assert weights[0, :, i, j].sum() == 0, f"Position ({i}, {j}) should be masked"
print("Verified: All future positions are properly masked.")

## Use Cases for Causal Attention

| Model | Purpose |
|-------|--------|
| GPT | Language modeling |
| Decoder-only LLMs | Text generation |
| Autoregressive models | Any sequential prediction |