In [None]:
# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import time
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
from datetime import datetime
import logging
from tqdm import tqdm
import pandas as pd

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']  
plt.rcParams['axes.unicode_minus'] = False

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("📚 环境设置完成")


In [None]:
# 导入项目模块
from config import model_config, dialog_config
from models import model_manager
from embedding_compressor import embedding_compressor
from direct_embedding_compressor import direct_compressor

print("🔧 项目模块导入完成")


In [None]:
# 初始化Qwen2.5-0.5B模型
print("🚀 初始化Qwen2.5-0.5B模型...")

# 确保使用0.5B模型
model_config.model_name = "Qwen/Qwen2.5-0.5B-Instruct"

# 加载模型
success = model_manager.load_models()

if success:
    print(f"✅ 模型加载完成: {model_config.model_name}")
    print(f"📊 模型参数量: ~0.5B")
    print(f"💾 设备: {model_manager.device}")
else:
    print("❌ 模型加载失败")


In [None]:
class EmbeddingDialogProcessor:
    """将embedding向量与查询一起输入模型的处理器"""
    
    def __init__(self, model_manager):
        self.model_manager = model_manager
        self.embedding_dim = 896  # Qwen2.5-0.5B的hidden_size
        
    def extract_text_embedding(self, text: str, layer_idx: int = -1) -> torch.Tensor:
        """提取文本的hidden state embedding"""
        if not self.model_manager.dialog_model or not self.model_manager.tokenizer:
            return torch.zeros(self.embedding_dim)
        
        try:
            # Tokenize
            inputs = self.model_manager.tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=512,
                padding=True
            ).to(self.model_manager.device)
            
            # 获取hidden states
            with torch.no_grad():
                outputs = self.model_manager.dialog_model(
                    **inputs,
                    output_hidden_states=True
                )
                
                # 提取指定层的hidden states
                hidden_states = outputs.hidden_states[layer_idx]  # [batch, seq_len, hidden_dim]
                
                # 使用平均pooling
                embedding = hidden_states.mean(dim=1)  # [batch, hidden_dim]
                
                return embedding.squeeze(0).cpu()  # [hidden_dim]
                
        except Exception as e:
            logger.error(f"提取embedding失败: {e}")
            return torch.zeros(self.embedding_dim)
    
    def inject_embedding_into_model(self, 
                                  history_embedding: torch.Tensor, 
                                  query_text: str,
                                  injection_method: str = "prefix") -> str:
        """将历史embedding注入到模型中与查询一起处理"""
        
        if injection_method == "prefix":
            return self._prefix_injection(history_embedding, query_text)
        elif injection_method == "interpolation":
            return self._interpolation_injection(history_embedding, query_text)
        elif injection_method == "attention_fusion":
            return self._attention_fusion_injection(history_embedding, query_text)
        else:
            raise ValueError(f"未知的注入方法: {injection_method}")
    
    def _prefix_injection(self, history_embedding: torch.Tensor, query_text: str) -> str:
        """前缀注入：将embedding转换为特殊token前缀"""
        try:
            # 将embedding量化为discrete tokens
            embedding_tokens = self._embedding_to_tokens(history_embedding)
            
            # 构建带前缀的输入
            prefixed_prompt = f"<HIST_EMB>{embedding_tokens}</HIST_EMB>\n\n用户: {query_text}\n助手:"
            
            # 生成响应
            response = self.model_manager.generate_text(
                model=self.model_manager.dialog_model,
                prompt=prefixed_prompt,
                max_new_tokens=512
            )
            
            return response
            
        except Exception as e:
            logger.error(f"前缀注入失败: {e}")
            return ""
    
    def _interpolation_injection(self, history_embedding: torch.Tensor, query_text: str) -> str:
        """插值注入：在模型内部混合embedding"""
        try:
            # 首先获取查询的embedding
            query_embedding = self.extract_text_embedding(query_text)
            
            # 混合历史和查询embedding
            alpha = 0.3  # 历史信息权重
            mixed_embedding = alpha * history_embedding + (1 - alpha) * query_embedding
            
            # 转换回文本表示（简化版本）
            mixed_prompt = f"基于历史上下文的查询: {query_text}\n助手:"
            
            response = self.model_manager.generate_text(
                model=self.model_manager.dialog_model,
                prompt=mixed_prompt,
                max_new_tokens=512
            )
            
            return response
            
        except Exception as e:
            logger.error(f"插值注入失败: {e}")
            return ""
    
    def _attention_fusion_injection(self, history_embedding: torch.Tensor, query_text: str) -> str:
        """注意力融合注入：使用注意力机制融合embedding"""
        try:
            # 获取查询embedding
            query_embedding = self.extract_text_embedding(query_text)
            
            # 计算注意力权重
            attention_score = torch.cosine_similarity(
                history_embedding.unsqueeze(0), 
                query_embedding.unsqueeze(0)
            ).item()
            
            # 基于注意力权重调整提示
            if attention_score > 0.5:
                prompt = f"参考相关历史信息: {query_text}\n助手:"
            else:
                prompt = f"用户: {query_text}\n助手:"
            
            response = self.model_manager.generate_text(
                model=self.model_manager.dialog_model,
                prompt=prompt,
                max_new_tokens=512
            )
            
            return response
            
        except Exception as e:
            logger.error(f"注意力融合注入失败: {e}")
            return ""
    
    def _embedding_to_tokens(self, embedding: torch.Tensor, num_tokens: int = 10) -> str:
        """将embedding向量转换为离散token表示"""
        # 简化方法：将embedding的主要维度映射为特殊符号
        # 实际应用中可能需要更复杂的量化方案
        
        # 找到最大的几个维度
        top_values, top_indices = torch.topk(embedding, num_tokens)
        
        # 映射为特殊token
        tokens = []
        for i, (val, idx) in enumerate(zip(top_values, top_indices)):
            # 简单的映射方案
            token = f"<E{idx.item()%100:02d}>"
            tokens.append(token)
        
        return " ".join(tokens)

