In [None]:
# Generation Strategies & Parameter Tuning
# 生成策略與參數調優
# ================================================================

"""
Learning Goals (學習目標):
1. Master text generation strategies: Temperature, Top-k, Top-p, repetition penalty
2. Implement interactive parameter tuning with real-time feedback
3. Build evaluation framework for generation quality assessment
4. Support low-VRAM environments with 4-bit quantization
5. Compare optimal parameters for Chinese vs English generation
"""

# === Shared Cache Bootstrap (English comments only) ===
import os, torch, platform, pathlib

AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "/mnt/ai/cache")
paths = {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}
for k, v in paths.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)

print("[Cache] Root:", AI_CACHE_ROOT)
print(
    "[GPU]",
    torch.cuda.is_available(),
    torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
)

In [None]:
# === Cell 1: Environment Setup ===
import torch
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    set_seed,
)
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import re
import time
from typing import Dict, List, Optional, Tuple
import warnings

warnings.filterwarnings("ignore")

# Set random seeds for reproducibility
set_seed(42)
torch.manual_seed(42)
np.random.seed(42)

print("🔧 Environment setup complete")

In [None]:
# === Cell 2: Model & Tokenizer Loading (Low-VRAM Friendly) ===


class GenerationTester:
    """Unified model loader with low-VRAM support"""

    def __init__(
        self,
        model_name: str = "Qwen/Qwen2.5-1.5B-Instruct",
        use_4bit: bool = None,
        device_map: str = "auto",
    ):
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Auto-detect 4bit need based on VRAM
        if use_4bit is None:
            if torch.cuda.is_available():
                gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
                use_4bit = gpu_memory < 12  # Use 4bit if < 12GB VRAM
            else:
                use_4bit = False

        self.use_4bit = use_4bit
        self._load_model_and_tokenizer(device_map)

    def _load_model_and_tokenizer(self, device_map: str):
        """Load model with optimal configuration"""
        print(f"📥 Loading {self.model_name}...")
        print(f"🔧 4-bit quantization: {self.use_4bit}")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, trust_remote_code=True, padding_side="left"
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Configure model loading
        model_kwargs = {
            "trust_remote_code": True,
            "torch_dtype": (
                torch.float16 if torch.cuda.is_available() else torch.float32
            ),
            "device_map": device_map if torch.cuda.is_available() else None,
        }

        if self.use_4bit and torch.cuda.is_available():
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
            )
            model_kwargs["quantization_config"] = quantization_config

        # Load model
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name, **model_kwargs
        )

        if not torch.cuda.is_available():
            self.model = self.model.to(self.device)

        print(f"✅ Model loaded successfully")
        if torch.cuda.is_available():
            print(f"💾 VRAM usage: {torch.cuda.memory_allocated()/1e9:.2f}GB")


# Initialize model (fallback chain for compatibility)
model_candidates = [
    "Qwen/Qwen2.5-1.5B-Instruct",  # Lightweight, good for demos
    "microsoft/DialoGPT-medium",  # Fallback option
    "gpt2",  # Ultimate fallback
]

generator = None
for model_name in model_candidates:
    try:
        generator = GenerationTester(model_name)
        print(f"🎯 Using model: {model_name}")
        break
    except Exception as e:
        print(f"❌ Failed to load {model_name}: {e}")
        continue

if generator is None:
    raise RuntimeError("❌ No compatible model found")


In [None]:
# === Cell 3: Generation Strategy Functions ===


