In [None]:
# %% [markdown]
# # RAG系统性能评估框架
# 
# 基于提供的论文数据结构，实现完整的RAG系统评估

# %%
# 首先安装必要的库（如果还没安装）
import sys
!{sys.executable} -m pip install numpy==1.26.4 sentence-transformers scikit-learn pandas matplotlib seaborn tqdm --quiet

# 设置Matplotlib后端，避免matplotlib_inline冲突
import matplotlib
matplotlib.use('Agg')  # 使用非交互式后端

# 导入其他库
import json
import numpy as np
import pandas as pd
import requests
import time
from typing import List, Dict, Any, Tuple, Optional
import re
from dataclasses import dataclass, field
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity
import hashlib
from tqdm.auto import tqdm
import warnings
import threading
warnings.filterwarnings('ignore')

# 尝试导入可视化库，如果失败则跳过
try:
    import matplotlib.pyplot as plt
    import seaborn as sns
    VISUALIZATION_AVAILABLE = True
    # 设置中文显示
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
    plt.rcParams['axes.unicode_minus'] = False
except ImportError:
    VISUALIZATION_AVAILABLE = False
    print("警告：无法导入matplotlib/seaborn，可视化功能将禁用")

# %%
# ==================== 配置区域 ====================
class Config:
    """配置参数"""
    # API配置
    DEEPSEEK_API_KEY = "sk-79990d599cd74bc0a56f6ca2f200a621"  # 替换为你的API密钥
    API_BASE_URL = "https://api.deepseek.com/v1"
    
    # 向量数据库配置
    EMBEDDING_MODEL = "all-MiniLM-L6-v2"  # 更小的模型，减少内存使用
    TOP_K_RETRIEVAL = 5  # 检索的文档数量
    SIMILARITY_THRESHOLD = 0.7  # 相似度阈值
    
    # 评估配置
    MAX_ANSWER_LENGTH = 1500
    TEMPERATURE = 0.1
    
    # 文件路径
    DATA_FILE = "cleaned_papers.jsonl"
    REPORT_FILE = "rag_evaluation_report.json"
    EVALUATION_CSV = "evaluation_results.csv"
    
    # 测试问题
    TEST_QUESTIONS = [
        {
            'question': '什么是医学影像中的开放词汇目标检测？它解决了什么问题？',
            'topic': 'computer vision, medical imaging',
            'difficulty': 'medium'
        },
        {
            'question': '最近在视频运动编辑方面有哪些新的技术突破？',
            'topic': 'computer vision, video editing',
            'difficulty': 'medium'
        },
        {
            'question': '对比学习在计算机视觉中有哪些应用？',
            'topic': 'machine learning, computer vision',
            'difficulty': 'medium'
        }
    ]

config = Config()

# %%
# ==================== 1. 数据加载和预处理 ====================

@dataclass
class Paper:
    """论文数据结构（适配您的JSONL格式）"""
    paper_id: str
    title: str
    abstract: str
    authors: List[str]
    first_author: str
    topic: str
    categories: List[str]
    publish_date: str
    url: str
    embedding_text: str
    quality_scores: Dict[str, float]
    quality_tier: str
    basic_keywords: List[str]
    domain_keywords: List[str]
    update_date: str = ""  # 添加update_date字段，设置默认值
    
    def __post_init__(self):
        """初始化后处理"""
        # 确保所有字段都是正确的类型
        if isinstance(self.authors, str):
            self.authors = [a.strip() for a in self.authors.split(',')]
        elif not isinstance(self.authors, list):
            self.authors = []
        
        if isinstance(self.categories, str):
            self.categories = [c.strip() for c in self.categories.split(',')]
        elif not isinstance(self.categories, list):
            self.categories = []
            
        if isinstance(self.basic_keywords, str):
            self.basic_keywords = [k.strip() for k in self.basic_keywords.split(',')]
        elif not isinstance(self.basic_keywords, list):
            self.basic_keywords = []
            
        if isinstance(self.domain_keywords, str):
            self.domain_keywords = [k.strip() for k in self.domain_keywords.split(',')]
        elif not isinstance(self.domain_keywords, list):
            self.domain_keywords = []
    
    def to_text(self) -> str:
        """将论文转换为文本用于嵌入"""
        # 使用embedding_text字段（您的数据中已有）
        if self.embedding_text:
            return self.embedding_text
        
        # 如果embedding_text不存在，构建文本
        text_parts = [
            f"Paper Title: {self.title}",
            f"Research Topic: {self.topic}",
            f"Authors: {', '.join(self.authors[:3])}",
            f"Abstract: {self.abstract[:500]}",
        ]
        
        if self.categories:
            text_parts.append(f"Categories: {', '.join(self.categories)}")
        
        if self.basic_keywords:
            text_parts.append(f"Keywords: {', '.join(self.basic_keywords[:5])}")
        
        return "\n".join(text_parts)
    
    def get_quality_score(self) -> float:
        """获取论文质量分数"""
        if self.quality_scores and 'overall_quality_score' in self.quality_scores:
            return self.quality_scores['overall_quality_score']
        return 0.5
    
    @classmethod
    def from_dict(cls, data: Dict) -> 'Paper':
        """从字典创建Paper对象（适配您的JSONL格式）"""
        # 提供默认值以避免KeyError
        defaults = {
            'paper_id': '',
            'title': '',
            'abstract': '',
            'authors': [],
            'first_author': '',
            'topic': '',
            'categories': [],
            'publish_date': '',
            'url': '',
            'embedding_text': '',
            'quality_scores': {},
            'quality_tier': 'medium',
            'basic_keywords': [],
            'domain_keywords': [],
            'update_date': ''  # 添加update_date默认值
        }
        
        # 合并数据与默认值，只保留Paper类定义的字段
        merged_data = {**defaults, **{k: v for k, v in data.items() if k in defaults}}
        
        return cls(**merged_data)

