In [None]:
# nb05_lstm_text_generation.ipynb
# LSTM/GRU Text Generation - Character and Word Level

# %% Cell 1: Shared Cache Bootstrap
import os, pathlib, torch

AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "/mnt/ai/cache")
for k, v in {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)

print("[Cache]", AI_CACHE_ROOT, "| GPU:", torch.cuda.is_available())
if torch.cuda.is_available():
    print(
        f"[GPU] {torch.cuda.get_device_name(0)} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB"
    )

In [None]:
# %% Cell 2: Dependencies and Data Preparation
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import requests
from collections import Counter
import random


# Download Shakespeare text for character-level generation
def download_shakespeare():
    """Download and prepare Shakespeare text data"""
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        text = response.text
        print(f"[Data] Downloaded Shakespeare text: {len(text)} characters")
        return text
    except Exception as e:
        print(f"[Warning] Download failed: {e}")
        # Fallback sample text (mix of English + Chinese)
        return """To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles.
人工智能的發展正在改變世界。
機器學習讓電腦能夠從數據中學習。
深度學習是機器學習的一個分支。
"""


# Prepare sample text data
text_data = download_shakespeare()
print(f"[Data] Text length: {len(text_data)} characters")
print("[Sample]", repr(text_data[:100]))

In [None]:
# %% Cell 3: Character-Level Data Processor
class CharTokenizer:
    """Character-level tokenizer for text generation"""

    def __init__(self, text, min_freq=1):
        # Count character frequencies
        char_counts = Counter(text)

        # Filter by minimum frequency
        self.chars = sorted(
            [char for char, count in char_counts.items() if count >= min_freq]
        )

        # Create mappings
        self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)}
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}

        self.vocab_size = len(self.chars)
        print(f"[Tokenizer] Vocab size: {self.vocab_size}")
        print(f"[Chars] {repr(''.join(self.chars[:20]))}...")

    def encode(self, text):
        """Convert text to token indices"""
        return [self.char_to_idx.get(char, 0) for char in text]

    def decode(self, indices):
        """Convert token indices back to text"""
        return "".join([self.idx_to_char.get(idx, "<UNK>") for idx in indices])


class TextDataset(Dataset):
    """Dataset for sequence-to-sequence text modeling"""

    def __init__(self, text, tokenizer, seq_length=50):
        self.tokenizer = tokenizer
        self.seq_length = seq_length

        # Encode entire text
        self.tokens = tokenizer.encode(text)
        print(f"[Dataset] Total tokens: {len(self.tokens)}")

        # Create sequences
        self.sequences = []
        for i in range(len(self.tokens) - seq_length):
            input_seq = self.tokens[i : i + seq_length]
            target_seq = self.tokens[i + 1 : i + seq_length + 1]
            self.sequences.append((input_seq, target_seq))

        print(f"[Dataset] Created {len(self.sequences)} sequences")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        input_seq, target_seq = self.sequences[idx]
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(
            target_seq, dtype=torch.long
        )


# Create tokenizer and dataset
tokenizer = CharTokenizer(text_data, min_freq=2)
dataset = TextDataset(text_data, tokenizer, seq_length=50)