class GenerationStrategies:
    """Core generation strategies with parameter explanations"""

    @staticmethod
    def temperature_sampling(
        logits: torch.Tensor, temperature: float = 1.0
    ) -> torch.Tensor:
        """
        Temperature scaling (溫度縮放):
        - temperature < 1.0: More conservative, focused output
        - temperature = 1.0: Original distribution
        - temperature > 1.0: More random, creative output
        """
        if temperature == 0:
            return torch.argmax(logits, dim=-1, keepdim=True)
        return torch.multinomial(torch.softmax(logits / temperature, dim=-1), 1)

    @staticmethod
    def top_k_filtering(logits: torch.Tensor, top_k: int = 50) -> torch.Tensor:
        """
        Top-k filtering (前k選擇):
        Keep only top-k highest probability tokens, set others to -inf
        """
        if top_k <= 0:
            return logits

        # Get top-k values and indices
        top_k = min(top_k, logits.size(-1))
        values, indices = torch.topk(logits, top_k, dim=-1)

        # Create mask for non-top-k tokens
        mask = torch.full_like(logits, float("-inf"))
        mask.scatter_(-1, indices, values)
        return mask

    @staticmethod
    def nucleus_sampling(logits: torch.Tensor, top_p: float = 0.9) -> torch.Tensor:
        """
        Nucleus (top-p) sampling (核心採樣):
        Keep tokens until cumulative probability reaches top_p
        """
        if top_p >= 1.0:
            return logits

        # Sort logits in descending order
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        sorted_probs = torch.softmax(sorted_logits, dim=-1)

        # Calculate cumulative probabilities
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # Find cutoff point
        cutoff_mask = cumulative_probs > top_p
        cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone()
        cutoff_mask[..., 0] = False

        # Set tokens beyond cutoff to -inf
        sorted_logits[cutoff_mask] = float("-inf")

        # Unsort to original order
        original_logits = torch.full_like(logits, float("-inf"))
        original_logits.scatter_(-1, sorted_indices, sorted_logits)
        return original_logits


print("🧩 Generation strategies defined")

In [None]:
# === Cell 4: Interactive Generation Tester ===