class DataLoader:
    """数据加载器"""
    
    @staticmethod
    def load_from_jsonl(file_path: str) -> List[Paper]:
        """从JSONL文件加载论文数据（适配您的格式）"""
        papers = []
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    if line.strip():
                        try:
                            data = json.loads(line.strip())
                            paper = Paper.from_dict(data)
                            papers.append(paper)
                        except json.JSONDecodeError as e:
                            print(f"第{line_num}行JSON解析错误: {e}")
                            continue
                        except Exception as e:
                            print(f"第{line_num}行数据转换错误: {e}")
                            continue
            
            print(f"✓ 从 {file_path} 成功加载 {len(papers)} 篇论文")
            
            # 打印统计信息
            if papers:
                DataLoader._print_statistics(papers)
            
        except FileNotFoundError:
            print(f"✗ 文件 {file_path} 不存在")
            # 创建一些示例数据用于测试
            papers = DataLoader.create_sample_data()
        
        return papers
    
    @staticmethod
    def create_sample_data() -> List[Paper]:
        """创建示例数据"""
        print("创建示例数据用于测试...")
        
        sample_papers = [
            Paper(
                paper_id="2511.20650",
                title="MedROV: Towards Real-Time Open-Vocabulary Detection Across Diverse Medical Imaging Modalities",
                abstract="Traditional object detection models in medical imaging operate within a closed-set paradigm...",
                authors=["Tooba Tehreem Sheikh", "Jean Lahoud", "Rao Muhammad Anwer"],
                first_author="Tooba Tehreem Sheikh",
                topic="artificial intelligence",
                categories=["cs.CV", "cs.AI"],
                publish_date="2025-11-25",
                url="http://arxiv.org/abs/2511.20650",
                embedding_text="Paper Title: MedROV: Towards Real-Time Open-Vocabulary Detection...",
                quality_scores={"overall_quality_score": 0.975},
                quality_tier="high",
                basic_keywords=["detection", "medical", "imaging"],
                domain_keywords=["object detection", "medical imaging"]
            ),
            Paper(
                paper_id="2511.20640",
                title="MotionV2V: Editing Motion in a Video",
                abstract="While generative video models have achieved remarkable fidelity and consistency...",
                authors=["Ryan Burgert", "Charles Herrmann", "Forrester Cole"],
                first_author="Ryan Burgert",
                topic="artificial intelligence",
                categories=["cs.CV", "cs.AI", "cs.GR"],
                publish_date="2025-11-25",
                url="http://arxiv.org/abs/2511.20640",
                embedding_text="Paper Title: MotionV2V: Editing Motion in a Video...",
                quality_scores={"overall_quality_score": 0.9},
                quality_tier="high",
                basic_keywords=["video", "motion", "editing"],
                domain_keywords=["video editing", "motion control"]
            )
        ]
        
        print(f"✓ 创建了 {len(sample_papers)} 篇示例论文")
        return sample_papers
    
    @staticmethod
    def _print_statistics(papers: List[Paper]):
        """打印数据集统计信息"""
        print("\n数据集统计:")
        print("-" * 40)
        print(f"论文总数: {len(papers)}")
        
        # 主题分布
        topics = [p.topic for p in papers]
        unique_topics = set(topics)
        print(f"主题数量: {len(unique_topics)}")
        
        # 质量分布
        quality_tiers = [p.quality_tier for p in papers]
        tier_counts = {tier: quality_tiers.count(tier) for tier in set(quality_tiers)}
        print("\n质量等级分布:")
        for tier, count in tier_counts.items():
            print(f"  {tier}: {count} 篇 ({count/len(papers)*100:.1f}%)")
        
        # 作者统计
        all_authors = [author for p in papers for author in p.authors]
        unique_authors = set(all_authors)
        print(f"作者总数: {len(unique_authors)}")
        
        # 最新论文日期
        dates = [p.publish_date for p in papers if p.publish_date]
        if dates:
            print(f"最新论文日期: {max(dates)}")
    
    @staticmethod
    def analyze_topic_distribution(papers: List[Paper]) -> pd.DataFrame:
        """分析主题分布"""
        topic_counts = {}
        for paper in papers:
            topic = paper.topic
            topic_counts[topic] = topic_counts.get(topic, 0) + 1
        
        # 转换为DataFrame
        df = pd.DataFrame(list(topic_counts.items()), columns=['topic', 'count'])
        df = df.sort_values('count', ascending=False)
        
        return df

# %%
# ==================== 2. 嵌入和向量数据库 ====================

