# 阶段2：RNN/LSTM 可视化分析

本notebook用于可视化分析RNN/LSTM模型的训练过程和生成结果，包括：
- 训练曲线分析
- 生成文本示例
- 隐藏状态可视化
- 模型参数分析

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties
import pickle
import json
from pathlib import Path
import sys

# 添加模块路径
sys.path.append('.')
from models.rnn import create_rnn_model
from models.lstm import SimpleLSTM
from utils.text_data import CharacterVocabulary, WordVocabulary
from generate import load_model_and_vocab, temperature_sampling

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 设置图形样式
plt.style.use('default')
sns.set_palette('husl')

## 1. 训练历史可视化

分析训练过程中的损失和困惑度变化

In [None]:
def plot_training_history(history_path, save_path=None):
    """
    绘制训练历史曲线
    
    Args:
        history_path (str): 训练历史文件路径
        save_path (str): 保存路径
    """
    # 加载训练历史
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('RNN/LSTM 训练历史分析', fontsize=16, fontweight='bold')
    
    # 训练损失
    axes[0, 0].plot(history['train_losses'], label='训练损失', linewidth=2)
    if 'val_losses' in history:
        axes[0, 0].plot(history['val_losses'], label='验证损失', linewidth=2)
    axes[0, 0].set_title('损失曲线')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 困惑度
    if 'train_perplexities' in history:
        axes[0, 1].plot(history['train_perplexities'], label='训练困惑度', linewidth=2)
        if 'val_perplexities' in history:
            axes[0, 1].plot(history['val_perplexities'], label='验证困惑度', linewidth=2)
        axes[0, 1].set_title('困惑度曲线')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Perplexity')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
    
    # 学习率
    if 'learning_rates' in history:
        axes[1, 0].plot(history['learning_rates'], linewidth=2, color='orange')
        axes[1, 0].set_title('学习率变化')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True, alpha=0.3)
    
    # 梯度范数
    if 'grad_norms' in history:
        axes[1, 1].plot(history['grad_norms'], linewidth=2, color='red')
        axes[1, 1].set_title('梯度范数')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Gradient Norm')
        axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

# 示例用法
# plot_training_history('checkpoints/training_history.json')

## 2. 文本生成示例对比

比较不同采样策略的生成效果

In [None]:
def compare_generation_strategies(checkpoint_path, vocab_path, start_text="The", max_length=200):
    """
    比较不同生成策略的效果
    
    Args:
        checkpoint_path (str): 模型检查点路径
        vocab_path (str): 词汇表路径
        start_text (str): 起始文本
        max_length (int): 生成长度
    """
    # 加载模型
    model, vocab, device = load_model_and_vocab(checkpoint_path, vocab_path)
    
    # 不同策略的参数
    strategies = {
        '贪心解码': {'temperature': 0.1},
        '温度采样 (T=0.5)': {'temperature': 0.5},
        '温度采样 (T=0.8)': {'temperature': 0.8},
        '温度采样 (T=1.2)': {'temperature': 1.2},
        '高温采样 (T=1.5)': {'temperature': 1.5}
    }
    
    print(f"起始文本: '{start_text}'")
    print("=" * 80)
    
    results = {}
    for strategy_name, params in strategies.items():
        print(f"\n📝 {strategy_name}:")
        print("-" * 60)
        
        generated = temperature_sampling(
            model, vocab, device, start_text, max_length, 
            temperature=params['temperature']
        )
        
        results[strategy_name] = generated
        print(generated)
    
    return results

# 示例用法
# results = compare_generation_strategies(
#     'checkpoints/best_lstm_char.pt',
#     'checkpoints/char_vocabulary.pkl',
#     start_text="Once upon a time"
# )

## 3. 隐藏状态可视化

使用t-SNE可视化模型的隐藏状态

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