def generate_with_strategies(
    prompt: str,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.9,
    repetition_penalty: float = 1.1,
    max_length: int = 100,
    num_return_sequences: int = 3,
    do_sample: bool = True,
) -> List[str]:
    """
    Generate text with specified strategies
    生成策略參數說明:
    - temperature: 控制隨機性 (0.1=保守, 1.0=平衡, 2.0=創意)
    - top_k: 候選詞彙數量限制 (0=無限制, 50=中等, 10=嚴格)
    - top_p: 累積機率閾值 (0.9=平衡, 0.5=保守, 0.95=寬鬆)
    - repetition_penalty: 重複懲罰 (1.0=無懲罰, 1.1=輕微, 1.5=嚴格)
    """

    # Tokenize input
    inputs = generator.tokenizer.encode(prompt, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = inputs.to(generator.device)

    # Generation parameters
    generation_kwargs = {
        "max_length": len(inputs[0]) + max_length,
        "num_return_sequences": num_return_sequences,
        "do_sample": do_sample,
        "temperature": temperature if do_sample else 1.0,
        "top_k": top_k if do_sample else 0,
        "top_p": top_p if do_sample else 1.0,
        "repetition_penalty": repetition_penalty,
        "pad_token_id": generator.tokenizer.eos_token_id,
        "eos_token_id": generator.tokenizer.eos_token_id,
        "no_repeat_ngram_size": 2,  # Prevent 2-gram repetition
    }

    # Generate
    with torch.no_grad():
        start_time = time.time()
        outputs = generator.model.generate(inputs, **generation_kwargs)
        generation_time = time.time() - start_time

    # Decode outputs
    generated_texts = []
    input_length = len(inputs[0])

    for output in outputs:
        text = generator.tokenizer.decode(
            output[input_length:], skip_special_tokens=True
        )
        generated_texts.append(text.strip())

    print(f"⏱️ Generation time: {generation_time:.2f}s")
    return generated_texts


# Test with different prompts
test_prompts = {
    "en_creative": "Once upon a time in a magical forest,",
    "en_factual": "The process of photosynthesis involves",
    "zh_creative": "從前有一個神奇的森林，",
    "zh_factual": "光合作用的過程包括",
}

print("🎮 Interactive generation tester ready")

In [None]:
# === Cell 5: Parameter Comparison Demo ===


def compare_generation_strategies():
    """Compare different generation strategies side by side"""

    prompt = "The future of artificial intelligence will"

    # Different parameter configurations
    configs = {
        "Conservative (保守)": {
            "temperature": 0.3,
            "top_k": 10,
            "top_p": 0.8,
            "repetition_penalty": 1.2,
        },
        "Balanced (平衡)": {
            "temperature": 0.7,
            "top_k": 50,
            "top_p": 0.9,
            "repetition_penalty": 1.1,
        },
        "Creative (創意)": {
            "temperature": 1.2,
            "top_k": 100,
            "top_p": 0.95,
            "repetition_penalty": 1.05,
        },
        "Greedy (貪婪)": {
            "temperature": 0.1,
            "top_k": 1,
            "top_p": 1.0,
            "repetition_penalty": 1.0,
        },
    }

    print(f"🎯 Prompt: '{prompt}'\n")
    print("=" * 80)

    results = {}
    for config_name, params in configs.items():
        print(f"\n📊 {config_name} Configuration:")
        print(
            f"   Temperature: {params['temperature']}, Top-k: {params['top_k']}, "
            f"Top-p: {params['top_p']}, Rep. Penalty: {params['repetition_penalty']}"
        )
        print("-" * 40)

        generated = generate_with_strategies(
            prompt=prompt, num_return_sequences=2, max_length=50, **params
        )

        results[config_name] = generated
        for i, text in enumerate(generated, 1):
            print(f"   {i}. {text}")
        print()

    return results


# Run comparison
comparison_results = compare_generation_strategies()

In [None]:
# === Cell 6: Generation Quality Metrics ===


class GenerationMetrics:
    """Evaluate generation quality with multiple metrics"""

    @staticmethod
    def calculate_diversity(texts: List[str]) -> Dict[str, float]:
        """Calculate lexical diversity metrics"""
        if not texts:
            return {"distinct_1": 0, "distinct_2": 0, "distinct_3": 0}

        # Combine all texts
        all_tokens = []
        for text in texts:
            tokens = text.lower().split()
            all_tokens.extend(tokens)

        if not all_tokens:
            return {"distinct_1": 0, "distinct_2": 0, "distinct_3": 0}

        # Calculate distinct n-grams
        def get_ngrams(tokens: List[str], n: int) -> List[str]:
            return [" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]

        distinct_1 = len(set(all_tokens)) / len(all_tokens) if all_tokens else 0

        bigrams = get_ngrams(all_tokens, 2)
        distinct_2 = len(set(bigrams)) / len(bigrams) if bigrams else 0

        trigrams = get_ngrams(all_tokens, 3)
        distinct_3 = len(set(trigrams)) / len(trigrams) if trigrams else 0

        return {
            "distinct_1": distinct_1,
            "distinct_2": distinct_2,
            "distinct_3": distinct_3,
        }

    @staticmethod
    def calculate_repetition_score(text: str) -> float:
        """Calculate repetition penalty score (lower is better)"""
        tokens = text.lower().split()
        if len(tokens) < 2:
            return 0.0

        # Count repeated adjacent tokens
        repetitions = sum(
            1 for i in range(len(tokens) - 1) if tokens[i] == tokens[i + 1]
        )
        return repetitions / (len(tokens) - 1)

    @staticmethod
    def calculate_coherence_score(text: str) -> float:
        """Simple coherence score based on sentence structure"""
        # Count complete sentences (basic heuristic)
        sentences = re.split(r"[.!?]+", text.strip())
        complete_sentences = [s.strip() for s in sentences if len(s.strip()) > 5]

        if not complete_sentences:
            return 0.0

        # Simple coherence: average sentence length and punctuation usage
        avg_length = np.mean([len(s.split()) for s in complete_sentences])
        punctuation_ratio = (
            len(re.findall(r"[.!?,:;]", text)) / len(text) if text else 0
        )

        # Normalize score (this is a simplified metric)
        coherence = min(1.0, (avg_length / 20) * 0.7 + punctuation_ratio * 10 * 0.3)
        return coherence


def evaluate_generation_quality(texts: List[str]) -> Dict[str, float]:
    """Comprehensive quality evaluation"""
    metrics = GenerationMetrics()

    # Calculate diversity
    diversity = metrics.calculate_diversity(texts)

    # Calculate average repetition and coherence
    repetition_scores = [metrics.calculate_repetition_score(text) for text in texts]
    coherence_scores = [metrics.calculate_coherence_score(text) for text in texts]

    results = {
        **diversity,
        "avg_repetition": np.mean(repetition_scores),
        "avg_coherence": np.mean(coherence_scores),
        "text_length_avg": np.mean([len(text.split()) for text in texts]),
        "text_length_std": np.std([len(text.split()) for text in texts]),
    }

    return results


# Evaluate our comparison results
print("📊 Quality Evaluation Results:")
print("=" * 50)

for config_name, generated_texts in comparison_results.items():
    metrics = evaluate_generation_quality(generated_texts)
    print(f"\n{config_name}:")
    print(f"  Distinct-1: {metrics['distinct_1']:.3f}")
    print(f"  Distinct-2: {metrics['distinct_2']:.3f}")
    print(f"  Repetition: {metrics['avg_repetition']:.3f}")
    print(f"  Coherence:  {metrics['avg_coherence']:.3f}")
    print(f"  Avg Length: {metrics['text_length_avg']:.1f} words")

In [None]:
# === Cell 7: Parameter Sensitivity Analysis ===


def parameter_sensitivity_analysis():
    """Analyze how parameters affect generation quality"""

    base_prompt = "The benefits of renewable energy include"
    base_params = {
        "top_k": 50,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "max_length": 40,
    }

    # Temperature sensitivity
    temperatures = [0.1, 0.3, 0.5, 0.7, 1.0, 1.3, 1.6, 2.0]
    temp_results = []

    print("🌡️ Temperature Sensitivity Analysis:")
    print("-" * 40)

    for temp in temperatures:
        params = base_params.copy()
        params["temperature"] = temp

        generated = generate_with_strategies(
            prompt=base_prompt, num_return_sequences=3, **params
        )

        metrics = evaluate_generation_quality(generated)
        temp_results.append(
            {
                "temperature": temp,
                "distinct_1": metrics["distinct_1"],
                "repetition": metrics["avg_repetition"],
                "coherence": metrics["avg_coherence"],
            }
        )

        print(
            f"T={temp:.1f}: Distinct={metrics['distinct_1']:.3f}, "
            f"Rep={metrics['avg_repetition']:.3f}, Coh={metrics['avg_coherence']:.3f}"
        )

    return temp_results


# Run sensitivity analysis
sensitivity_results = parameter_sensitivity_analysis()

In [None]:
# === Cell 8: Visualization of Results ===


def plot_parameter_effects(results: List[Dict]):
    """Visualize parameter effects on generation quality"""

    # Extract data
    temperatures = [r["temperature"] for r in results]
    distinct_scores = [r["distinct_1"] for r in results]
    repetition_scores = [r["repetition"] for r in results]
    coherence_scores = [r["coherence"] for r in results]

    # Create subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Plot 1: Diversity vs Temperature
    axes[0].plot(temperatures, distinct_scores, "b-o", linewidth=2, markersize=6)
    axes[0].set_xlabel("Temperature (溫度)")
    axes[0].set_ylabel("Distinct-1 Score (多樣性)")
    axes[0].set_title("Diversity vs Temperature")
    axes[0].grid(True, alpha=0.3)

    # Plot 2: Repetition vs Temperature
    axes[1].plot(temperatures, repetition_scores, "r-s", linewidth=2, markersize=6)
    axes[1].set_xlabel("Temperature (溫度)")
    axes[1].set_ylabel("Repetition Score (重複度)")
    axes[1].set_title("Repetition vs Temperature")
    axes[1].grid(True, alpha=0.3)

    # Plot 3: Coherence vs Temperature
    axes[2].plot(temperatures, coherence_scores, "g-^", linewidth=2, markersize=6)
    axes[2].set_xlabel("Temperature (溫度)")
    axes[2].set_ylabel("Coherence Score (連貫性)")
    axes[2].set_title("Coherence vs Temperature")
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Print optimal temperature suggestion
    optimal_temp = temperatures[np.argmax(distinct_scores)]
    print(f"\n🎯 Suggested optimal temperature: {optimal_temp}")
    print(f"   (Based on highest diversity score)")


# Create visualization
plot_parameter_effects(sensitivity_results)

In [None]:
# === Cell 9: Chinese vs English Parameter Optimization ===


def compare_language_parameters():
    """Compare optimal parameters for Chinese vs English generation"""

    prompts = {
        "english": "The key advantages of machine learning are",
        "chinese": "機器學習的主要優勢包括",
    }

    # Test different configurations
    configs = [
        {"name": "Low Temp", "temperature": 0.3, "top_k": 20, "top_p": 0.8},
        {"name": "Med Temp", "temperature": 0.7, "top_k": 50, "top_p": 0.9},
        {"name": "High Temp", "temperature": 1.2, "top_k": 100, "top_p": 0.95},
    ]

    results = {}

    for lang, prompt in prompts.items():
        print(f"\n🌍 Testing {lang.upper()} generation:")
        print("=" * 40)

        lang_results = []
        for config in configs:
            generated = generate_with_strategies(
                prompt=prompt,
                num_return_sequences=3,
                max_length=30,
                repetition_penalty=1.1,
                **{k: v for k, v in config.items() if k != "name"},
            )

            metrics = evaluate_generation_quality(generated)
            lang_results.append(
                {
                    "config": config["name"],
                    "metrics": metrics,
                    "sample": (
                        generated[0][:100] + "..."
                        if len(generated[0]) > 100
                        else generated[0]
                    ),
                }
            )

            print(
                f"{config['name']}: Distinct={metrics['distinct_1']:.3f}, "
                f"Coherence={metrics['avg_coherence']:.3f}"
            )
            print(f"  Sample: {generated[0][:80]}...")
            print()

        results[lang] = lang_results

    return results


# Run language comparison
language_results = compare_language_parameters()

In [None]:
# === Cell 10: Best Practices & Recommendations ===


def print_generation_best_practices():
    """Print generation strategy best practices and recommendations"""

    practices = {
        "🎯 Task-Specific Recommendations": [
            "Creative Writing: temperature=1.0-1.5, top_p=0.95, top_k=100+",
            "Factual Q&A: temperature=0.3-0.7, top_p=0.8-0.9, top_k=20-50",
            "Code Generation: temperature=0.1-0.5, top_p=0.8, repetition_penalty=1.2",
            "Dialogue: temperature=0.7-1.0, top_p=0.9, top_k=50, repetition_penalty=1.1",
        ],
        "🌡️ Temperature Guidelines": [
            "0.1-0.3: Deterministic, focused (good for factual tasks)",
            "0.5-0.8: Balanced creativity and consistency",
            "1.0-1.5: Creative, diverse (good for brainstorming)",
            "1.5+: Very random (experimental, may lack coherence)",
        ],
        "🔢 Top-k/Top-p Balance": [
            "High diversity: top_k=100+, top_p=0.95+",
            "Moderate diversity: top_k=50, top_p=0.9",
            "Conservative: top_k=10-20, top_p=0.7-0.8",
            "Use both together: top_k filters noise, top_p ensures quality",
        ],
        "🚫 Repetition Control": [
            "repetition_penalty=1.0: No penalty (natural repetition allowed)",
            "repetition_penalty=1.1: Light penalty (recommended default)",
            "repetition_penalty=1.2-1.5: Strong penalty (for repetitive models)",
            "no_repeat_ngram_size=2-3: Prevent exact phrase repetition",
        ],
        "🌍 Language-Specific Tips": [
            "Chinese: Slightly higher temperature (0.8-1.2) for natural flow",
            "English: Standard parameters work well across tasks",
            "Code: Lower temperature (0.1-0.5) for syntax correctness",
            "Multilingual: Test parameters per language separately",
        ],
        "⚡ Performance Optimization": [
            "Batch multiple sequences: num_return_sequences=3-5",
            "Use caching: past_key_values for multi-turn conversations",
            "Memory management: Clear cache between long generations",
            "4-bit quantization: Minimal quality loss, 4x memory savings",
        ],
    }

    print("📚 GENERATION STRATEGIES BEST PRACTICES")
    print("=" * 60)

    for category, tips in practices.items():
        print(f"\n{category}")
        print("-" * 40)
        for tip in tips:
            print(f"  • {tip}")

    print("\n" + "=" * 60)


print_generation_best_practices()

In [None]:
# === Cell 11: Smoke Test & Validation ===


def run_smoke_test():
    """Quick validation that all generation strategies work correctly"""
    print("🧪 Running Generation Strategies Smoke Test...")
    print("-" * 50)

    test_prompt = "AI technology will"

    # Test 1: Basic generation
    try:
        basic_output = generate_with_strategies(
            prompt=test_prompt, temperature=0.7, max_length=20, num_return_sequences=1
        )
        assert len(basic_output) == 1
        assert len(basic_output[0]) > 0
        print("✅ Basic generation: PASS")
    except Exception as e:
        print(f"❌ Basic generation: FAIL - {e}")
        return False

    # Test 2: Parameter variations
    try:
        for temp in [0.1, 1.0, 1.5]:
            output = generate_with_strategies(
                prompt=test_prompt,
                temperature=temp,
                max_length=15,
                num_return_sequences=1,
            )
            assert len(output[0]) > 0
        print("✅ Temperature variations: PASS")
    except Exception as e:
        print(f"❌ Temperature variations: FAIL - {e}")
        return False

    # Test 3: Quality metrics
    try:
        test_texts = ["Hello world test", "Another test sentence here"]
        metrics = evaluate_generation_quality(test_texts)
        required_keys = ["distinct_1", "avg_repetition", "avg_coherence"]
        assert all(key in metrics for key in required_keys)
        print("✅ Quality metrics: PASS")
    except Exception as e:
        print(f"❌ Quality metrics: FAIL - {e}")
        return False

    # Test 4: Memory efficiency
    try:
        if torch.cuda.is_available():
            initial_memory = torch.cuda.memory_allocated()
            # Generate longer sequence
            output = generate_with_strategies(
                prompt=test_prompt, max_length=50, num_return_sequences=2
            )
            final_memory = torch.cuda.memory_allocated()
            memory_increase = (final_memory - initial_memory) / 1e6  # MB
            print(f"✅ Memory efficiency: PASS (increase: {memory_increase:.1f}MB)")
        else:
            print("✅ Memory efficiency: PASS (CPU mode)")
    except Exception as e:
        print(f"❌ Memory efficiency: FAIL - {e}")
        return False

    print("\n🎉 All smoke tests passed!")
    return True


# Run smoke test
smoke_test_result = run_smoke_test()

# === Summary Cell ===
print("\n" + "=" * 60)
print("📋 NOTEBOOK COMPLETION SUMMARY")
print("=" * 60)

completed_items = [
    "✅ Loaded generation-capable model with low-VRAM support",
    "✅ Implemented core generation strategies (temperature, top-k, top-p)",
    "✅ Built interactive parameter comparison system",
    "✅ Created quality evaluation metrics (diversity, repetition, coherence)",
    "✅ Performed parameter sensitivity analysis",
    "✅ Compared Chinese vs English generation parameters",
    "✅ Established best practices for different generation tasks",
    "✅ Validated implementation with comprehensive smoke tests",
]

print("\n📊 Key Learnings & Concepts:")
key_concepts = [
    "🌡️ Temperature controls randomness: lower=focused, higher=creative",
    "🔢 Top-k limits vocabulary size, top-p uses probability mass",
    "🚫 Repetition penalty prevents monotonous outputs",
    "📈 Quality metrics: diversity, coherence, repetition scores",
    "🌍 Language-specific parameter optimization needed",
    "⚡ 4-bit quantization enables low-VRAM generation",
]

for concept in key_concepts:
    print(f"  {concept}")

print("\n⚠️ Common Pitfalls & Solutions:")
pitfalls = [
    "High temperature (>1.5) → incoherent text → use 0.7-1.2 range",
    "Low top-k (<10) → repetitive output → combine with top-p",
    "Excessive repetition penalty → unnatural flow → keep ≤1.3",
    "Ignoring language differences → poor quality → test per language",
    "VRAM overflow → crashes → use 4-bit quantization",
    "No evaluation metrics → can't optimize → implement quality scoring",
]

for pitfall in pitfalls:
    print(f"  ⚠️ {pitfall}")

print("\n🚀 Next Steps & Recommendations:")
next_steps = [
    "1. Implement custom sampling strategies (contrastive search, typical sampling)",
    "2. Add beam search for deterministic high-quality generation",
    "3. Create parameter auto-tuning based on task classification",
    "4. Integrate with instruction-tuned models for better control",
    "5. Build real-time generation monitoring dashboard",
    "6. Experiment with mixture of experts for dynamic parameter selection",
]

for step in next_steps:
    print(f"  {step}")

print(
    f"\n💾 Final VRAM usage: {torch.cuda.memory_allocated()/1e9:.2f}GB"
    if torch.cuda.is_available()
    else "\n💾 Running on CPU mode"
)
print("=" * 60)

In [None]:
# === FINAL SMOKE TEST ===
def final_acceptance_test():
    """Comprehensive acceptance test for nb09_generation_strategies"""
    print("🏁 Final Acceptance Test - Generation Strategies")
    print("=" * 55)

    # Test core functionality
    test_prompt = "The future of technology"

    # 1. Multi-strategy generation
    strategies = [
        {"name": "Conservative", "temperature": 0.3, "top_k": 20, "top_p": 0.8},
        {"name": "Balanced", "temperature": 0.7, "top_k": 50, "top_p": 0.9},
        {"name": "Creative", "temperature": 1.2, "top_k": 100, "top_p": 0.95},
    ]

    all_passed = True
    for strategy in strategies:
        try:
            output = generate_with_strategies(
                prompt=test_prompt,
                max_length=25,
                num_return_sequences=1,
                **{k: v for k, v in strategy.items() if k != "name"},
            )
            assert len(output[0]) > 5, f"Output too short for {strategy['name']}"
            print(f"✅ {strategy['name']} strategy: PASS")
        except Exception as e:
            print(f"❌ {strategy['name']} strategy: FAIL - {e}")
            all_passed = False

    # 2. Quality metrics validation
    try:
        sample_texts = ["AI will transform society", "Technology advances rapidly"]
        metrics = evaluate_generation_quality(sample_texts)
        assert 0 <= metrics["distinct_1"] <= 1, "Invalid distinct-1 score"
        assert metrics["avg_repetition"] >= 0, "Invalid repetition score"
        print("✅ Quality metrics: PASS")
    except Exception as e:
        print(f"❌ Quality metrics: FAIL - {e}")
        all_passed = False

    return all_passed


# Run final test
if final_acceptance_test():
    print("\n🎉 ALL ACCEPTANCE TESTS PASSED!")
    print("🎓 Ready to proceed to next notebook")
else:
    print("\n❌ Some tests failed - please review implementation")