class EmbeddingModel:
    """嵌入模型封装"""
    
    def __init__(self, model_name: str = None, use_api: bool = False, api_key: str = None):
        """
        初始化嵌入模型
        
        Args:
            model_name: 本地模型名称
            use_api: 是否使用API
            api_key: API密钥
        """
        self.use_api = use_api
        self.api_key = api_key
        
        if use_api and api_key:
            self.model = None
            print("✓ 使用DeepSeek API进行嵌入")
        else:
            print(f"✓ 使用本地嵌入模型")
            try:
                from sentence_transformers import SentenceTransformer
                model_name = model_name or config.EMBEDDING_MODEL
                self.model = SentenceTransformer(model_name)
                print(f"  模型: {model_name}")
                print(f"  维度: {self.model.get_sentence_embedding_dimension()}")
            except ImportError:
                print("警告：无法导入sentence_transformers，使用简单嵌入")
                self.model = None
    
    def embed(self, texts: List[str]) -> np.ndarray:
        """生成嵌入向量"""
        if isinstance(texts, str):
            texts = [texts]
        
        if not texts:
            return np.array([])
        
        if self.use_api and self.api_key:
            return self._embed_api(texts)
        elif self.model:
            return self._embed_local(texts)
        else:
            # 回退到简单词向量
            return self._embed_simple(texts)
    
    def _embed_local(self, texts: List[str]) -> np.ndarray:
        """使用本地模型嵌入"""
        try:
            # 批量处理以避免内存问题
            batch_size = 32
            embeddings = []
            
            for i in range(0, len(texts), batch_size):
                batch = texts[i:i + batch_size]
                batch_embeddings = self.model.encode(batch, show_progress_bar=False)
                embeddings.append(batch_embeddings)
            
            return np.vstack(embeddings) if embeddings else np.array([])
        except Exception as e:
            print(f"本地嵌入失败: {e}")
            return self._embed_simple(texts)
    
    def _embed_api(self, texts: List[str]) -> np.ndarray:
        """使用DeepSeek API嵌入"""
        try:
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            }
            
            # 只处理前几个文本以避免API限制
            texts = texts[:10] if len(texts) > 10 else texts
            
            data = {
                "model": "text-embedding-3-small",
                "input": texts,
                "encoding_format": "float"
            }
            
            response = requests.post(
                f"{config.API_BASE_URL}/embeddings",
                headers=headers,
                json=data,
                timeout=30
            )
            response.raise_for_status()
            
            result = response.json()
            embeddings = [item["embedding"] for item in result["data"]]
            return np.array(embeddings)
            
        except Exception as e:
            print(f"API嵌入失败: {e}")
            # 回退到本地模型
            if self.model:
                return self._embed_local(texts)
            else:
                return self._embed_simple(texts)
    
    def _embed_simple(self, texts: List[str]) -> np.ndarray:
        """简单的词向量嵌入（回退方案）"""
        print("使用简单嵌入方法...")
        # 创建简单的词频向量
        vocab = {}
        for text in texts:
            words = text.lower().split()
            for word in words:
                if word not in vocab:
                    vocab[word] = len(vocab)
        
        embeddings = []
        for text in texts:
            vector = np.zeros(len(vocab))
            words = text.lower().split()
            for word in words:
                if word in vocab:
                    vector[vocab[word]] += 1
            # 归一化
            if np.linalg.norm(vector) > 0:
                vector = vector / np.linalg.norm(vector)
            embeddings.append(vector)
        
        return np.array(embeddings)
    
    def get_dimension(self) -> int:
        """获取嵌入维度"""
        if self.use_api:
            return 1536  # DeepSeek text-embedding-3-small的维度
        elif self.model:
            return self.model.get_sentence_embedding_dimension()
        return 100  # 简单嵌入的维度

class VectorStore:
    """向量数据库"""
    
    def __init__(self, embedder: EmbeddingModel):
        self.embedder = embedder
        self.documents: List[str] = []
        self.metadata: List[Dict] = []
        self.embeddings: np.ndarray = None
    
    def add_papers(self, papers: List[Paper], use_embedding_text: bool = True):
        """添加论文到向量库"""
        print(f"开始处理 {len(papers)} 篇论文...")
        
        for paper in tqdm(papers, desc="添加论文"):
            # 转换为文本
            if use_embedding_text and paper.embedding_text:
                doc_text = paper.embedding_text
            else:
                doc_text = paper.to_text()
            
            # 存储文档和元数据
            self.documents.append(doc_text)
            self.metadata.append({
                'paper_id': paper.paper_id,
                'title': paper.title,
                'authors': paper.authors,
                'first_author': paper.first_author,
                'topic': paper.topic,
                'categories': paper.categories,
                'publish_date': paper.publish_date,
                'quality_score': paper.get_quality_score(),
                'quality_tier': paper.quality_tier
            })
        
        # 生成嵌入向量
        print("生成嵌入向量...")
        self.embeddings = self.embedder.embed(self.documents)
        
        print(f"✓ 向量库构建完成")
        print(f"  文档数量: {len(self.documents)}")
        print(f"  嵌入维度: {self.embeddings.shape[1]}")
    
    def search(self, query: str, top_k: int = None, threshold: float = None) -> List[Dict]:
        """语义搜索"""
        if top_k is None:
            top_k = config.TOP_K_RETRIEVAL
        if threshold is None:
            threshold = config.SIMILARITY_THRESHOLD
        
        if len(self.documents) == 0 or self.embeddings is None:
            print("警告：向量库为空")
            return []
        
        # 查询嵌入
        query_embedding = self.embedder.embed(query)
        if query_embedding.ndim == 1:
            query_embedding = query_embedding.reshape(1, -1)
        
        # 计算相似度
        similarities = cosine_similarity(query_embedding, self.embeddings)[0]
        
        # 获取最相似的文档
        indices = np.argsort(similarities)[::-1]
        
        results = []
        for idx in indices:
            similarity = float(similarities[idx])
            
            # 应用阈值过滤
            if similarity < threshold and len(results) >= top_k:
                continue
            
            result = {
                'paper_id': self.metadata[idx]['paper_id'],
                'document': self.documents[idx],
                'metadata': self.metadata[idx],
                'similarity': similarity,
                'rank': len(results) + 1
            }
            results.append(result)
            
            if len(results) >= top_k:
                break
        
        return results
    
    def get_stats(self) -> Dict:
        """获取统计信息"""
        if len(self.documents) == 0:
            return {'total_documents': 0}
        
        stats = {
            'total_documents': len(self.documents),
            'embedding_dimension': self.embeddings.shape[1] if self.embeddings is not None else 0,
            'unique_topics': len(set(m['topic'] for m in self.metadata))
        }
        
        return stats

# %%
# ==================== 3. RAG系统 ====================

