In [None]:
# nb06_attention_transformer.ipynb
# Attention Mechanism & Transformer Block Implementation
# 注意力機制與 Transformer Block 實作

# =============================================================================
# Cell 1: Environment & Cache Setup
# =============================================================================

# === Shared Cache Bootstrap (English comments only) ===
import os, pathlib, torch

AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "/mnt/ai/cache")
paths = {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}
for k, v in paths.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)

print("[Cache] Root:", AI_CACHE_ROOT)
print(
    "[GPU]",
    torch.cuda.is_available(),
    torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
)

In [None]:
# =============================================================================
# Cell 2: Dependencies & Imports
# =============================================================================

import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# =============================================================================
# Cell 3: Scaled Dot-Product Attention Implementation
# =============================================================================


def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    """
    Scaled Dot-Product Attention 縮放點積注意力

    Args:
        query: [batch_size, seq_len, d_k] Query tensor
        key: [batch_size, seq_len, d_k] Key tensor
        value: [batch_size, seq_len, d_v] Value tensor
        mask: Optional attention mask
        dropout: Optional dropout layer

    Returns:
        output: [batch_size, seq_len, d_v] Attention output
        attention_weights: [batch_size, seq_len, seq_len] Attention weights
    """
    d_k = query.size(-1)

    # Compute attention scores: Q * K^T / sqrt(d_k)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Apply mask if provided (for padding tokens or causal attention)
    if mask is not None:
        scores.masked_fill_(mask == 0, -1e9)

    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)

    # Apply dropout if provided
    if dropout is not None:
        attention_weights = dropout(attention_weights)

    # Apply attention weights to values
    output = torch.matmul(attention_weights, value)

    return output, attention_weights