# 创建处理器实例
embedding_processor = EmbeddingDialogProcessor(model_manager)
print("🔧 Embedding处理器创建完成")


In [None]:
# 创建超长多轮对话测试数据
def create_long_dialog_data():
    """创建模拟的超长对话数据"""
    
    dialogs = [
        {
            "title": "技术讨论 - Python编程",
            "history": [
                {"role": "user", "content": "你好，我想学习Python编程，应该从哪里开始？"},
                {"role": "assistant", "content": "你好！学习Python是个很好的选择。建议从基础语法开始，包括变量、数据类型、控制结构等。可以先安装Python环境，然后练习简单的程序。"},
                {"role": "user", "content": "我已经安装了Python 3.9，现在想了解数据结构，比如列表和字典。"},
                {"role": "assistant", "content": "很好！列表(list)是有序的可变序列，用[]表示，如[1,2,3]。字典(dict)是键值对的集合，用{}表示，如{'name':'Tom', 'age':25}。列表适合存储同类数据，字典适合存储结构化数据。"},
                {"role": "user", "content": "那么如何遍历这些数据结构呢？for循环是怎么工作的？"},
                {"role": "assistant", "content": "遍历列表可以用 for item in list:，遍历字典可以用 for key in dict: 或 for key, value in dict.items():。for循环会自动迭代容器中的每个元素，这是Python的强大特性之一。"},
                {"role": "user", "content": "我想写一个程序来处理文件，比如读取文本文件并统计单词数量。"},
                {"role": "assistant", "content": "文件处理是很实用的技能！可以用open()函数读取文件，建议使用 with open('file.txt', 'r') as f: 的方式，这样会自动关闭文件。统计单词可以用split()分割文本，然后计算列表长度。"},
                {"role": "user", "content": "如果我想要更复杂的文本分析，比如统计每个单词的出现频率怎么办？"},
                {"role": "assistant", "content": "可以使用字典来统计频率！遍历所有单词，如果单词已在字典中则计数+1，否则设为1。也可以使用collections.Counter类，它专门用于计数统计，使用起来更简便。"},
            ],
            "test_query": "现在我想学习面向对象编程，类和对象的概念是什么？"
        },
        {
            "title": "科学讨论 - 气候变化",
            "history": [
                {"role": "user", "content": "最近总听说气候变化，这到底是什么意思？"},
                {"role": "assistant", "content": "气候变化指地球气候系统长期的变化趋势，主要表现为全球平均温度上升、极端天气频发、海平面上升等。主要原因是人类活动导致的温室气体排放增加。"},
                {"role": "user", "content": "温室气体都有哪些？它们是如何影响气候的？"},
                {"role": "assistant", "content": "主要的温室气体包括二氧化碳(CO2)、甲烷(CH4)、氮氧化物(N2O)等。它们能吸收地球表面发出的长波辐射，形成温室效应，导致大气温度升高。"},
                {"role": "user", "content": "那我们个人能做些什么来减少温室气体排放呢？"},
                {"role": "assistant", "content": "个人可以从多方面减排：减少开车、选择公共交通；节约用电；减少肉类消费；支持可再生能源；垃圾分类回收；购买环保产品等。虽然个人作用有限，但集体行动能产生巨大影响。"},
                {"role": "user", "content": "我听说森林砍伐也与气候变化有关，这是为什么？"},
                {"role": "assistant", "content": "森林是重要的碳汇，树木通过光合作用吸收CO2并储存碳。砍伐森林不仅减少了碳吸收能力，燃烧或腐烂的木材还会释放储存的碳。亚马逊雨林被称为'地球之肺'就是这个原因。"},
                {"role": "user", "content": "可再生能源有哪些类型？它们各有什么优缺点？"},
                {"role": "assistant", "content": "主要类型有太阳能、风能、水能、地热能、生物能等。太阳能和风能技术成熟、成本下降快，但有间歇性问题。水能稳定可靠，但受地理条件限制。地热能稳定但分布不均。生物能可再生但可能与粮食竞争土地。"},
            ],
            "test_query": "核能作为清洁能源的优缺点是什么？它在应对气候变化中的作用如何？"
        },
        {
            "title": "健康咨询 - 营养与运动", 
            "history": [
                {"role": "user", "content": "我想要开始健康的生活方式，但不知道从哪里开始。"},
                {"role": "assistant", "content": "很好的决定！健康生活方式主要包括均衡饮食、规律运动、充足睡眠和压力管理。建议先从小的改变开始，比如每天多喝水、增加蔬菜摄入、每天散步30分钟。"},
                {"role": "user", "content": "关于饮食，我应该如何安排三餐？有什么营养搭配的原则吗？"},
                {"role": "assistant", "content": "三餐应该均衡搭配：早餐要丰富(蛋白质+碳水+维生素)，午餐要充足，晚餐要清淡。遵循'彩虹饮食'原则，多吃不同颜色的蔬果。控制糖分和加工食品的摄入。"},
                {"role": "user", "content": "我之前很少运动，应该选择什么类型的运动比较好？"},
                {"role": "assistant", "content": "初学者建议从低强度有氧运动开始，如快走、游泳、骑车。每周3-4次，每次30-45分钟。可以逐渐加入力量训练来增强肌肉。最重要的是选择自己喜欢的运动，这样更容易坚持。"},
                {"role": "user", "content": "睡眠质量对健康有多重要？我经常失眠该怎么办？"},
                {"role": "assistant", "content": "睡眠极其重要！充足睡眠有助于身体修复、记忆整合、免疫力维持。成人需要7-9小时睡眠。改善失眠可以：保持规律作息、睡前1小时避免屏幕、创造舒适环境、避免咖啡因、尝试放松技巧如冥想。"},
                {"role": "user", "content": "压力管理方面，有什么有效的方法可以推荐吗？"},
                {"role": "assistant", "content": "压力管理方法很多：深呼吸练习、冥想、瑜伽、适量运动、与朋友聊天、培养爱好、合理安排时间、学会说'不'。关键是找到适合自己的方式，定期练习。必要时可以寻求专业心理帮助。"},
            ],
            "test_query": "我想了解更多关于营养补充剂的信息，比如维生素和蛋白粉，这些真的有必要吗？"
        }
    ]
    
    return dialogs