class DeepSeekClient:
    """DeepSeek API客户端"""
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.base_url = config.API_BASE_URL
        self.total_tokens = 0
    
    def generate_response(self, prompt: str, temperature: float = None, 
                         max_tokens: int = None, model: str = "deepseek-chat") -> str:
        """生成回答"""
        if temperature is None:
            temperature = config.TEMPERATURE
        if max_tokens is None:
            max_tokens = config.MAX_ANSWER_LENGTH
        
        # 如果没有API密钥，返回模拟回答
        if not self.api_key or self.api_key == "your-deepseek-api-key":
            print("警告：使用模拟API响应（请设置正确的API密钥）")
            return self._mock_response(prompt)
        
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        data = {
            "model": model,
            "messages": [{"role": "user", "content": prompt}],
            "temperature": temperature,
            "max_tokens": max_tokens,
            "stream": False
        }
        
        try:
            response = requests.post(
                f"{self.base_url}/chat/completions",
                headers=headers,
                json=data,
                timeout=60
            )
            response.raise_for_status()
            
            result = response.json()
            
            # 记录token使用情况
            if 'usage' in result:
                self.total_tokens += result['usage']['total_tokens']
            
            return result["choices"][0]["message"]["content"]
            
        except requests.exceptions.RequestException as e:
            print(f"API请求失败: {e}")
            return f"错误：无法生成回答 ({str(e)})"
        except KeyError as e:
            print(f"API响应解析失败: {e}")
            return "错误：响应格式不正确"
    
    def _mock_response(self, prompt: str) -> str:
        """模拟API响应"""
        time.sleep(0.5)  # 模拟延迟
        
        # 根据提示内容生成模拟回答
        if "医学影像" in prompt or "medical" in prompt.lower():
            return """医学影像中的开放词汇目标检测（Open-Vocabulary Object Detection, OVOD）是一种能够检测训练时未见过的目标类别的技术。

它解决了传统目标检测的以下问题：
1. 封闭词汇限制：传统方法只能检测训练集中出现的类别
2. 数据稀缺问题：医学影像标注数据难以获取
3. 泛化能力：能够识别新的病变或解剖结构

MedROV是这方面的最新研究，实现了实时开放词汇检测[论文1]。"""
        elif "视频运动编辑" in prompt or "video" in prompt.lower():
            return """最近在视频运动编辑方面的技术突破包括：

1. MotionV2V模型：通过编辑稀疏轨迹来修改视频运动[论文2]
2. 运动反事实生成：创建内容相同但运动不同的视频对
3. 时间戳控制：可以从任何时间点开始编辑并自然传播
4. 用户研究表明，MotionV2V在对比测试中获得超过65%的偏好率"""
        else:
            return "这是一个模拟回答。实际使用需要设置正确的DeepSeek API密钥。"
    
    def evaluate_answer(self, question: str, answer: str) -> Dict:
        """使用DeepSeek评估回答质量"""
        # 如果没有API密钥，返回模拟评估
        if not self.api_key or self.api_key == "your-deepseek-api-key":
            return self._mock_evaluation(question, answer)
        
        prompt = f"""请评估以下回答的质量：

问题：{question}

回答：{answer}

请从以下维度给出1-5分的评分（5分为最佳）：
1. 准确性：回答内容是否准确无误
2. 完整性：是否全面回答了问题
3. 相关性：是否与问题紧密相关
4. 具体性：是否包含具体细节和例子
5. 整体质量：综合评分

请以JSON格式返回结果，包含每个维度的分数和简要理由，以及一个总分（各项得分的平均值）。"""
        
        try:
            response = self.generate_response(prompt, temperature=0.1, max_tokens=800)
            
            # 提取JSON部分
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                evaluation = json.loads(json_match.group())
                return evaluation
            else:
                return {
                    "accuracy": {"score": 3, "reason": "无法解析评估结果"},
                    "completeness": {"score": 3, "reason": "无法解析评估结果"},
                    "relevance": {"score": 3, "reason": "无法解析评估结果"},
                    "specificity": {"score": 3, "reason": "无法解析评估结果"},
                    "overall_quality": {"score": 3, "reason": "无法解析评估结果"},
                    "total_score": 3.0
                }
                
        except Exception as e:
            print(f"评估失败: {e}")
            return self._mock_evaluation(question, answer)
    
    def _mock_evaluation(self, question: str, answer: str) -> Dict:
        """模拟评估结果"""
        return {
            "accuracy": {"score": 4, "reason": "回答内容基本准确"},
            "completeness": {"score": 3, "reason": "回答了主要问题但不够全面"},
            "relevance": {"score": 4, "reason": "与问题高度相关"},
            "specificity": {"score": 3, "reason": "包含一些具体信息"},
            "overall_quality": {"score": 3.5, "reason": "整体质量良好"},
            "total_score": 3.5
        }

class RAGSystem:
    """RAG系统"""
    
    def __init__(self, vector_store: VectorStore, llm_client: DeepSeekClient):
        self.vector_store = vector_store
        self.llm = llm_client
    
    def query(self, question: str, top_k: int = None, include_context: bool = True) -> Dict:
        """执行RAG查询"""
        
        # 1. 检索相关文档
        retrieved_docs = self.vector_store.search(question, top_k=top_k)
        
        # 2. 构建提示词
        if include_context and retrieved_docs:
            context = self._build_context(retrieved_docs)
            prompt = self._build_rag_prompt(question, context)
            method = "RAG"
        else:
            prompt = self._build_baseline_prompt(question)
            method = "Baseline"
        
        # 3. 生成回答
        answer = self.llm.generate_response(prompt)
        
        return {
            'question': question,
            'answer': answer,
            'retrieved_docs': retrieved_docs,
            'method': method,
            'prompt_preview': prompt[:200] + "..." if len(prompt) > 200 else prompt,
        }
    
    def _build_context(self, docs: List[Dict]) -> str:
        """构建上下文"""
        context_parts = ["基于以下研究论文信息："]
        
        for i, doc in enumerate(docs, 1):
            meta = doc['metadata']
            context_parts.append(
                f"[论文{i}] {meta['title']}\n"
                f"作者: {meta['first_author']}等\n"
                f"摘要: {self._truncate_text(meta.get('abstract', doc['document']), 200)}\n"
                f"相关度: {doc['similarity']:.3f}"
            )
        
        return "\n\n".join(context_parts)
    
    def _build_rag_prompt(self, question: str, context: str) -> str:
        """构建RAG提示词"""
        return f"""你是一个AI研究助手，请基于提供的学术文献回答用户的问题。

可用文献：
{context}

用户问题：{question}

请按照以下要求回答：
1. 主要基于提供的文献信息进行回答
2. 在回答中引用相关文献，格式为[论文1]、[论文2]等
3. 如果文献信息不足，可以适当补充相关知识
4. 保持学术严谨性

请提供详细、准确的回答："""
    
    def _build_baseline_prompt(self, question: str) -> str:
        """构建基线提示词（无RAG）"""
        return f"""你是一个AI研究助手，请回答以下学术问题。

问题：{question}

请提供详细、准确的回答："""
    
    @staticmethod
    def _truncate_text(text: str, max_length: int) -> str:
        """截断文本"""
        if len(text) <= max_length:
            return text
        return text[:max_length] + "..."

