# Multi-Head Attention 구현

PyTorch로 Scaled Dot-Product Attention과 Multi-Head Attention을 직접 구현한다.

## 학습 목표
- Attention 수식을 코드로 변환
- Multi-Head의 병렬 처리 이해
- Attention weights 시각화

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import math

torch.manual_seed(42)

## 1. Scaled Dot-Product Attention

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled Dot-Product Attention
    
    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
    
    Args:
        Q: Query tensor (batch, ..., seq_len, d_k)
        K: Key tensor (batch, ..., seq_len, d_k)
        V: Value tensor (batch, ..., seq_len, d_v)
        mask: Optional mask tensor
    
    Returns:
        output: Attention output (batch, ..., seq_len, d_v)
        attention_weights: (batch, ..., seq_len, seq_len)
    """
    d_k = K.size(-1)
    
    # Step 1: Q @ K^T
    scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / math.sqrt(d_k)
    
    # Step 3: Apply mask (optional)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 4: Softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5: @ V
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

In [None]:
# 테스트
batch_size = 2
seq_len = 4
d_k = 8
d_v = 8

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)

output, attn_weights = scaled_dot_product_attention(Q, K, V)

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nAttention weights (첫 번째 배치):")
print(attn_weights[0])
print(f"\n각 행의 합: {attn_weights[0].sum(dim=-1)}")

## 2. Attention Weights 시각화

In [None]:
def visualize_attention(attention_weights, tokens=None, title="Attention Weights"):
    """
    Attention weights 히트맵 시각화
    """
    plt.figure(figsize=(8, 6))
    
    if tokens is None:
        tokens = [f"pos_{i}" for i in range(attention_weights.shape[0])]
    
    sns.heatmap(
        attention_weights.detach().numpy(),
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='Blues',
        annot=True,
        fmt='.2f',
        square=True
    )
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.title(title)
    plt.tight_layout()
    plt.show()

# 시각화
tokens = ["I", "love", "deep", "learning"]
visualize_attention(attn_weights[0], tokens)

## 3. Causal Mask (Decoder용)

In [None]:
def create_causal_mask(seq_len):
    """
    Lower triangular mask for causal (autoregressive) attention
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# Causal mask 시각화
causal_mask = create_causal_mask(4)
print("Causal Mask:")
print(causal_mask)

# Causal attention
output_causal, attn_weights_causal = scaled_dot_product_attention(Q, K, V, causal_mask)

visualize_attention(attn_weights_causal[0], tokens, "Causal Attention")

## 4. Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention 구현
    
    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O
    where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
    """
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x, batch_size):
        """
        (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
        """
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)
    
    def combine_heads(self, x, batch_size):
        """
        (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
        """
        x = x.transpose(1, 2).contiguous()
        return x.view(batch_size, -1, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 1. Linear projection
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        
        # 2. Split into heads
        Q = self.split_heads(Q, batch_size)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)
        
        # 3. Scaled dot-product attention
        if mask is not None:
            mask = mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dims
        
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # 4. Combine heads
        attn_output = self.combine_heads(attn_output, batch_size)
        
        # 5. Final linear projection
        output = self.W_o(attn_output)
        output = self.dropout(output)
        
        return output, attn_weights

In [None]:
# Multi-Head Attention 테스트
d_model = 64
num_heads = 8
seq_len = 4
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads)

# 입력 생성
x = torch.randn(batch_size, seq_len, d_model)

# Self-Attention (Q=K=V=x)
output, attn_weights = mha(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

In [None]:
# 각 Head의 Attention 시각화
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
tokens = ["I", "love", "deep", "learning"]

for head_idx in range(num_heads):
    ax = axes[head_idx // 4, head_idx % 4]
    
    sns.heatmap(
        attn_weights[0, head_idx].detach().numpy(),
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='Blues',
        ax=ax,
        annot=True,
        fmt='.2f'
    )
    ax.set_title(f'Head {head_idx}')

plt.suptitle('Multi-Head Attention Weights', fontsize=14)
plt.tight_layout()
plt.show()

## 5. 파라미터 수 확인

In [None]:
def count_parameters(model):
    """모델의 학습 가능한 파라미터 수 계산"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Multi-Head Attention 파라미터
mha_params = count_parameters(mha)
print(f"Multi-Head Attention 파라미터 수: {mha_params:,}")
print(f"\n이론적 계산:")
print(f"  W_q: {d_model} x {d_model} = {d_model * d_model}")
print(f"  W_k: {d_model} x {d_model} = {d_model * d_model}")
print(f"  W_v: {d_model} x {d_model} = {d_model * d_model}")
print(f"  W_o: {d_model} x {d_model} = {d_model * d_model}")
print(f"  Biases: 4 x {d_model} = {4 * d_model}")
print(f"  Total: {4 * d_model * d_model + 4 * d_model}")

## 6. PyTorch 내장 구현과 비교

In [None]:
# PyTorch 내장 MultiheadAttention
pytorch_mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)

# 테스트
with torch.no_grad():
    pytorch_output, pytorch_weights = pytorch_mha(x, x, x)

print(f"PyTorch MHA output shape: {pytorch_output.shape}")
print(f"PyTorch MHA weights shape: {pytorch_weights.shape}")

## 7. Scaled Dot-Product의 스케일링 효과

In [None]:
# 스케일링의 중요성 시각화
d_k_values = [8, 64, 512]

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for idx, d_k in enumerate(d_k_values):
    Q = torch.randn(1, 4, d_k)
    K = torch.randn(1, 4, d_k)
    
    # 스케일링 없이
    scores_no_scale = torch.matmul(Q, K.transpose(-2, -1))
    attn_no_scale = F.softmax(scores_no_scale, dim=-1)
    
    # 스케일링 적용
    scores_scaled = scores_no_scale / math.sqrt(d_k)
    attn_scaled = F.softmax(scores_scaled, dim=-1)
    
    # 스케일링 없음
    axes[0, idx].hist(attn_no_scale[0].flatten().numpy(), bins=20)
    axes[0, idx].set_title(f'd_k={d_k}, No Scaling')
    axes[0, idx].set_xlabel('Attention Weight')
    
    # 스케일링 적용
    axes[1, idx].hist(attn_scaled[0].flatten().numpy(), bins=20)
    axes[1, idx].set_title(f'd_k={d_k}, With Scaling')
    axes[1, idx].set_xlabel('Attention Weight')

plt.suptitle('Effect of Scaling on Attention Distribution', fontsize=14)
plt.tight_layout()
plt.show()

print("스케일링 없이 d_k가 크면 softmax가 극단적으로 뾰족해짐 (대부분 0 또는 1)")

## 8. 정리

### 핵심 구현 포인트

1. **Scaled Dot-Product Attention**
   - `scores = Q @ K^T / sqrt(d_k)`
   - `attention = softmax(scores)`
   - `output = attention @ V`

2. **Multi-Head Attention**
   - Linear projection: `Q = xW_q`, `K = xW_k`, `V = xW_v`
   - Split into heads: reshape to `(batch, heads, seq_len, d_k)`
   - Apply attention per head
   - Combine and project: `output = concat(heads)W_o`

3. **Masking**
   - Causal mask: `torch.tril()` for autoregressive
   - Padding mask: `mask == 0` → `-inf`