# Test scaled dot-product attention
batch_size, seq_len, d_model = 2, 8, 64
query = torch.randn(batch_size, seq_len, d_model)
key = torch.randn(batch_size, seq_len, d_model)
value = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(query, key, value)
print(f"Input shape: {query.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

In [None]:
# =============================================================================
# Cell 4: Multi-Head Attention Module
# =============================================================================


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention 多頭注意力機制

    Args:
        d_model: Model dimension (embedding size)
        num_heads: Number of attention heads
        dropout_rate: Dropout rate for attention weights
    """

    def __init__(self, d_model, num_heads, dropout_rate=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Output projection
        self.w_o = nn.Linear(d_model, d_model)

        # Dropout for attention weights
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len = query.size(1)

        # Linear projections and reshape for multi-head
        # [batch_size, seq_len, d_model] -> [batch_size, num_heads, seq_len, d_k]
        Q = (
            self.w_q(query)
            .view(batch_size, seq_len, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        K = (
            self.w_k(key)
            .view(batch_size, seq_len, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        V = (
            self.w_v(value)
            .view(batch_size, seq_len, self.num_heads, self.d_k)
            .transpose(1, 2)
        )

        # Adjust mask for multi-head if provided
        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)

        # Apply scaled dot-product attention for each head
        attention_output, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask=mask, dropout=self.dropout
        )

        # Concatenate heads and put through final linear layer
        # [batch_size, num_heads, seq_len, d_k] -> [batch_size, seq_len, d_model]
        attention_output = (
            attention_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_len, self.d_model)
        )

        output = self.w_o(attention_output)

        return output, attention_weights


# Test Multi-Head Attention
d_model, num_heads = 512, 8
mha = MultiHeadAttention(d_model, num_heads)

# Create sample input
batch_size, seq_len = 2, 10
x = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = mha(x, x, x)  # Self-attention
print(f"MHA Input shape: {x.shape}")
print(f"MHA Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

In [None]:
# =============================================================================
# Cell 5: Position Encoding (Sinusoidal)
# =============================================================================


class PositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding 正弦位置編碼

    Args:
        d_model: Model dimension
        max_seq_length: Maximum sequence length
        dropout_rate: Dropout rate
    """

    def __init__(self, d_model, max_seq_length=5000, dropout_rate=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)

        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)

        # Compute div_term for sinusoidal pattern
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # Apply sin to even indices and cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension and register as buffer
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Args:
            x: [seq_len, batch_size, d_model] or [batch_size, seq_len, d_model]
        """
        if x.dim() == 3 and x.size(0) != x.size(1):  # Assume [batch, seq, dim]
            x = x + self.pe[: x.size(1), :].transpose(0, 1)
        else:  # Assume [seq, batch, dim]
            x = x + self.pe[: x.size(0), :]
        return self.dropout(x)


# Test Positional Encoding
pos_enc = PositionalEncoding(d_model=512)
sample_input = torch.randn(2, 10, 512)  # [batch, seq, dim]
encoded = pos_enc(sample_input)
print(f"Positional encoding input: {sample_input.shape}")
print(f"Positional encoding output: {encoded.shape}")

# Visualize positional encoding pattern
plt.figure(figsize=(12, 8))
pe_vis = pos_enc.pe[:50, 0, :64].numpy()  # First 50 positions, first 64 dimensions
sns.heatmap(pe_vis.T, cmap="RdYlBu", center=0)
plt.title("Positional Encoding Pattern (位置編碼模式)")
plt.xlabel("Position (位置)")
plt.ylabel("Dimension (維度)")
plt.show()

In [None]:
# =============================================================================
# Cell 6: Transformer Block (Self-Attention + FFN)
# =============================================================================


class TransformerBlock(nn.Module):
    """
    Complete Transformer Block 完整的 Transformer Block

    Includes:
    - Multi-Head Self-Attention
    - Position-wise Feed-Forward Network
    - Residual connections and Layer Normalization
    """

    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
        super().__init__()

        # Multi-Head Attention
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)

        # Position-wise Feed-Forward Network
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
        )

        # Layer Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None, use_checkpoint=False):
        """
        Args:
            x: [batch_size, seq_len, d_model] Input tensor
            mask: Optional attention mask
            use_checkpoint: Use gradient checkpointing for memory efficiency
        """
        if use_checkpoint and self.training:
            return checkpoint(self._forward_impl, x, mask)
        else:
            return self._forward_impl(x, mask)

    def _forward_impl(self, x, mask=None):
        # Multi-Head Self-Attention with residual connection and layer norm
        attn_output, attn_weights = self.multi_head_attention(x, x, x, mask)
        x1 = self.norm1(x + self.dropout1(attn_output))

        # Feed-Forward with residual connection and layer norm
        ff_output = self.feed_forward(x1)
        x2 = self.norm2(x1 + self.dropout2(ff_output))

        return x2, attn_weights


# Test Transformer Block
transformer_block = TransformerBlock(d_model=512, num_heads=8, d_ff=2048)

# Create sample input
batch_size, seq_len = 2, 10
x = torch.randn(batch_size, seq_len, 512)

output, attn_weights = transformer_block(x)
print(f"Transformer Block Input: {x.shape}")
print(f"Transformer Block Output: {output.shape}")
print(f"Attention Weights: {attn_weights.shape}")

In [None]:
# =============================================================================
# Cell 7: Mini Transformer Model
# =============================================================================


class MiniTransformer(nn.Module):
    """
    Mini Transformer Model 迷你 Transformer 模型

    A simplified transformer for demonstration purposes
    """

    def __init__(
        self,
        vocab_size,
        d_model=512,
        num_heads=8,
        num_layers=6,
        d_ff=2048,
        max_seq_length=1000,
        dropout_rate=0.1,
    ):
        super().__init__()

        self.d_model = d_model

        # Token and positional embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(
            d_model, max_seq_length, dropout_rate
        )

        # Stack of transformer blocks
        self.transformer_blocks = nn.ModuleList(
            [
                TransformerBlock(d_model, num_heads, d_ff, dropout_rate)
                for _ in range(num_layers)
            ]
        )

        # Output layer
        self.layer_norm = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights following standard practices"""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(self, input_ids, attention_mask=None, use_checkpoint=False):
        """
        Args:
            input_ids: [batch_size, seq_len] Token indices
            attention_mask: [batch_size, seq_len] Attention mask
            use_checkpoint: Use gradient checkpointing for memory efficiency
        """
        # Token embeddings
        embeddings = self.token_embedding(input_ids) * math.sqrt(self.d_model)

        # Add positional encoding
        x = self.positional_encoding(embeddings)

        # Pass through transformer blocks
        attention_weights_list = []
        for transformer_block in self.transformer_blocks:
            x, attn_weights = transformer_block(x, attention_mask, use_checkpoint)
            attention_weights_list.append(attn_weights)

        # Final layer norm and output projection
        x = self.layer_norm(x)
        logits = self.output_projection(x)

        return logits, attention_weights_list