# %%
# ==================== 4. 双评估系统 ====================

class AutoEvaluationMetrics:
    """自动评估指标计算"""
    
    def evaluate_response(self, 
                         question: str,
                         answer: str,
                         retrieved_docs: List[Dict] = None,
                         baseline_answer: str = None) -> Dict:
        """评估单个响应"""
        
        metrics = {
            'answer_length': len(answer),
            'word_count': len(re.findall(r'\w+', answer)),
            'has_error': 1 if "错误：" in answer or "error" in answer.lower() else 0,
        }
        
        # 检索相关指标
        if retrieved_docs:
            metrics.update(self._calculate_retrieval_metrics(retrieved_docs))
        
        # 内容质量指标
        metrics.update(self._calculate_content_metrics(answer, question))
        
        # 引用质量指标
        metrics.update(self._calculate_citation_metrics(answer))
        
        # 与基线对比
        if baseline_answer:
            metrics.update(self._calculate_comparison_metrics(answer, baseline_answer))
        
        # 计算自动评估总分
        metrics['auto_score'] = self._calculate_overall_score(metrics)
        
        return metrics
    
    def _calculate_retrieval_metrics(self, retrieved_docs: List[Dict]) -> Dict:
        """计算检索相关指标"""
        if not retrieved_docs:
            return {}
        
        similarities = [doc['similarity'] for doc in retrieved_docs]
        return {
            'retrieved_docs_count': len(retrieved_docs),
            'avg_similarity': np.mean(similarities),
            'max_similarity': max(similarities),
        }
    
    def _calculate_content_metrics(self, answer: str, question: str) -> Dict:
        """计算内容质量指标"""
        # 技术术语检测
        technical_terms = ['模型', '算法', '检测', '学习', '训练', '精度', '准确率']
        tech_term_count = sum(1 for term in technical_terms if term in answer)
        
        # 问题关键词匹配
        question_words = set(re.findall(r'\w+', question.lower()))
        answer_words = set(re.findall(r'\w+', answer.lower()))
        keyword_matches = len(question_words.intersection(answer_words))
        
        return {
            'technical_terms': tech_term_count,
            'keyword_matches': keyword_matches,
            'has_citation': 1 if '[' in answer and ']' in answer else 0,
        }
    
    def _calculate_citation_metrics(self, answer: str) -> Dict:
        """计算引用质量指标"""
        # 检测引用
        citations = re.findall(r'\[.*?\d+.*?\]', answer)
        return {
            'citation_count': len(citations),
            'unique_citations': len(set(citations)),
        }
    
    def _calculate_comparison_metrics(self, rag_answer: str, baseline_answer: str) -> Dict:
        """计算对比指标"""
        rag_length = len(rag_answer)
        baseline_length = len(baseline_answer)
        
        return {
            'length_ratio': rag_length / max(baseline_length, 1),
            'length_difference': rag_length - baseline_length,
        }
    
    def _calculate_overall_score(self, metrics: Dict) -> float:
        """计算自动评估总分"""
        score = 0.0
        
        # 长度分数（适中为佳）
        length = metrics.get('answer_length', 0)
        if 200 <= length <= 800:
            score += 0.3
        elif length > 50:
            score += 0.2
        
        # 技术术语分数
        tech_terms = metrics.get('technical_terms', 0)
        score += min(tech_terms * 0.1, 0.3)
        
        # 引用分数
        citations = metrics.get('citation_count', 0)
        score += min(citations * 0.2, 0.3)
        
        # 相似度分数
        avg_sim = metrics.get('avg_similarity', 0)
        score += avg_sim * 0.1
        
        # 无错误加分
        if metrics.get('has_error', 1) == 0:
            score += 0.1
        
        return min(score, 1.0)