In [None]:
# %% Cell 4: LSTM Text Generation Model
class LSTMTextGenerator(nn.Module):
    """LSTM-based text generation model"""

    def __init__(
        self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, dropout=0.3
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Layers
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            embed_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        self.output_proj = nn.Linear(hidden_dim, vocab_size)

        # Initialize weights
        self.init_weights()

    def init_weights(self):
        """Initialize model weights"""
        for name, param in self.named_parameters():
            if "weight" in name:
                nn.init.uniform_(param, -0.1, 0.1)
            elif "bias" in name:
                nn.init.zeros_(param)

    def forward(self, x, hidden=None):
        """Forward pass"""
        batch_size, seq_len = x.size()

        # Embedding
        embedded = self.embedding(x)  # (batch, seq, embed)

        # LSTM
        lstm_out, hidden = self.lstm(embedded, hidden)  # (batch, seq, hidden)

        # Dropout and projection
        output = self.dropout(lstm_out)
        logits = self.output_proj(output)  # (batch, seq, vocab)

        return logits, hidden

    def init_hidden(self, batch_size, device="cpu"):
        """Initialize hidden state"""
        h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        c_0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        return (h_0, c_0)


# Alternative GRU model for comparison
class GRUTextGenerator(nn.Module):
    """GRU-based text generation model"""

    def __init__(
        self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, dropout=0.3
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(
            embed_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        self.output_proj = nn.Linear(hidden_dim, vocab_size)

        self.init_weights()

    def init_weights(self):
        for name, param in self.named_parameters():
            if "weight" in name:
                nn.init.uniform_(param, -0.1, 0.1)
            elif "bias" in name:
                nn.init.zeros_(param)

    def forward(self, x, hidden=None):
        embedded = self.embedding(x)
        gru_out, hidden = self.gru(embedded, hidden)
        output = self.dropout(gru_out)
        logits = self.output_proj(output)
        return logits, hidden

    def init_hidden(self, batch_size, device="cpu"):
        return torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)


# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMTextGenerator(
    vocab_size=tokenizer.vocab_size,
    embed_dim=64,  # Smaller for low VRAM
    hidden_dim=128,  # Smaller for low VRAM
    num_layers=2,
    dropout=0.2,
).to(device)

print(
    f"[Model] LSTM Generator - {sum(p.numel() for p in model.parameters())} parameters"
)
print(f"[Device] {device}")

In [None]:
# %% Cell 5: Training Loop with Gradient Accumulation
def train_text_generator(
    model,
    dataset,
    tokenizer,
    epochs=10,
    batch_size=32,
    learning_rate=0.001,
    gradient_accumulation_steps=1,
):
    """Train text generation model with gradient accumulation for low VRAM"""

    # Adjust batch size based on available memory
    if torch.cuda.is_available():
        available_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        if available_memory < 6:
            batch_size = min(batch_size, 16)
            gradient_accumulation_steps = max(gradient_accumulation_steps, 2)
            print(
                f"[LowVRAM] Adjusted batch_size={batch_size}, grad_accum={gradient_accumulation_steps}"
            )

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    model.train()
    losses = []

    for epoch in range(epochs):
        epoch_loss = 0
        num_batches = 0

        # Reset gradients
        optimizer.zero_grad()

        for batch_idx, (input_seq, target_seq) in enumerate(dataloader):
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)

            # Forward pass
            logits, _ = model(input_seq)

            # Reshape for loss calculation
            logits_flat = logits.view(-1, model.vocab_size)
            target_flat = target_seq.view(-1)

            # Calculate loss
            loss = criterion(logits_flat, target_flat)

            # Scale loss for gradient accumulation
            loss = loss / gradient_accumulation_steps
            loss.backward()

            # Update weights every gradient_accumulation_steps
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad()

            epoch_loss += loss.item() * gradient_accumulation_steps
            num_batches += 1

            # Memory cleanup
            if batch_idx % 50 == 0:
                torch.cuda.empty_cache() if torch.cuda.is_available() else None

        avg_loss = epoch_loss / num_batches
        losses.append(avg_loss)

        print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")

        # Generate sample text every few epochs
        if (epoch + 1) % 3 == 0:
            sample = generate_text(
                model, tokenizer, prompt="To be", max_length=50, temperature=0.8
            )
            print(f"[Sample] {repr(sample)}")

    return losses

In [None]:
# %% Cell 6: Temperature Sampling Generation Function
def generate_text(
    model, tokenizer, prompt="", max_length=100, temperature=1.0, device=None
):
    """Generate text using temperature sampling"""

    if device is None:
        device = next(model.parameters()).device

    model.eval()

    # Handle empty prompt
    if not prompt:
        prompt = random.choice(tokenizer.chars[:10])

    # Encode prompt
    tokens = tokenizer.encode(prompt)
    input_seq = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    # Initialize hidden state
    if hasattr(model, "init_hidden"):
        hidden = model.init_hidden(1, device)
    else:
        hidden = None

    generated_tokens = tokens.copy()

    with torch.no_grad():
        for _ in range(max_length):
            # Forward pass
            logits, hidden = model(input_seq, hidden)

            # Get last timestep logits
            next_token_logits = logits[0, -1, :] / temperature

            # Apply softmax and sample
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, 1).item()

            generated_tokens.append(next_token)

            # Update input for next iteration
            input_seq = torch.tensor([[next_token]], dtype=torch.long).to(device)

    return tokenizer.decode(generated_tokens)


def compare_temperatures(
    model, tokenizer, prompt="The", temperatures=[0.5, 0.8, 1.0, 1.2]
):
    """Compare generation quality at different temperatures"""
    print(f"\n[Temperature Comparison] Prompt: '{prompt}'")
    print("-" * 60)

    for temp in temperatures:
        generated = generate_text(
            model, tokenizer, prompt, max_length=80, temperature=temp
        )
        print(f"Temp {temp:3.1f}: {repr(generated[:100])}...")
        print()