def extract_hidden_states(model, vocab, device, texts, max_seq_len=50):
    """
    提取模型的隐藏状态
    
    Args:
        model: 训练好的模型
        vocab: 词汇表
        device: 计算设备
        texts (list): 文本列表
        max_seq_len (int): 最大序列长度
        
    Returns:
        numpy.array: 隐藏状态矩阵
    """
    model.eval()
    hidden_states = []
    
    with torch.no_grad():
        for text in texts:
            # 文本转索引
            if isinstance(vocab, CharacterVocabulary):
                indices = vocab.text_to_indices(text[:max_seq_len])
            else:
                words = text.split()[:max_seq_len]
                indices = vocab.text_to_indices(words)
            
            if len(indices) < 5:  # 跳过太短的序列
                continue
                
            # 转换为tensor
            input_tensor = torch.tensor(indices, dtype=torch.long, device=device).unsqueeze(0)
            
            # 前向传播获取隐藏状态
            if hasattr(model, 'get_hidden_states'):
                hidden = model.get_hidden_states(input_tensor)
            else:
                # 手动提取最后一层的隐藏状态
                embeddings = model.embedding(input_tensor)
                if isinstance(model, SimpleLSTM):
                    output, (hidden, cell) = model.lstm(embeddings)
                    hidden = hidden[-1, 0, :]  # 取最后一层的隐藏状态
                else:
                    output, hidden = model.rnn(embeddings)
                    hidden = hidden[-1, 0, :]  # 取最后一层的隐藏状态
            
            hidden_states.append(hidden.cpu().numpy())
    
    return np.array(hidden_states)

def visualize_hidden_states(hidden_states, labels=None, method='tsne', save_path=None):
    """
    可视化隐藏状态
    
    Args:
        hidden_states (numpy.array): 隐藏状态矩阵
        labels (list): 标签列表
        method (str): 降维方法 ('tsne' 或 'pca')
        save_path (str): 保存路径
    """
    if method == 'tsne':
        # 首先用PCA降维到50维，再用t-SNE
        if hidden_states.shape[1] > 50:
            pca = PCA(n_components=50)
            hidden_states = pca.fit_transform(hidden_states)
        
        reducer = TSNE(n_components=2, random_state=42, perplexity=min(30, len(hidden_states)-1))
        embedded = reducer.fit_transform(hidden_states)
        title = 't-SNE 隐藏状态可视化'
    else:
        reducer = PCA(n_components=2)
        embedded = reducer.fit_transform(hidden_states)
        title = 'PCA 隐藏状态可视化'
    
    plt.figure(figsize=(12, 8))
    
    if labels is not None:
        unique_labels = list(set(labels))
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
        
        for i, label in enumerate(unique_labels):
            mask = np.array(labels) == label
            plt.scatter(embedded[mask, 0], embedded[mask, 1], 
                       c=[colors[i]], label=label, alpha=0.7, s=50)
        
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        plt.scatter(embedded[:, 0], embedded[:, 1], alpha=0.7, s=50)
    
    plt.title(title, fontsize=14, fontweight='bold')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

# 示例用法
# texts = ["Hello world", "Good morning", "How are you", ...]
# hidden_states = extract_hidden_states(model, vocab, device, texts)
# visualize_hidden_states(hidden_states, method='tsne')

## 4. 模型参数分析

分析模型权重分布和梯度情况