# 加载测试数据
test_dialogs = create_long_dialog_data()
print(f"📋 创建了 {len(test_dialogs)} 个测试对话")
for i, dialog in enumerate(test_dialogs):
    print(f"  {i+1}. {dialog['title']} - {len(dialog['history'])} 轮对话")


In [None]:
class ExperimentRunner:
    """实验运行器，对比不同方法的性能"""
    
    def __init__(self, embedding_processor, model_manager):
        self.embedding_processor = embedding_processor
        self.model_manager = model_manager
        self.results = []
    
    def run_baseline_experiment(self, dialog_history: List[Dict], query: str) -> Dict:
        """基线实验：使用原始完整历史"""
        start_time = time.time()
        
        # 构建完整历史提示
        full_context = ""
        for turn in dialog_history[-6:]:  # 最近6轮对话
            role = "用户" if turn['role'] == 'user' else "助手"
            full_context += f"{role}: {turn['content']}\n"
        
        full_prompt = f"{full_context}用户: {query}\n助手:"
        
        # 生成回复
        response = self.model_manager.generate_text(
            model=self.model_manager.dialog_model,
            prompt=full_prompt,
            max_new_tokens=512
        )
        
        end_time = time.time()
        
        return {
            'method': 'baseline_full_context',
            'response': response,
            'latency': end_time - start_time,
            'input_tokens': self.model_manager.count_tokens(full_prompt),
            'output_tokens': self.model_manager.count_tokens(response),
            'context_length': len(full_context)
        }
    
    def run_embedding_experiment(self, dialog_history: List[Dict], query: str, 
                               injection_method: str = "prefix") -> Dict:
        """embedding实验：使用压缩的embedding表示"""
        start_time = time.time()
        
        # 压缩历史为embedding
        history_text = ""
        for turn in dialog_history:
            role = "用户" if turn['role'] == 'user' else "助手"
            history_text += f"{role}: {turn['content']}\n"
        
        # 提取历史embedding
        history_embedding = self.embedding_processor.extract_text_embedding(history_text)
        
        # 使用embedding生成回复
        response = self.embedding_processor.inject_embedding_into_model(
            history_embedding, query, injection_method
        )
        
        end_time = time.time()
        
        return {
            'method': f'embedding_{injection_method}',
            'response': response,
            'latency': end_time - start_time,
            'input_tokens': self.model_manager.count_tokens(query),  # 只计算query的token
            'output_tokens': self.model_manager.count_tokens(response),
            'embedding_dim': history_embedding.shape[0],
            'compression_ratio': len(history_text) / history_embedding.shape[0]
        }
    
    def run_no_context_experiment(self, query: str) -> Dict:
        """无上下文实验：只用当前查询"""
        start_time = time.time()
        
        prompt = f"用户: {query}\n助手:"
        response = self.model_manager.generate_text(
            model=self.model_manager.dialog_model,
            prompt=prompt,
            max_new_tokens=512
        )
        
        end_time = time.time()
        
        return {
            'method': 'no_context',
            'response': response,
            'latency': end_time - start_time,
            'input_tokens': self.model_manager.count_tokens(prompt),
            'output_tokens': self.model_manager.count_tokens(response)
        }
    
    def evaluate_response_quality(self, response: str, query: str, context: str) -> Dict:
        """评估回复质量（简化版本）"""
        # 简单的质量指标
        relevance_score = min(1.0, len(response) / 100)  # 基于长度的相关性
        coherence_score = 1.0 - response.count("。。") * 0.1  # 减少重复的分数
        informativeness = min(1.0, len(set(response.split())) / 50)  # 词汇多样性
        
        return {
            'relevance': relevance_score,
            'coherence': max(0, coherence_score),
            'informativeness': informativeness,
            'overall_score': (relevance_score + coherence_score + informativeness) / 3
        }
    
    def run_comparative_experiment(self, test_dialog: Dict) -> Dict:
        """运行完整的对比实验"""
        history = test_dialog['history']
        query = test_dialog['test_query']
        title = test_dialog['title']
        
        print(f"\n🧪 运行实验: {title}")
        print(f"   历史轮数: {len(history)}")
        print(f"   测试查询: {query[:50]}...")
        
        results = {}
        
        # 1. 基线实验
        print("   🔵 运行基线实验...")
        baseline_result = self.run_baseline_experiment(history, query)
        baseline_quality = self.evaluate_response_quality(
            baseline_result['response'], query, str(history)
        )
        baseline_result.update(baseline_quality)
        results['baseline'] = baseline_result
        
        # 2. 无上下文实验
        print("   ⚪ 运行无上下文实验...")
        no_context_result = self.run_no_context_experiment(query)
        no_context_quality = self.evaluate_response_quality(
            no_context_result['response'], query, ""
        )
        no_context_result.update(no_context_quality)
        results['no_context'] = no_context_result
        
        # 3. Embedding实验 (多种注入方法)
        for method in ['prefix', 'interpolation', 'attention_fusion']:
            print(f"   🟡 运行embedding实验 ({method})...")
            embedding_result = self.run_embedding_experiment(history, query, method)
            embedding_quality = self.evaluate_response_quality(
                embedding_result['response'], query, "embedding_context"
            )
            embedding_result.update(embedding_quality)
            results[f'embedding_{method}'] = embedding_result
        
        return {
            'dialog_title': title,
            'test_query': query,
            'results': results
        }