# Create and test mini transformer
vocab_size = 1000
model = MiniTransformer(vocab_size, d_model=256, num_heads=4, num_layers=2)

# Sample input
batch_size, seq_len = 2, 8
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

logits, all_attn_weights = model(input_ids)
print(f"Mini Transformer Input: {input_ids.shape}")
print(f"Mini Transformer Output: {logits.shape}")
print(f"Number of attention weight matrices: {len(all_attn_weights)}")

In [None]:
# =============================================================================
# Cell 8: Attention Visualization
# =============================================================================


def visualize_attention(attention_weights, tokens=None, layer_idx=0, head_idx=0):
    """
    Visualize attention weights 視覺化注意力權重

    Args:
        attention_weights: Attention weights from model
        tokens: Optional list of tokens for labeling
        layer_idx: Which layer to visualize
        head_idx: Which attention head to visualize
    """
    # Extract attention weights for specific layer and head
    # Shape: [batch_size, num_heads, seq_len, seq_len]
    attn = attention_weights[layer_idx][0, head_idx].detach().numpy()

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attn,
        annot=True,
        fmt=".2f",
        cmap="Blues",
        xticklabels=tokens,
        yticklabels=tokens,
    )
    plt.title(f"Attention Weights - Layer {layer_idx}, Head {head_idx}")
    plt.xlabel("Keys (被關注的位置)")
    plt.ylabel("Queries (查詢位置)")
    plt.show()


# Generate sample tokens for visualization
sample_tokens = [f"token_{i}" for i in range(seq_len)]

# Visualize attention from the first layer, first head
visualize_attention(all_attn_weights, sample_tokens, layer_idx=0, head_idx=0)

In [None]:
# =============================================================================
# Cell 9: Memory Optimization & Low-VRAM Tips
# =============================================================================


def get_model_memory_usage(model, input_shape):
    """Calculate approximate model memory usage"""
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    param_memory = total_params * 4 / (1024**2)  # Assuming float32, in MB

    # Estimate activation memory (rough approximation)
    batch_size, seq_len = input_shape[:2]
    activation_memory = (
        batch_size * seq_len * model.d_model * 4 / (1024**2) * 10
    )  # Rough estimate

    return {
        "parameters": total_params,
        "param_memory_mb": param_memory,
        "estimated_activation_mb": activation_memory,
        "total_estimated_mb": param_memory + activation_memory,
    }


# Check memory usage
memory_info = get_model_memory_usage(model, (2, 8))
print("Memory Usage Analysis:")
for key, value in memory_info.items():
    print(f"  {key}: {value:,.0f}")


# Low-VRAM configuration example
def create_low_vram_transformer(vocab_size):
    """Create a memory-efficient transformer configuration"""
    return MiniTransformer(
        vocab_size=vocab_size,
        d_model=128,  # Smaller model dimension
        num_heads=4,  # Fewer attention heads
        num_layers=2,  # Fewer layers
        d_ff=512,  # Smaller feed-forward dimension
        dropout_rate=0.1,
    )


low_vram_model = create_low_vram_transformer(vocab_size)
low_vram_memory = get_model_memory_usage(low_vram_model, (2, 8))
print("\nLow-VRAM Model Memory Usage:")
for key, value in low_vram_memory.items():
    print(f"  {key}: {value:,.0f}")


# Gradient checkpointing example
def forward_with_checkpointing(model, input_ids):
    """Example of using gradient checkpointing to save memory"""
    return model(input_ids, use_checkpoint=True)


# Memory-saving tips for training
print("\n=== Memory Optimization Tips ===")
print("1. Use gradient checkpointing: set use_checkpoint=True")
print("2. Reduce batch size: smaller batches = less memory")
print("3. Use mixed precision: torch.cuda.amp.autocast()")
print("4. Accumulate gradients: effective larger batch without memory cost")
print("5. Use smaller model dimensions: reduce d_model, num_heads, num_layers")

In [None]:
# =============================================================================
# Cell 10: Smoke Test & Validation
# =============================================================================