In [None]:
def analyze_model_parameters(checkpoint_path):
    """
    分析模型参数
    
    Args:
        checkpoint_path (str): 模型检查点路径
    """
    # 加载检查点
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # 提取模型参数
    state_dict = checkpoint['model_state_dict']
    
    # 计算参数统计信息
    param_stats = {}
    total_params = 0
    
    for name, param in state_dict.items():
        param_tensor = param.detach().numpy()
        param_stats[name] = {
            'shape': param_tensor.shape,
            'mean': np.mean(param_tensor),
            'std': np.std(param_tensor),
            'min': np.min(param_tensor),
            'max': np.max(param_tensor),
            'num_params': param_tensor.size
        }
        total_params += param_tensor.size
    
    # 可视化
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('模型参数分析', fontsize=16, fontweight='bold')
    
    # 参数数量分布
    layer_names = []
    param_counts = []
    
    for name, stats in param_stats.items():
        layer_names.append(name.split('.')[0])  # 取层名
        param_counts.append(stats['num_params'])
    
    # 按层聚合
    layer_param_counts = {}
    for layer, count in zip(layer_names, param_counts):
        layer_param_counts[layer] = layer_param_counts.get(layer, 0) + count
    
    axes[0, 0].bar(layer_param_counts.keys(), layer_param_counts.values())
    axes[0, 0].set_title('各层参数数量')
    axes[0, 0].set_ylabel('参数数量')
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # 参数均值分布
    means = [stats['mean'] for stats in param_stats.values()]
    axes[0, 1].hist(means, bins=20, alpha=0.7, edgecolor='black')
    axes[0, 1].set_title('参数均值分布')
    axes[0, 1].set_xlabel('参数均值')
    axes[0, 1].set_ylabel('频次')
    
    # 参数标准差分布
    stds = [stats['std'] for stats in param_stats.values()]
    axes[1, 0].hist(stds, bins=20, alpha=0.7, edgecolor='black', color='orange')
    axes[1, 0].set_title('参数标准差分布')
    axes[1, 0].set_xlabel('参数标准差')
    axes[1, 0].set_ylabel('频次')
    
    # 参数范围
    param_names = list(param_stats.keys())[:10]  # 只显示前10个
    mins = [param_stats[name]['min'] for name in param_names]
    maxs = [param_stats[name]['max'] for name in param_names]
    
    x = np.arange(len(param_names))
    axes[1, 1].bar(x, maxs, alpha=0.7, label='最大值')
    axes[1, 1].bar(x, mins, alpha=0.7, label='最小值')
    axes[1, 1].set_title('参数范围 (前10层)')
    axes[1, 1].set_xlabel('层名')
    axes[1, 1].set_ylabel('参数值')
    axes[1, 1].set_xticks(x)
    axes[1, 1].set_xticklabels([name.split('.')[0] for name in param_names], rotation=45)
    axes[1, 1].legend()
    
    plt.tight_layout()
    plt.show()
    
    # 打印总结信息
    print(f"\n📊 模型参数总结:")
    print(f"总参数数量: {total_params:,}")
    print(f"层数: {len(param_stats)}")
    
    return param_stats

# 示例用法
# param_stats = analyze_model_parameters('checkpoints/best_lstm_char.pt')

## 5. 生成质量评估

评估生成文本的质量指标

In [None]:
import re
from collections import Counter

def calculate_text_metrics(text):
    """
    计算文本质量指标
    
    Args:
        text (str): 生成的文本
        
    Returns:
        dict: 质量指标字典
    """
    # 基础统计
    char_count = len(text)
    word_count = len(text.split())
    sentence_count = len(re.findall(r'[.!?]+', text))
    
    # 词汇多样性
    words = text.lower().split()
    unique_words = len(set(words))
    vocabulary_diversity = unique_words / max(word_count, 1)
    
    # 重复度分析
    word_freq = Counter(words)
    most_common = word_freq.most_common(5)
    
    # 平均词长
    avg_word_length = np.mean([len(word) for word in words]) if words else 0
    
    # 平均句长
    sentences = re.split(r'[.!?]+', text)
    avg_sentence_length = np.mean([len(sent.split()) for sent in sentences if sent.strip()]) if sentences else 0
    
    return {
        '字符数': char_count,
        '词数': word_count,
        '句数': sentence_count,
        '词汇多样性': vocabulary_diversity,
        '独特词数': unique_words,
        '平均词长': avg_word_length,
        '平均句长': avg_sentence_length,
        '最常见词': most_common
    }