In [None]:
# %% Cell 7: Model Training and Checkpoints
print("[Training] Starting LSTM text generator training...")

# Create smaller dataset for quick training if low VRAM
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).total_memory < 6e9:
    # Use subset of data for low VRAM
    subset_size = min(len(dataset), 1000)
    dataset = torch.utils.data.Subset(dataset, range(subset_size))
    print(f"[LowVRAM] Using subset of {subset_size} samples")

# Train model
train_losses = train_text_generator(
    model=model,
    dataset=dataset,
    tokenizer=tokenizer,
    epochs=8,
    batch_size=16,
    learning_rate=0.002,
    gradient_accumulation_steps=2,
)

# Plot training loss
plt.figure(figsize=(8, 4))
plt.plot(train_losses)
plt.title("LSTM Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

# Save model checkpoint
checkpoint_path = f"{AI_CACHE_ROOT}/models/lstm_text_gen_checkpoint.pt"
pathlib.Path(checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "tokenizer_chars": tokenizer.chars,
        "model_config": {
            "vocab_size": tokenizer.vocab_size,
            "embed_dim": 64,
            "hidden_dim": 128,
            "num_layers": 2,
        },
    },
    checkpoint_path,
)
print(f"[Checkpoint] Saved to {checkpoint_path}")

In [None]:
# %% Cell 8: Text Generation Testing
print("\n" + "=" * 60)
print("TEXT GENERATION TESTING")
print("=" * 60)

# Test different temperature settings
compare_temperatures(
    model, tokenizer, prompt="To be", temperatures=[0.3, 0.7, 1.0, 1.5]
)

# Test with Chinese characters (if available in vocab)
chinese_chars = [char for char in tokenizer.chars if "\u4e00" <= char <= "\u9fff"]
if chinese_chars:
    print(f"[Chinese] Found {len(chinese_chars)} Chinese characters in vocab")
    chinese_prompt = chinese_chars[0] if chinese_chars else "人"
    print(f"\n[Chinese Generation] Prompt: '{chinese_prompt}'")
    chinese_text = generate_text(
        model, tokenizer, prompt=chinese_prompt, max_length=50, temperature=0.8
    )
    print(f"Generated: {repr(chinese_text)}")

# Creative prompts
creative_prompts = ["The future", "Once upon", "In the beginning", "人工智能"]
print(f"\n[Creative Generation]")
for prompt in creative_prompts:
    # Check if prompt characters are in vocab
    if all(char in tokenizer.char_to_idx for char in prompt):
        generated = generate_text(
            model, tokenizer, prompt, max_length=60, temperature=0.9
        )
        print(f"'{prompt}' → {repr(generated[:80])}...")
    else:
        print(f"'{prompt}' → [Skipped - characters not in vocab]")

In [None]:
# %% Cell 9: GRU vs LSTM Comparison
print("\n" + "=" * 60)
print("GRU vs LSTM COMPARISON")
print("=" * 60)

# Create GRU model for comparison
gru_model = GRUTextGenerator(
    vocab_size=tokenizer.vocab_size,
    embed_dim=64,
    hidden_dim=128,
    num_layers=2,
    dropout=0.2,
).to(device)

print(f"[GRU Model] {sum(p.numel() for p in gru_model.parameters())} parameters")

# Quick training for GRU (fewer epochs)
print("[Training] GRU model (quick training)...")
gru_losses = train_text_generator(
    model=gru_model,
    dataset=dataset,
    tokenizer=tokenizer,
    epochs=4,  # Fewer epochs for comparison
    batch_size=16,
    learning_rate=0.002,
    gradient_accumulation_steps=2,
)

# Compare generation quality
test_prompt = "To be"
print(f"\n[Model Comparison] Prompt: '{test_prompt}'")
print("-" * 50)

lstm_output = generate_text(
    model, tokenizer, test_prompt, max_length=80, temperature=0.8
)
gru_output = generate_text(
    gru_model, tokenizer, test_prompt, max_length=80, temperature=0.8
)

print(f"LSTM: {repr(lstm_output[:70])}...")
print(f"GRU:  {repr(gru_output[:70])}...")