class CombinedEvaluator:
    """组合评估器（自动评估 + API评估）"""
    
    def __init__(self, rag_system: RAGSystem, llm_client: DeepSeekClient):
        self.rag_system = rag_system
        self.llm_client = llm_client
        self.auto_evaluator = AutoEvaluationMetrics()
        self.results = []
    
    def evaluate_question(self, question_data: Dict) -> Dict:
        """评估单个问题"""
        question = question_data['question']
        topic = question_data.get('topic', '')
        
        print(f"\n评估问题: {question}")
        
        # RAG查询
        rag_response = self.rag_system.query(question, include_context=True)
        
        # 基线查询
        baseline_response = self.rag_system.query(question, include_context=False)
        
        # 自动评估（RAG）
        auto_metrics_rag = self.auto_evaluator.evaluate_response(
            question=question,
            answer=rag_response['answer'],
            retrieved_docs=rag_response['retrieved_docs'],
            baseline_answer=baseline_response['answer']
        )
        
        # 自动评估（基线）
        auto_metrics_baseline = self.auto_evaluator.evaluate_response(
            question=question,
            answer=baseline_response['answer'],
            retrieved_docs=None,  # 基线没有检索文档
            baseline_answer=None  # 基线没有对比对象
        )
        
        # API评估（RAG回答）
        api_evaluation_rag = self.llm_client.evaluate_answer(question, rag_response['answer'])
        
        # API评估（基线回答）
        api_evaluation_baseline = self.llm_client.evaluate_answer(question, baseline_response['answer'])
        
        # 构建结果
        result = {
            'question_id': len(self.results) + 1,
            'question': question,
            'topic': topic,
            
            # RAG结果
            'rag_answer': rag_response['answer'],
            'rag_method': rag_response['method'],
            'rag_retrieved_docs': [
                {
                    'title': doc['metadata']['title'],
                    'similarity': float(doc['similarity']),
                    'first_author': doc['metadata']['first_author']
                }
                for doc in rag_response['retrieved_docs']
            ],
            
            # 基线结果
            'baseline_answer': baseline_response['answer'],
            'baseline_method': baseline_response['method'],
            
            # 自动评估结果（RAG）
            'auto_evaluation_rag': auto_metrics_rag,
            'auto_score_rag': auto_metrics_rag.get('auto_score', 0),
            
            # 自动评估结果（基线）
            'auto_evaluation_baseline': auto_metrics_baseline,
            'auto_score_baseline': auto_metrics_baseline.get('auto_score', 0),
            
            # API评估结果
            'api_evaluation_rag': api_evaluation_rag,
            'api_evaluation_baseline': api_evaluation_baseline,
            'api_score_rag': api_evaluation_rag.get('total_score', 0),
            'api_score_baseline': api_evaluation_baseline.get('total_score', 0),
            
            # 综合比较
            'auto_improvement': auto_metrics_rag.get('auto_score', 0) - auto_metrics_baseline.get('auto_score', 0),
            'api_improvement': api_evaluation_rag.get('total_score', 0) - api_evaluation_baseline.get('total_score', 0),
        }
        
        self.results.append(result)
        
        # 打印简要结果
        print(f"  自动评分: RAG={auto_metrics_rag.get('auto_score', 0):.2f}, 基线={auto_metrics_baseline.get('auto_score', 0):.2f}")
        print(f"  API评分: RAG={api_evaluation_rag.get('total_score', 0):.2f}, 基线={api_evaluation_baseline.get('total_score', 0):.2f}")
        print(f"  检索文档: {len(rag_response['retrieved_docs'])}篇")
        
        return result
    
    def evaluate_all(self, questions: List[Dict]) -> List[Dict]:
        """评估所有问题"""
        print(f"开始评估 {len(questions)} 个问题...")
        
        for question_data in questions:
            self.evaluate_question(question_data)
        
        print(f"\n✓ 完成所有评估")
        return self.results
    
    def generate_report(self) -> Dict:
        """生成评估报告"""
        if not self.results:
            return {}
        
        # 收集统计信息
        auto_scores_rag = [r['auto_score_rag'] for r in self.results]
        auto_scores_baseline = [r['auto_score_baseline'] for r in self.results]
        api_scores_rag = [r['api_score_rag'] for r in self.results]
        api_scores_baseline = [r['api_score_baseline'] for r in self.results]
        auto_improvements = [r['auto_improvement'] for r in self.results]
        api_improvements = [r['api_improvement'] for r in self.results]
        
        # 计算相关性（自动评估 vs API评估）
        if len(auto_scores_rag) > 1:
            correlation = np.corrcoef(auto_scores_rag, api_scores_rag)[0, 1]
        else:
            correlation = 0
        
        report = {
            'summary': {
                'total_questions': len(self.results),
                'avg_auto_score_rag': float(np.mean(auto_scores_rag)),
                'avg_auto_score_baseline': float(np.mean(auto_scores_baseline)),
                'avg_api_score_rag': float(np.mean(api_scores_rag)),
                'avg_api_score_baseline': float(np.mean(api_scores_baseline)),
                'avg_auto_improvement': float(np.mean(auto_improvements)),
                'avg_api_improvement': float(np.mean(api_improvements)),
                'auto_improvement_rate': sum(1 for imp in auto_improvements if imp > 0) / len(auto_improvements),
                'api_improvement_rate': sum(1 for imp in api_improvements if imp > 0) / len(api_improvements),
                'correlation_auto_vs_api': float(correlation),
                'total_tokens_used': self.llm_client.total_tokens
            },
            'detailed_results': self.results
        }
        
        return report
    
    def save_results(self):
        """保存结果到文件"""
        # 保存JSON报告
        report = self.generate_report()
        with open(config.REPORT_FILE, 'w', encoding='utf-8') as f:
            json.dump(report, f, ensure_ascii=False, indent=2)
        print(f"✓ 评估报告已保存到 {config.REPORT_FILE}")
        
        # 保存CSV结果
        df_data = []
        for result in self.results:
            row = {
                'question_id': result['question_id'],
                'question': result['question'],
                'topic': result['topic'],
                'rag_answer_length': len(result['rag_answer']),
                'baseline_answer_length': len(result['baseline_answer']),
                'auto_score_rag': result['auto_score_rag'],
                'auto_score_baseline': result['auto_score_baseline'],
                'api_score_rag': result['api_score_rag'],
                'api_score_baseline': result['api_score_baseline'],
                'auto_improvement': result['auto_improvement'],
                'api_improvement': result['api_improvement'],
                'retrieved_docs_count': len(result['rag_retrieved_docs']),
            }
            
            # 添加RAG自动评估指标
            for key, value in result['auto_evaluation_rag'].items():
                if isinstance(value, (int, float)):
                    row[f'auto_rag_{key}'] = value
            
            # 添加基线自动评估指标
            for key, value in result['auto_evaluation_baseline'].items():
                if isinstance(value, (int, float)):
                    row[f'auto_baseline_{key}'] = value
            
            df_data.append(row)
        
        df = pd.DataFrame(df_data)
        df.to_csv(config.EVALUATION_CSV, index=False, encoding='utf-8-sig')
        print(f"✓ 详细结果已保存到 {config.EVALUATION_CSV}")
        
        return df

# %%
# ==================== 5. 可视化和主程序 ====================