def compare_generation_quality(generated_texts):
    """
    比较不同生成策略的文本质量
    
    Args:
        generated_texts (dict): 不同策略生成的文本
    """
    metrics_comparison = {}
    
    for strategy, text in generated_texts.items():
        metrics = calculate_text_metrics(text)
        metrics_comparison[strategy] = metrics
    
    # 可视化比较
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('生成文本质量对比', fontsize=16, fontweight='bold')
    
    strategies = list(metrics_comparison.keys())
    
    # 词汇多样性
    diversity_scores = [metrics_comparison[s]['词汇多样性'] for s in strategies]
    axes[0, 0].bar(strategies, diversity_scores)
    axes[0, 0].set_title('词汇多样性')
    axes[0, 0].set_ylabel('多样性得分')
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # 平均词长
    word_lengths = [metrics_comparison[s]['平均词长'] for s in strategies]
    axes[0, 1].bar(strategies, word_lengths, color='orange')
    axes[0, 1].set_title('平均词长')
    axes[0, 1].set_ylabel('字符数')
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    # 平均句长
    sentence_lengths = [metrics_comparison[s]['平均句长'] for s in strategies]
    axes[0, 2].bar(strategies, sentence_lengths, color='green')
    axes[0, 2].set_title('平均句长')
    axes[0, 2].set_ylabel('词数')
    axes[0, 2].tick_params(axis='x', rotation=45)
    
    # 词数
    word_counts = [metrics_comparison[s]['词数'] for s in strategies]
    axes[1, 0].bar(strategies, word_counts, color='red')
    axes[1, 0].set_title('总词数')
    axes[1, 0].set_ylabel('词数')
    axes[1, 0].tick_params(axis='x', rotation=45)
    
    # 独特词数
    unique_word_counts = [metrics_comparison[s]['独特词数'] for s in strategies]
    axes[1, 1].bar(strategies, unique_word_counts, color='purple')
    axes[1, 1].set_title('独特词数')
    axes[1, 1].set_ylabel('词数')
    axes[1, 1].tick_params(axis='x', rotation=45)
    
    # 句数
    sentence_counts = [metrics_comparison[s]['句数'] for s in strategies]
    axes[1, 2].bar(strategies, sentence_counts, color='brown')
    axes[1, 2].set_title('句数')
    axes[1, 2].set_ylabel('句数')
    axes[1, 2].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # 打印详细对比
    print("\n📊 详细质量指标对比:")
    print("=" * 80)
    for strategy, metrics in metrics_comparison.items():
        print(f"\n🎯 {strategy}:")
        for metric, value in metrics.items():
            if metric != '最常见词':
                if isinstance(value, float):
                    print(f"  {metric}: {value:.3f}")
                else:
                    print(f"  {metric}: {value}")
        print(f"  最常见词: {metrics['最常见词'][:3]}")
    
    return metrics_comparison

# 示例用法
# generated_texts = {
#     '贪心': "This is a sample text generated by greedy decoding.",
#     '温度采样': "This is another sample text with more diversity and creativity."
# }
# quality_metrics = compare_generation_quality(generated_texts)

## 6. 使用示例

以下是完整的分析流程示例

In [None]:
# 设置路径
checkpoint_path = 'checkpoints/best_lstm_char.pt'
vocab_path = 'checkpoints/char_vocabulary.pkl'
history_path = 'checkpoints/training_history.json'

# 检查文件是否存在
from pathlib import Path
if Path(checkpoint_path).exists():
    print("✅ 找到模型文件，开始分析...")
    
    # 1. 训练历史分析
    if Path(history_path).exists():
        plot_training_history(history_path)
    
    # 2. 生成策略对比
    if Path(vocab_path).exists():
        results = compare_generation_strategies(
            checkpoint_path, vocab_path, 
            start_text="The quick brown fox", 
            max_length=150
        )
        
        # 3. 质量评估
        quality_metrics = compare_generation_quality(results)
    
    # 4. 参数分析
    param_stats = analyze_model_parameters(checkpoint_path)
    
else:
    print("❌ 未找到模型文件，请先训练模型")
    print("运行: python train.py --model_type lstm --vocab_type char --epochs 10")