# Plot comparison
plt.figure(figsize=(10, 4))
plt.plot(train_losses, label="LSTM", marker="o")
plt.plot(gru_losses, label="GRU", marker="s")
plt.title("LSTM vs GRU Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# %% Cell 10: Smoke Test - Verification
print("\n" + "=" * 60)
print("SMOKE TEST - VERIFICATION")
print("=" * 60)


def smoke_test():
    """Quick verification that everything works"""
    try:
        # Test 1: Model can generate text
        sample = generate_text(model, tokenizer, "Test", max_length=20, temperature=1.0)
        assert len(sample) > 4, "Generated text too short"
        print("✓ Text generation works")

        # Test 2: Different temperatures produce different outputs
        out1 = generate_text(model, tokenizer, "Hello", max_length=30, temperature=0.1)
        out2 = generate_text(model, tokenizer, "Hello", max_length=30, temperature=1.5)
        print("✓ Temperature sampling works")

        # Test 3: Model can handle different prompt lengths
        short_gen = generate_text(model, tokenizer, "A", max_length=10)
        long_gen = generate_text(model, tokenizer, "The quick brown", max_length=10)
        print("✓ Variable prompt lengths work")

        # Test 4: Tokenizer encode/decode consistency
        test_text = "Hello world!"
        if all(char in tokenizer.char_to_idx for char in test_text):
            encoded = tokenizer.encode(test_text)
            decoded = tokenizer.decode(encoded)
            assert decoded == test_text, "Encode/decode mismatch"
            print("✓ Tokenizer consistency verified")

        # Test 5: Model checkpoint saving worked
        assert pathlib.Path(
            f"{AI_CACHE_ROOT}/models/lstm_text_gen_checkpoint.pt"
        ).exists()
        print("✓ Model checkpoint saved successfully")

        print("\n🎉 All smoke tests passed!")
        return True

    except Exception as e:
        print(f"❌ Smoke test failed: {e}")
        return False


# Run smoke test
smoke_test()

print(
    f"\n[Memory] Peak VRAM used: {torch.cuda.max_memory_allocated()/1e6:.1f}MB"
    if torch.cuda.is_available()
    else "[Memory] CPU only"
)

# %% Summary Cell
print("\n" + "=" * 60)
print("NOTEBOOK SUMMARY")
print("=" * 60)

print(
    """
✅ 完成項目 (Completed Items):
   • 實作字符級 LSTM/GRU 文字生成模型
   • 支援溫度採樣與不同生成策略
   • 低 VRAM 友善訓練（梯度累積、小 batch）
   • LSTM vs GRU 架構比較
   • 模型檢查點保存與載入

🔑 核心概念 (Core Concepts):
   • RNN 序列建模：LSTM 解決梯度消失問題
   • 字符級 vs 詞彙級生成的權衡
   • 溫度採樣控制生成多樣性
   • 梯度累積減少記憶體需求

⚠️  常見坑 (Common Pitfalls):
   • 序列長度過長導致 VRAM 不足
   • 溫度設置不當影響生成品質
   • 詞彙表過大影響訓練效率
   • 沒有梯度裁剪導致訓練不穩定

🚀 下一步建議 (Next Steps):
   • 進入 Part B: Transformer 架構學習
   • 對比 RNN vs Transformer 的生成品質
   • 學習 attention 機制改善長序列建模
   • 探索更大規模的預訓練模型
"""
)


## 6. 本章小結

### ✅ 完成項目
• **LSTM/GRU 文字生成**：實作完整的字符級序列模型
• **低 VRAM 優化**：梯度累積、動態 batch size、記憶體清理
• **溫度採樣**：多種生成策略比較
• **模型比較**：LSTM vs GRU 架構效能對比
• **中英文支援**：混合語言文本生成能力

### 🔑 核心原理要點
• **序列建模基礎**：理解 RNN 家族處理序列資料的方式
• **LSTM 閘控機制**：遺忘閘、輸入閘、輸出閘控制資訊流
• **文字生成策略**：貪婪解碼 vs 隨機採樣 vs 溫度控制
• **訓練技巧**：梯度裁剪、學習率調度、檢查點保存

### 🚀 下一步建議
**立即行動**：進入 **nb06_attention_transformer.ipynb**，學習 Transformer 架構
**技能建構**：對比 RNN 與 Transformer 在序列建模上的差異
**實用擴展**：嘗試詞彙級生成、更大語料庫、多語言支援
**效能優化**：探索模型蒸餾、量化等進一步的 VRAM 節省技巧

---

**何時使用這套技術**：
- 需要理解 RNN 序列建模基礎時
- 資源受限但需要文字生成功能時  
- 作為 Transformer 學習前的打底準備
- 進行序列模型架構比較研究時