def create_visualizations(results: List[Dict], save_dir: str = "."):
    """创建可视化图表"""
    if not VISUALIZATION_AVAILABLE:
        print("警告：可视化库不可用，跳过图表生成")
        return
    
    try:
        # 准备数据
        df = pd.DataFrame([{
            'question_id': r['question_id'],
            'auto_score_rag': r['auto_score_rag'],
            'auto_score_baseline': r['auto_score_baseline'],
            'api_score_rag': r['api_score_rag'],
            'api_score_baseline': r['api_score_baseline'],
            'auto_improvement': r['auto_improvement'],
            'api_improvement': r['api_improvement'],
            'retrieved_docs': len(r['rag_retrieved_docs'])
        } for r in results])
        
        # 创建图表
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # 1. API评分对比
        x = range(len(df))
        width = 0.35
        axes[0, 0].bar([i - width/2 for i in x], df['api_score_baseline'], width, label='基线', alpha=0.7)
        axes[0, 0].bar([i + width/2 for i in x], df['api_score_rag'], width, label='RAG', alpha=0.7)
        axes[0, 0].set_xlabel('问题编号')
        axes[0, 0].set_ylabel('API评分')
        axes[0, 0].set_title('RAG vs 基线 API评分对比')
        axes[0, 0].legend()
        axes[0, 0].set_xticks(x)
        axes[0, 0].set_xticklabels(df['question_id'])
        
        # 2. 自动评分对比
        axes[0, 1].bar([i - width/2 for i in x], df['auto_score_baseline'], width, label='基线', alpha=0.7)
        axes[0, 1].bar([i + width/2 for i in x], df['auto_score_rag'], width, label='RAG', alpha=0.7)
        axes[0, 1].set_xlabel('问题编号')
        axes[0, 1].set_ylabel('自动评分')
        axes[0, 1].set_title('RAG vs 基线 自动评分对比')
        axes[0, 1].legend()
        axes[0, 1].set_xticks(x)
        axes[0, 1].set_xticklabels(df['question_id'])
        
        # 3. 改进分布（API）
        axes[1, 0].bar(x, df['api_improvement'], color='green' if df['api_improvement'].mean() > 0 else 'red')
        axes[1, 0].axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        axes[1, 0].set_xlabel('问题编号')
        axes[1, 0].set_ylabel('API改进分数')
        axes[1, 0].set_title('RAG API改进分数分布')
        axes[1, 0].set_xticks(x)
        axes[1, 0].set_xticklabels(df['question_id'])
        
        # 4. 改进分布（自动）
        axes[1, 1].bar(x, df['auto_improvement'], color='blue' if df['auto_improvement'].mean() > 0 else 'red')
        axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        axes[1, 1].set_xlabel('问题编号')
        axes[1, 1].set_ylabel('自动评估改进分数')
        axes[1, 1].set_title('RAG 自动评估改进分数分布')
        axes[1, 1].set_xticks(x)
        axes[1, 1].set_xticklabels(df['question_id'])
        
        plt.tight_layout()
        plt.savefig(f'{save_dir}/evaluation_visualization.png', dpi=150, bbox_inches='tight')
        plt.close()  # 关闭图表，避免内存泄漏
        print(f"✓ 可视化图表已保存到 {save_dir}/evaluation_visualization.png")
        
    except Exception as e:
        print(f"可视化生成失败: {e}")
        print("跳过可视化生成")

def print_detailed_report(report: Dict):
    """打印详细报告"""
    summary = report['summary']
    
    print("\n" + "="*60)
    print("RAG系统评估报告")
    print("="*60)
    
    print(f"\n总体统计:")
    print(f"  评估问题总数: {summary['total_questions']}")
    print(f"  平均自动评估分数 (RAG): {summary['avg_auto_score_rag']:.3f}")
    print(f"  平均自动评估分数 (基线): {summary['avg_auto_score_baseline']:.3f}")
    print(f"  平均API评估分数 (RAG): {summary['avg_api_score_rag']:.3f}")
    print(f"  平均API评估分数 (基线): {summary['avg_api_score_baseline']:.3f}")
    print(f"  平均自动评估改进: {summary['avg_auto_improvement']:.3f}")
    print(f"  平均API评估改进: {summary['avg_api_improvement']:.3f}")
    print(f"  自动评估改进率: {summary['auto_improvement_rate']:.2%}")
    print(f"  API评估改进率: {summary['api_improvement_rate']:.2%}")
    print(f"  自动评估与API评估相关性: {summary['correlation_auto_vs_api']:.3f}")
    print(f"  总token使用量: {summary['total_tokens_used']}")
    
    # 打印每个问题的详细结果
    print(f"\n详细结果:")
    for result in report['detailed_results']:
        print(f"\n问题 {result['question_id']}: {result['question'][:50]}...")
        print(f"  自动评分: RAG={result['auto_score_rag']:.3f}, 基线={result['auto_score_baseline']:.3f}")
        print(f"  API评分: RAG={result['api_score_rag']:.3f}, 基线={result['api_score_baseline']:.3f}")
        print(f"  改进: 自动={result['auto_improvement']:.3f}, API={result['api_improvement']:.3f}")
        print(f"  检索文档: {len(result['rag_retrieved_docs'])}篇")
        
        # 显示检索到的文档
        if result['rag_retrieved_docs']:
            print(f"  相关文档:")
            for doc in result['rag_retrieved_docs'][:2]:  # 只显示前2个
                print(f"    - {doc['title'][:50]}... (相似度: {doc['similarity']:.3f})")

# %%
# ==================== 主程序 ====================