def run_transformer_smoke_test():
    """Basic smoke test to ensure everything works"""
    print("Running Transformer Smoke Test...")

    # Test 1: Basic attention mechanism
    try:
        q = k = v = torch.randn(1, 4, 8)
        out, weights = scaled_dot_product_attention(q, k, v)
        assert out.shape == (1, 4, 8), f"Expected (1,4,8), got {out.shape}"
        assert weights.shape == (1, 4, 4), f"Expected (1,4,4), got {weights.shape}"
        print("✓ Scaled dot-product attention works")
    except Exception as e:
        print(f"✗ Attention test failed: {e}")
        return False

    # Test 2: Multi-head attention
    try:
        mha = MultiHeadAttention(d_model=64, num_heads=4)
        x = torch.randn(1, 8, 64)
        out, weights = mha(x, x, x)
        assert out.shape == (1, 8, 64), f"Expected (1,8,64), got {out.shape}"
        print("✓ Multi-head attention works")
    except Exception as e:
        print(f"✗ Multi-head attention test failed: {e}")
        return False

    # Test 3: Transformer block
    try:
        block = TransformerBlock(d_model=64, num_heads=4, d_ff=128)
        x = torch.randn(1, 8, 64)
        out, weights = block(x)
        assert out.shape == (1, 8, 64), f"Expected (1,8,64), got {out.shape}"
        print("✓ Transformer block works")
    except Exception as e:
        print(f"✗ Transformer block test failed: {e}")
        return False

    # Test 4: Full model
    try:
        model = MiniTransformer(vocab_size=100, d_model=64, num_heads=4, num_layers=2)
        input_ids = torch.randint(0, 100, (1, 8))
        logits, all_weights = model(input_ids)
        assert logits.shape == (1, 8, 100), f"Expected (1,8,100), got {logits.shape}"
        assert (
            len(all_weights) == 2
        ), f"Expected 2 attention weight matrices, got {len(all_weights)}"
        print("✓ Full transformer model works")
    except Exception as e:
        print(f"✗ Full model test failed: {e}")
        return False

    print("🎉 All tests passed! Transformer implementation is working correctly.")
    return True


# Run the smoke test
success = run_transformer_smoke_test()

if success:
    print("\n=== Next Steps ===")
    print("1. Try modifying attention patterns with custom masks")
    print("2. Experiment with different positional encoding schemes")
    print("3. Test on real text data with tokenization")
    print("4. Move to next notebook: HF Datasets Pipeline (nb07)")

In [None]:
# Quick validation test
def quick_test():
    """5-line smoke test for nb06"""
    model = MiniTransformer(vocab_size=50, d_model=32, num_heads=2, num_layers=1)
    input_ids = torch.randint(0, 50, (1, 4))
    logits, weights = model(input_ids)
    assert logits.shape == (1, 4, 50) and len(weights) == 1
    print("✅ nb06 attention & transformer working!")


quick_test()

### **本章小結**

**✅ 完成項目**
- 從零實作 Scaled Dot-Product Attention 縮放點積注意力機制
- 構建 Multi-Head Attention 多頭注意力，支援並行計算多個表示子空間  
- 實現 Sinusoidal Positional Encoding 正弦位置編碼與視覺化
- 組裝完整 Transformer Block，包含殘差連接與層正規化
- 提供記憶體優化方案，支援梯度檢查點與低顯存配置

**🧠 核心原理要點**
- **注意力機制本質**：Query-Key-Value 三元組讓模型動態關注相關資訊
- **多頭注意力優勢**：並行處理多個表示子空間，捕獲不同類型的關係模式
- **位置編碼必要性**：Transformer 缺乏順序歸納偏置，需要顯式位置資訊
- **殘差連接重要性**：解決深層網路梯度消失問題，穩定訓練過程

**⚠️ 常見陷阱**
- 注意力權重記憶體消耗：O(n²) 複雜度，長序列時需要檢查記憶體
- 維度必須整除：`d_model` 必須被 `num_heads` 整除
- 遮罩維度對齊：多頭注意力時遮罩需要正確廣播
- 梯度爆炸：深層模型需要適當的學習率與初始化

**🚀 下一步建議**
1. **立即行動**：進入 nb07 學習 HF Datasets 多模態資料處理管線
2. **延伸實驗**：嘗試 Rotary Position Embedding (RoPE) 替代方案
3. **效能優化**：測試 Flash Attention 等高效實作
4. **應用場景**：準備將此基礎應用到實際 LLM 模型載入與推理

這個 notebook 為後續的 HF 模型操作與 LLM 應用打下了堅實的理論與實作基礎！