# 创建实验运行器
experiment_runner = ExperimentRunner(embedding_processor, model_manager)
print("🔬 实验运行器创建完成")


In [None]:
# 运行所有实验
print("🚀 开始运行所有实验...")

all_experiment_results = []

for i, dialog in enumerate(test_dialogs):
    print(f"\n{'='*60}")
    print(f"实验 {i+1}/{len(test_dialogs)}: {dialog['title']}")
    
    try:
        result = experiment_runner.run_comparative_experiment(dialog)
        all_experiment_results.append(result)
        print("✅ 实验完成")
        
    except Exception as e:
        print(f"❌ 实验失败: {e}")
        continue

print(f"\n🎉 所有实验完成！共完成 {len(all_experiment_results)} 个实验")


In [None]:
# 分析实验结果
def analyze_experiment_results(results):
    """分析实验结果并生成统计数据"""
    
    # 收集所有方法的数据
    methods = ['baseline', 'no_context', 'embedding_prefix', 'embedding_interpolation', 'embedding_attention_fusion']
    metrics_data = {method: [] for method in methods}
    
    for experiment in results:
        for method in methods:
            if method in experiment['results']:
                metrics_data[method].append(experiment['results'][method])
    
    # 计算平均指标
    summary_stats = {}
    for method, data in metrics_data.items():
        if data:
            summary_stats[method] = {
                'avg_latency': np.mean([d['latency'] for d in data]),
                'avg_input_tokens': np.mean([d['input_tokens'] for d in data]),
                'avg_output_tokens': np.mean([d['output_tokens'] for d in data]),
                'avg_relevance': np.mean([d['relevance'] for d in data]),
                'avg_coherence': np.mean([d['coherence'] for d in data]),
                'avg_informativeness': np.mean([d['informativeness'] for d in data]),
                'avg_overall_score': np.mean([d['overall_score'] for d in data]),
                'sample_size': len(data)
            }
    
    return summary_stats