def main():
    """主函数"""
    print("="*60)
    print("RAG系统评估框架")
    print("="*60)
    
    # 1. 加载数据
    print("\n1. 加载论文数据...")
    papers = DataLoader.load_from_jsonl(config.DATA_FILE)
    
    if not papers:
        print("错误：没有加载到论文数据")
        return
    
    # 2. 初始化嵌入模型
    print("\n2. 初始化嵌入模型...")
    embedder = EmbeddingModel(
        model_name=config.EMBEDDING_MODEL,
        use_api=False,  # 使用本地模型以避免API限制
        api_key=config.DEEPSEEK_API_KEY
    )
    
    # 3. 构建向量数据库
    print("\n3. 构建向量数据库...")
    vector_store = VectorStore(embedder)
    vector_store.add_papers(papers)
    
    # 打印向量库统计
    stats = vector_store.get_stats()
    print(f"  文档总数: {stats['total_documents']}")
    print(f"  主题数量: {stats['unique_topics']}")
    
    # 4. 初始化RAG系统
    print("\n4. 初始化RAG系统...")
    llm_client = DeepSeekClient(config.DEEPSEEK_API_KEY)
    rag_system = RAGSystem(vector_store, llm_client)
    
    # 5. 运行评估
    print("\n5. 运行评估...")
    evaluator = CombinedEvaluator(rag_system, llm_client)
    results = evaluator.evaluate_all(config.TEST_QUESTIONS)
    
    # 6. 生成报告
    print("\n6. 生成报告...")
    report = evaluator.generate_report()
    
    # 7. 保存结果
    print("\n7. 保存结果...")
    df = evaluator.save_results()
    
    # 8. 创建可视化
    if VISUALIZATION_AVAILABLE and len(results) > 0:
        print("\n8. 创建可视化图表...")
        create_visualizations(results)
    
    # 9. 打印报告
    print("\n9. 评估报告:")
    print_detailed_report(report)
    
    # 10. 显示DataFrame
    print("\n10. 结果DataFrame:")
    print(df[['question_id', 'auto_score_rag', 'auto_score_baseline', 'api_score_rag', 'api_score_baseline', 'auto_improvement', 'api_improvement']].to_string())
    
    print("\n" + "="*60)
    print("评估完成！")
    print("="*60)

# %%
# 运行主程序
if __name__ == "__main__":
    main()

# %%
# 快速测试函数（可选）
def quick_test():
    """快速测试函数"""
    print("运行快速测试...")
    
    # 创建示例数据
    papers = DataLoader.create_sample_data()
    
    # 初始化组件
    embedder = EmbeddingModel(use_api=False)
    vector_store = VectorStore(embedder)
    vector_store.add_papers(papers)
    
    llm_client = DeepSeekClient(config.DEEPSEEK_API_KEY)
    rag_system = RAGSystem(vector_store, llm_client)
    
    # 测试单个问题
    test_question = "什么是医学影像中的开放词汇目标检测？"
    print(f"\n测试问题: {test_question}")
    
    response = rag_system.query(test_question)
    print(f"\nRAG回答预览: {response['answer'][:200]}...")
    
    # 评估
    evaluator = AutoEvaluationMetrics()
    metrics = evaluator.evaluate_response(
        question=test_question,
        answer=response['answer'],
        retrieved_docs=response['retrieved_docs']
    )
    
    print(f"\n自动评估结果:")
    print(f"  回答长度: {metrics['answer_length']}")
    print(f"  技术术语: {metrics.get('technical_terms', 0)}")
    print(f"  引用数量: {metrics.get('citation_count', 0)}")
    print(f"  自动评分: {metrics.get('auto_score', 0):.2f}")
    
    return response

# 运行快速测试（取消注释以下行）
# quick_test_response = quick_test()

RAG系统评估框架

1. 加载论文数据...
✓ 从 cleaned_papers.jsonl 成功加载 22 篇论文

数据集统计:
----------------------------------------
论文总数: 22
主题数量: 3

质量等级分布:
  high: 21 篇 (95.5%)
  medium: 1 篇 (4.5%)
作者总数: 128
最新论文日期: 2025-11-25

2. 初始化嵌入模型...
✓ 使用本地嵌入模型
  模型: all-MiniLM-L6-v2
  维度: 384

3. 构建向量数据库...
开始处理 22 篇论文...


添加论文: 100%|██████████| 22/22 [00:00<?, ?it/s]

生成嵌入向量...





✓ 向量库构建完成
  文档数量: 22
  嵌入维度: 384
  文档总数: 22
  主题数量: 3

4. 初始化RAG系统...

5. 运行评估...
开始评估 3 个问题...

评估问题: 什么是医学影像中的开放词汇目标检测？它解决了什么问题？
警告：使用模拟API响应（请设置正确的API密钥）
警告：使用模拟API响应（请设置正确的API密钥）
  自动评分: RAG=0.80, 基线=0.80
  API评分: RAG=3.50, 基线=3.50
  检索文档: 5篇

评估问题: 最近在视频运动编辑方面有哪些新的技术突破？
警告：使用模拟API响应（请设置正确的API密钥）
警告：使用模拟API响应（请设置正确的API密钥）
  自动评分: RAG=0.60, 基线=0.60
  API评分: RAG=3.50, 基线=3.50
  检索文档: 5篇

评估问题: 对比学习在计算机视觉中有哪些应用？
警告：使用模拟API响应（请设置正确的API密钥）
警告：使用模拟API响应（请设置正确的API密钥）
  自动评分: RAG=0.10, 基线=0.10
  API评分: RAG=3.50, 基线=3.50
  检索文档: 5篇

✓ 完成所有评估

6. 生成报告...

7. 保存结果...
✓ 评估报告已保存到 rag_evaluation_report.json
✓ 详细结果已保存到 evaluation_results.csv

8. 创建可视化图表...
✓ 可视化图表已保存到 ./evaluation_visualization.png

9. 评估报告:

RAG系统评估报告

总体统计:
  评估问题总数: 3
  平均自动评估分数 (RAG): 0.502
  平均自动评估分数 (基线): 0.500
  平均API评估分数 (RAG): 3.500
  平均API评估分数 (基线): 3.500
  平均自动评估改进: 0.002
  平均API评估改进: 0.000
  自动评估改进率: 100.00%
  API评估改进率: 0.00%
  自动评估与API评估相关性: nan
  总token使用量: 0

详细结果:

问题 1: 什么是医学影像中的开放词汇目标检测？它解决了什么问题？...
  自动评分: RAG=