# 深入理解 Transformer (Understanding Transformer Internals)

本 notebook 對應李宏毅老師 2025 Spring ML HW3，深入探討 Transformer 的內部機制。

## 學習目標

1. 深入理解 Attention 機制
2. 視覺化 Attention Patterns
3. 了解 KV Cache 加速推理
4. 學習各種位置編碼方法
5. 探索 Transformer 的可解釋性

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
from typing import Optional, Tuple

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Self-Attention 深入解析

### 1.1 Scaled Dot-Product Attention

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
def scaled_dot_product_attention(
    query: torch.Tensor,  # [batch, heads, seq_q, d_k]
    key: torch.Tensor,    # [batch, heads, seq_k, d_k]
    value: torch.Tensor,  # [batch, heads, seq_v, d_v]
    mask: Optional[torch.Tensor] = None,
    return_attention_weights: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    Scaled Dot-Product Attention
    """
    d_k = query.size(-1)
    
    # 計算 attention scores: (Q @ K^T) / sqrt(d_k)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 應用 mask
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # 應用到 Value
    output = torch.matmul(attention_weights, value)
    
    if return_attention_weights:
        return output, attention_weights
    return output, None


# 測試
batch, heads, seq_len, d_k = 2, 4, 8, 64
Q = torch.randn(batch, heads, seq_len, d_k)
K = torch.randn(batch, heads, seq_len, d_k)
V = torch.randn(batch, heads, seq_len, d_k)

output, attn_weights = scaled_dot_product_attention(Q, K, V, return_attention_weights=True)
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"Attention weights sum (should be 1): {attn_weights[0, 0, 0].sum().item():.4f}")

In [None]:
# 視覺化 Attention Pattern
def visualize_attention(attention_weights, tokens=None, title="Attention Pattern"):
    """視覺化 attention 矩陣"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 取第一個 batch, 第一個 head
    attn = attention_weights[0, 0].detach().numpy()
    
    # Heatmap
    im = axes[0].imshow(attn, cmap='Blues')
    axes[0].set_title('Attention Weights')
    axes[0].set_xlabel('Key Position')
    axes[0].set_ylabel('Query Position')
    plt.colorbar(im, ax=axes[0])
    
    # Causal mask attention pattern
    seq_len = attn.shape[0]
    causal_mask = torch.tril(torch.ones(seq_len, seq_len))
    masked_scores = torch.randn(seq_len, seq_len)
    masked_scores = masked_scores.masked_fill(causal_mask == 0, float('-inf'))
    causal_attn = F.softmax(masked_scores, dim=-1).numpy()
    
    im2 = axes[1].imshow(causal_attn, cmap='Blues')
    axes[1].set_title('Causal (Decoder) Attention Pattern')
    axes[1].set_xlabel('Key Position')
    axes[1].set_ylabel('Query Position')
    plt.colorbar(im2, ax=axes[1])
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

visualize_attention(attn_weights, title="Self-Attention Visualization")

## 2. Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention with detailed implementation
    """
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 投影層
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        self.attention_weights = None  # 儲存以便視覺化
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 線性投影
        Q = self.W_q(query)  # [batch, seq, d_model]
        K = self.W_k(key)
        V = self.W_v(value)
        
        # 重塑為多頭格式
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        # 現在: [batch, heads, seq, d_k]
        
        # Attention
        output, attn_weights = scaled_dot_product_attention(
            Q, K, V, mask, return_attention_weights=True
        )
        self.attention_weights = attn_weights
        
        # 合併多頭
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 最終投影
        output = self.W_o(output)
        
        return output

# 測試
mha = MultiHeadAttention(d_model=256, num_heads=8)
x = torch.randn(2, 10, 256)
output = mha(x, x, x)
print(f"MHA output shape: {output.shape}")

## 3. KV Cache 加速推理

在自回歸生成時，KV Cache 避免重複計算已生成 token 的 Key 和 Value。

In [None]:
class MultiHeadAttentionWithKVCache(nn.Module):
    """
    帶 KV Cache 的 Multi-Head Attention
    """
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x, kv_cache=None, use_cache=False):
        """
        Args:
            x: 輸入 [batch, seq, d_model]
            kv_cache: (cached_k, cached_v) or None
            use_cache: 是否返回更新的 cache
        
        Returns:
            output, (new_k_cache, new_v_cache) if use_cache else output
        """
        batch_size, seq_len = x.size(0), x.size(1)
        
        # 計算 Q, K, V
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # 如果有 cache，concatenate
        if kv_cache is not None:
            cached_k, cached_v = kv_cache
            K = torch.cat([cached_k, K], dim=2)
            V = torch.cat([cached_v, V], dim=2)
        
        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        
        # 合併多頭
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(output)
        
        if use_cache:
            return output, (K, V)
        return output

# 示範 KV Cache
print("KV Cache 示範")
print("="*50)

mha_cache = MultiHeadAttentionWithKVCache(d_model=256, num_heads=8)

# 第一個 token (無 cache)
x1 = torch.randn(1, 1, 256)  # [batch=1, seq=1, d_model]
out1, cache = mha_cache(x1, kv_cache=None, use_cache=True)
print(f"Token 1 - Output: {out1.shape}, Cache K: {cache[0].shape}")

# 第二個 token (使用 cache)
x2 = torch.randn(1, 1, 256)
out2, cache = mha_cache(x2, kv_cache=cache, use_cache=True)
print(f"Token 2 - Output: {out2.shape}, Cache K: {cache[0].shape}")

# 第三個 token
x3 = torch.randn(1, 1, 256)
out3, cache = mha_cache(x3, kv_cache=cache, use_cache=True)
print(f"Token 3 - Output: {out3.shape}, Cache K: {cache[0].shape}")

## 4. 位置編碼 (Positional Encoding)

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    """原始 Transformer 的正弦位置編碼"""
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                             (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class RotaryPositionalEmbedding(nn.Module):
    """RoPE (Rotary Position Embedding) - 用於 LLaMA 等現代模型"""
    def __init__(self, d_model: int, max_len: int = 5000, base: int = 10000):
        super().__init__()
        
        inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)
        self.max_len = max_len
        
        # 預計算
        t = torch.arange(max_len).float()
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cache', emb.cos())
        self.register_buffer('sin_cache', emb.sin())
    
    def _rotate_half(self, x):
        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
        return torch.cat([-x2, x1], dim=-1)
    
    def forward(self, q, k, positions=None):
        seq_len = q.size(2)
        cos = self.cos_cache[:seq_len].unsqueeze(0).unsqueeze(0)
        sin = self.sin_cache[:seq_len].unsqueeze(0).unsqueeze(0)
        
        q_embed = (q * cos) + (self._rotate_half(q) * sin)
        k_embed = (k * cos) + (self._rotate_half(k) * sin)
        
        return q_embed, k_embed


# 視覺化位置編碼
def visualize_positional_encodings():
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # 正弦位置編碼
    pe = SinusoidalPositionalEncoding(d_model=128, max_len=100)
    pe_values = pe.pe[0, :50, :64].numpy()
    
    im1 = axes[0].imshow(pe_values, cmap='RdBu', aspect='auto')
    axes[0].set_title('Sinusoidal Positional Encoding')
    axes[0].set_xlabel('Dimension')
    axes[0].set_ylabel('Position')
    plt.colorbar(im1, ax=axes[0])
    
    # RoPE 旋轉角度
    rope = RotaryPositionalEmbedding(d_model=64, max_len=100)
    angles = rope.sin_cache[:50, :32].numpy()
    
    im2 = axes[1].imshow(angles, cmap='RdBu', aspect='auto')
    axes[1].set_title('RoPE Sin Component')
    axes[1].set_xlabel('Dimension')
    axes[1].set_ylabel('Position')
    plt.colorbar(im2, ax=axes[1])
    
    plt.tight_layout()
    plt.show()

visualize_positional_encodings()

## 5. 練習題

### 練習 1：實作 Flash Attention 的簡化版本

In [None]:
# 練習 1：理解 Flash Attention 的分塊計算概念
def chunked_attention(Q, K, V, chunk_size=32):
    """
    TODO: 實作分塊 attention（Flash Attention 的簡化概念）
    
    Flash Attention 的核心思想：
    1. 將 Q, K, V 分成小塊
    2. 逐塊計算 attention
    3. 正確處理 softmax 的數值穩定性
    
    提示：需要追蹤每個塊的 max 值和 sum 值來正確計算 softmax
    """
    pass

print("練習 1：實作 chunked_attention 函數")

### 練習 2：Attention 可解釋性分析

In [None]:
# 練習 2：分析 attention patterns
def analyze_attention_patterns(attention_weights, tokens):
    """
    TODO: 分析不同層和頭的 attention 模式
    
    分析項目：
    1. 找出「全局注意力」的頭（attend to 特定位置如 [CLS]）
    2. 找出「局部注意力」的頭（只關注附近 token）
    3. 找出「語法相關」的頭（可能關注相關詞彙）
    """
    pass

print("練習 2：實作 analyze_attention_patterns 函數")

## 6. 總結

### 關鍵概念

| 概念 | 說明 |
|------|------|
| Self-Attention | 序列中每個位置可以關注所有位置 |
| Multi-Head | 多個 attention 頭學習不同的關注模式 |
| KV Cache | 快取已計算的 K, V 加速自回歸生成 |
| 位置編碼 | 為 attention 提供位置資訊 |
| RoPE | 相對位置編碼，支援長序列外推 |

In [None]:
print("="*60)
print("深入理解 Transformer - 學習完成！")
print("="*60)
print("\n你已經學會：")
print("✓ Attention 機制的深入理解")
print("✓ Multi-Head Attention 實作")
print("✓ KV Cache 加速推理")
print("✓ 各種位置編碼方法")