# 分析结果
if all_experiment_results:
    summary_stats = analyze_experiment_results(all_experiment_results)
    
    print("📊 实验结果统计摘要:")
    print("="*80)
    
    for method, stats in summary_stats.items():
        print(f"\n🔹 {method.upper()}:")
        print(f"   平均延迟: {stats['avg_latency']:.3f}s")
        print(f"   平均输入tokens: {stats['avg_input_tokens']:.1f}")
        print(f"   平均输出tokens: {stats['avg_output_tokens']:.1f}")
        print(f"   相关性得分: {stats['avg_relevance']:.3f}")
        print(f"   连贯性得分: {stats['avg_coherence']:.3f}")
        print(f"   信息量得分: {stats['avg_informativeness']:.3f}")
        print(f"   综合得分: {stats['avg_overall_score']:.3f}")
        print(f"   样本数量: {stats['sample_size']}")
else:
    print("⚠️ 没有实验结果可分析")


In [None]:
# 可视化实验结果
def create_comparison_plots(summary_stats):
    """创建对比图表"""
    
    if not summary_stats:
        print("⚠️ 没有数据可以可视化")
        return
    
    methods = list(summary_stats.keys())
    
    # 创建子图
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Embedding压缩技术实验结果对比', fontsize=16, fontweight='bold')
    
    # 1. 延迟对比
    latencies = [summary_stats[method]['avg_latency'] for method in methods]
    axes[0, 0].bar(methods, latencies, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
    axes[0, 0].set_title('平均延迟对比')
    axes[0, 0].set_ylabel('延迟 (秒)')
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # 2. Token使用量对比
    input_tokens = [summary_stats[method]['avg_input_tokens'] for method in methods]
    axes[0, 1].bar(methods, input_tokens, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'])
    axes[0, 1].set_title('平均输入Token数量')
    axes[0, 1].set_ylabel('Token数量')
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    # 3. 质量得分对比
    quality_metrics = ['avg_relevance', 'avg_coherence', 'avg_informativeness']
    quality_labels = ['相关性', '连贯性', '信息量']
    
    x = np.arange(len(methods))
    width = 0.25
    
    for i, (metric, label) in enumerate(zip(quality_metrics, quality_labels)):
        scores = [summary_stats[method][metric] for method in methods]
        axes[1, 0].bar(x + i*width, scores, width, label=label)
    
    axes[1, 0].set_title('质量指标对比')
    axes[1, 0].set_ylabel('得分')
    axes[1, 0].set_xticks(x + width)
    axes[1, 0].set_xticklabels(methods, rotation=45)
    axes[1, 0].legend()
    
    # 4. 综合得分对比
    overall_scores = [summary_stats[method]['avg_overall_score'] for method in methods]
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    axes[1, 1].pie(overall_scores, labels=methods, colors=colors, autopct='%1.3f')
    axes[1, 1].set_title('综合得分分布')
    
    plt.tight_layout()
    plt.show()

# 创建对比图表
if all_experiment_results and summary_stats:
    print("📈 生成可视化图表...")
    create_comparison_plots(summary_stats)
else:
    print("⚠️ 无法生成图表，缺少实验数据")


In [None]:
# 保存实验结果和展示样例回复
def save_and_display_results(results, filename="embedding_experiment_results.json"):
    """保存实验结果并展示样例回复"""
    
    if not results:
        print("⚠️ 没有结果可保存")
        return
    
    # 保存结果到JSON文件
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    print(f"💾 实验结果已保存到: {filename}")
    
    # 展示第一个实验的详细回复
    if results:
        first_experiment = results[0]
        print(f"\n📝 展示样例回复 - {first_experiment['dialog_title']}")
        print(f"🔍 测试查询: {first_experiment['test_query']}")
        print("\n" + "="*80)
        
        for method, result in first_experiment['results'].items():
            print(f"\n🔹 {method.upper()}:")
            print(f"延迟: {result['latency']:.3f}s | 输入tokens: {result['input_tokens']} | 综合得分: {result['overall_score']:.3f}")
            print(f"回复: {result['response'][:200]}...")
            print("-" * 60)

# 保存并展示结果
if all_experiment_results:
    save_and_display_results(all_experiment_results)
    
    # 创建结果DataFrame用于更好的分析
    df_results = []
    for exp in all_experiment_results:
        for method, result in exp['results'].items():
            row = {
                'dialog': exp['dialog_title'],
                'method': method,
                'latency': result['latency'],
                'input_tokens': result['input_tokens'],
                'output_tokens': result['output_tokens'],
                'relevance': result['relevance'],
                'coherence': result['coherence'],
                'informativeness': result['informativeness'],
                'overall_score': result['overall_score']
            }
            df_results.append(row)
    
    df = pd.DataFrame(df_results)
    print("\n📊 结果DataFrame概览:")
    print(df.groupby('method')[['latency', 'input_tokens', 'overall_score']].mean())
    
else:
    print("⚠️ 没有实验结果可保存")


In [None]:
# 导入并测试增强版embedding注入器
from enhanced_embedding_injector import AdvancedEmbeddingInjector

# 创建增强版注入器
if model_manager.dialog_model and model_manager.tokenizer:
    advanced_injector = AdvancedEmbeddingInjector(model_manager)
    print("🚀 增强版Embedding注入器创建完成")
    
    # 测试单个对话的多种注入策略
    if test_dialogs:
        test_dialog = test_dialogs[0]  # 使用第一个测试对话
        
        # 构建历史文本
        history_text = ""
        for turn in test_dialog['history']:
            role = "用户" if turn['role'] == 'user' else "助手"
            history_text += f"{role}: {turn['content']}\n"
        
        query_text = test_dialog['test_query']
        
        print(f"\n🧪 测试高级注入策略:")
        print(f"   对话主题: {test_dialog['title']}")
        print(f"   历史长度: {len(history_text)} 字符")
        print(f"   测试查询: {query_text[:80]}...")
        
        # 对比所有注入策略
        print("\n🔄 运行多策略对比...")
        advanced_results = advanced_injector.compare_injection_strategies(
            history_text, query_text
        )
        
        # 展示结果
        print("\n📊 高级注入策略结果:")
        print("="*80)
        
        for strategy, result in advanced_results.items():
            if 'error' not in result:
                print(f"\n🔹 {strategy.upper()}:")
                print(f"   生成时间: {result['generation_time']:.3f}s")
                print(f"   输入tokens: {result['input_tokens']}")
                print(f"   输出tokens: {result['output_tokens']}")
                
                if 'metadata' in result:
                    metadata = result['metadata']
                    print(f"   相似度: {metadata['embedding_similarity']:.3f}")
                    print(f"   融合策略: {metadata['injection_strategy']}")
                
                print(f"   回复: {result['response'][:150]}...")
                print("-" * 60)
            else:
                print(f"\n❌ {strategy}: {result['error']}")
        
        # 保存高级实验结果
        with open("advanced_injection_results.json", 'w', encoding='utf-8') as f:
            json.dump(advanced_results, f, ensure_ascii=False, indent=2)
        print(f"\n💾 高级注入实验结果已保存")
        
else:
    print("⚠️ 模型未加载，无法测试增强版注入器")
