In [2]:
import os
import numpy as np
from pymilvus import MilvusClient, DataType, CollectionSchema, FieldSchema
from sentence_transformers import SentenceTransformer
import json
from datetime import datetime
import torch


class MilvusLiteRAGRetriever:
    def __init__(self, db_path="./milvus_data.db", model_path="./embedding-model"):
        self.db_path = db_path
        
        # 加載模型到對應設備
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"RAG檢索器使用設備: {self.device}")
        self.model = SentenceTransformer(model_path).to(self.device)
        self.dimension = 768
        
        self.client = MilvusClient(db_path)
        self.collection_name = "testchunks"
        self.vector_field_name = "vector"
        
        # 檢查集合是否存在
        collections = self.client.list_collections()
        if self.collection_name not in collections:
            raise ValueError(f"集合 '{self.collection_name}' 不存在！請先運行數據插入程序。")
        
        print(f"RAG檢索器初始化完成，連接到集合: {self.collection_name}")
    
    def retrieve_for_rag(self, query_text, top_k=10, score_threshold=None, filter_condition=None):
        """
        為 RAG 系統檢索相關文檔
        
        Args:
            query_text (str): 查詢文本
            top_k (int): 返回的結果數量，默認10
            score_threshold (float, optional): 分數閾值，低於此分數的結果將被過濾
            filter_condition (str, optional): 過濾條件，例如 'folder == "某個文件夾"'
        
        Returns:
            dict: 包含檢索結果的字典
                - hits: 檢索到的文檔列表
                - scores: 對應的相似度分數列表
                - query: 原始查詢文本
                - total_hits: 總命中數
        """
        try:
            print(f"正在檢索查詢: '{query_text[:50]}...' (top_k={top_k})")
            
            # 生成查詢嵌入向量
            with torch.no_grad():
                query_embedding = self.model.encode(
                    [query_text], 
                    device=self.device,
                    convert_to_numpy=True,
                    show_progress_bar=False
                )[0].tolist()
            
            # FLAT 索引搜尋參數
            search_params = {
                "params": {}
            }
            
            # 執行向量搜索
            search_results = self.client.search(
                collection_name=self.collection_name,
                anns_field=self.vector_field_name,
                data=[query_embedding],
                filter=filter_condition,
                limit=top_k,
                output_fields=["text", "folder", "file", "timestamp"],
                search_params=search_params
            )
            
            # 處理搜索結果
            hits = []
            scores = []
            
            for hit in search_results[0]:
                score = hit["distance"]  # L2距離，越小越相似
                
                # 如果設置了分數閾值，過濾低分結果
                if score_threshold is not None and score > score_threshold:
                    continue
                
                # 構建結果項
                hit_item = {
                    "id": hit["id"],
                    "text": hit["entity"]["text"],
                    "source": f"{hit['entity']['folder']}/{hit['entity']['file']}",
                    "folder": hit["entity"]["folder"],
                    "file": hit["entity"]["file"],
                    "timestamp": hit["entity"]["timestamp"],
                    "score": score
                }
                
                hits.append(hit_item)
                scores.append(score)
            
            # 構建返回結果
            result = {
                "hits": hits,
                "scores": scores,
                "query": query_text,
                "total_hits": len(hits),
                "top_k_requested": top_k
            }
            
            print(f"檢索完成，找到 {len(hits)} 個相關結果")
            return result
            
        except Exception as e:
            print(f"檢索過程中發生錯誤: {e}")
            return {
                "hits": [],
                "scores": [],
                "query": query_text,
                "total_hits": 0,
                "top_k_requested": top_k,
                "error": str(e)
            }
    
    def batch_retrieve(self, queries, top_k=10, score_threshold=None):
        """
        批量檢索多個查詢
        
        Args:
            queries (list): 查詢文本列表
            top_k (int): 每個查詢返回的結果數量
            score_threshold (float, optional): 分數閾值
        
        Returns:
            list: 每個查詢的檢索結果列表
        """
        results = []
        for i, query in enumerate(queries):
            print(f"處理查詢 {i+1}/{len(queries)}")
            result = self.retrieve_for_rag(query, top_k, score_threshold)
            results.append(result)
        return results
    
    def get_retriever_stats(self):
        """獲取檢索器統計信息"""
        try:
            stats = self.client.get_collection_stats(self.collection_name)
            return {
                "collection_name": self.collection_name,
                "stats": stats,
                "device": self.device,
                "model_dimension": self.dimension
            }
        except Exception as e:
            return {"error": str(e)}
    
    def close(self):
        """關閉檢索器"""
        self.client.close()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print("RAG檢索器已關閉")


# 使用示例和測試函數
def test_rag_retriever():
    """測試 RAG 檢索器"""
    try:
        # 初始化檢索器
        retriever = MilvusLiteRAGRetriever("./milvus_data.db", "./embedding-model")
        
        # 獲取統計信息
        stats = retriever.get_retriever_stats()
        print(f"檢索器統計: {stats}")
        
        # 測試單個查詢
        test_queries = [
            "人工智能的發展趨勢",
            "機器學習算法",
            "深度學習應用",
            "自然語言處理技術"
        ]
        
        for query in test_queries:
            print(f"\n{'='*60}")
            print(f"測試查詢: {query}")
            
            # 執行檢索
            result = retriever.retrieve_for_rag(query, top_k=10)
            
            if result["total_hits"] > 0:
                print(f"找到 {result['total_hits']} 個相關結果:")
                for i, hit in enumerate(result["hits"][:5], 1):  # 只顯示前5個
                    print(f"{i}. 分數: {hit['score']:.4f}")
                    print(f"   來源: {hit['source']}")
                    print(f"   內容: {hit['text'][:100]}...")
                    print()
            else:
                print("未找到相關結果")
        
        # 測試批量檢索
        print(f"\n{'='*60}")
        print("測試批量檢索:")
        batch_results = retriever.batch_retrieve(test_queries[:2], top_k=5)
        for i, result in enumerate(batch_results):
            print(f"查詢 {i+1}: 找到 {result['total_hits']} 個結果")
        
    except Exception as e:
        print(f"測試過程中發生錯誤: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        if 'retriever' in locals():
            retriever.close()


# 簡化的 RAG 檢索函數（可直接導入使用）
def rag_retrieve(query_text, top_k=10, db_path="./milvus_data.db", model_path="./embedding-model"):
    """
    簡化的 RAG 檢索函數，可以直接導入使用
    
    Args:
        query_text (str): 查詢文本
        top_k (int): 返回結果數量
        db_path (str): Milvus 數據庫路徑
        model_path (str): 嵌入模型路徑
    
    Returns:
        dict: 檢索結果
    """
    retriever = None
    try:
        retriever = MilvusLiteRAGRetriever(db_path, model_path)
        result = retriever.retrieve_for_rag(query_text, top_k)
        return result
    except Exception as e:
        return {
            "hits": [],
            "scores": [],
            "query": query_text,
            "total_hits": 0,
            "error": str(e)
        }
    finally:
        if retriever:
            retriever.close()


if __name__ == "__main__":
    # 運行測試
    test_rag_retriever()


RAG檢索器使用設備: cuda
RAG檢索器初始化完成，連接到集合: testchunks
檢索器統計: {'collection_name': 'testchunks', 'stats': {'row_count': 567}, 'device': 'cuda', 'model_dimension': 768}

測試查詢: 人工智能的發展趨勢
正在檢索查詢: '人工智能的發展趨勢...' (top_k=10)
檢索完成，找到 10 個相關結果
找到 10 個相關結果:
1. 分數: 0.8554
   來源: documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path_merged_node_0_chunk_92_to_node_0_chunk_92.txt
   內容: Chunk path: ./chunks_output/documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path...

2. 分數: 0.9175
   來源: documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path_merged_node_0_chunk_8_to_node_0_chunk_8.txt
   內容: Chunk path: ./chunks_output/documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path...

3. 分數: 0.9766
   來源: documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path_merged_node_0_chunk_38_to_node_0_chunk_38.txt
   內容: Chunk path: ./chunks_output/documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path...

4. 分數:

In [None]:
from elasticsearch import Elasticsearch
from elasticsearch.exceptions import RequestError, ConnectionError

class ElasticsearchSearcher:
    def __init__(self, 
                 es_url="https://localhost:9200", 
                 username="elastic", 
                 password="", 
                 index_name="chunk_documents"):
        """
        初始化搜索器
        :param es_url: Elasticsearch 地址
        :param username: 用户名（默认 elastic）
        :param password: 密码
        :param index_name: 要搜索的索引名（默认与导入时一致）
        """
        self.es_url = es_url
        self.username = username
        self.password = password
        self.index_name = index_name
        self.es = self._connect()

    def _connect(self):
        """建立与 Elasticsearch 的连接"""
        try:
            es = Elasticsearch(
                [self.es_url],
                basic_auth=(self.username, self.password),
                verify_certs=False,  # 保持与之前一致的 SSL 配置
                ssl_show_warn=False
            )
            # 验证连接
            if es.ping():
                print(f"✅ 已连接到 Elasticsearch：{self.es_url}")
                return es
            else:
                print("❌ 连接失败：Elasticsearch 未响应")
                return None
        except ConnectionError:
            print(f"❌ 无法连接到 {self.es_url}，请检查服务是否启动")
            return None
        except Exception as e:
            print(f"❌ 连接错误：{e}")
            return None

    def search(self, keyword, size=10, preview_length=200):
        """
        搜索包含关键词的文档
        :param keyword: 搜索关键词（字符串）
        :param size: 返回结果数量（默认10条）
        :param preview_length: 内容预览长度（默认200字符）
        :return: 格式化的搜索结果列表
        """
        if not self.es:
            print("⚠️ 未建立有效连接，请检查配置")
            return []

        try:
            # 构造搜索查询（在 content 字段中搜索关键词）
            query = {
                "query": {
                    "match": {
                        "content": keyword  # 搜索内容字段
                    }
                },
                "size": size
            }

            # 执行搜索
            response = self.es.search(index=self.index_name, body=query)
            hits = response["hits"]["hits"]
            total = response["hits"]["total"]["value"]

            print(f"\n🔍 搜索关键词：'{keyword}'，找到 {total} 条匹配结果（显示前 {len(hits)} 条）")

            # 格式化结果
            results = []
            for hit in hits:
                source = hit["_source"]
                results.append({
                    "id": hit["_id"],
                    "filename": source["filename"],
                    "folder": source["folder"],
                    "score": hit["_score"],  # 匹配得分（越高越相关）
                    "content_preview": source["content"][:preview_length] + "..." if len(source["content"]) > preview_length else source["content"],
                    "import_time": source["import_time"]
                })

            return results

        except RequestError as e:
            print(f"❌ 搜索请求错误：{e}")
            return []
        except Exception as e:
            print(f"❌ 搜索失败：{e}")
            return []

    def print_results(self, results):
        """格式化打印搜索结果"""
        for i, result in enumerate(results, 1):
            print(f"\n--- 结果 {i} ---")
            print(f"文件名：{result['filename']}")
            print(f"文件夹：{result['folder']}")
            print(f"相关度：{result['score']:.2f}")
            print(f"内容预览：{result['content_preview']}")
            print(f"导入时间：{result['import_time']}")


# 示例用法
if __name__ == "__main__":
    # 初始化搜索器（替换为你的密码）
    searcher = ElasticsearchSearcher(
        password="vSCQnhBXoox0sRo7-U1x",  # 你的 Elasticsearch 密码
        index_name="chunk_documents"       # 与导入时的索引名一致
    )

    # 搜索示例（替换为你想搜索的关键词）
    keyword = "人工智能"  # 例如 "人工智能"、"数据分析" 等
    search_results = searcher.search(keyword, size=5)  # 搜索并返回前5条结果

    # 打印结果
    searcher.print_results(search_results)


In [7]:
import os
import numpy as np
from typing import List, Dict, Any, Optional, Union
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from pymilvus import MilvusClient
from elasticsearch import Elasticsearch
import json
from datetime import datetime
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class HybridRAGSearcher:
    """
    混合检索与重排序系统
    结合 Milvus 向量搜索、Elasticsearch 关键词搜索，基于 transformers 加载 BCE 系列模型
    """
    
    def __init__(self, 
                 milvus_db_path="./milvus_data.db",
                 embedding_model_name_or_path="maidalun1020/bce-embedding-base_v1",
                 reranker_model_name_or_path="maidalun1020/bce-reranker-base_v1",
                 es_url="https://localhost:9200",
                 es_username="elastic",
                 es_password="",
                 es_index_name="chunk_documents",
                 milvus_collection_name="testchunks",
                 max_seq_length: int = 512):
        """初始化混合检索系统"""
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.max_seq_length = max_seq_length
        logger.info(f"使用设备: {self.device} | 最大序列长度: {self.max_seq_length}")
        
        # 初始化模型
        self.embedding_tokenizer, self.embedding_model = self._init_embedding_model(embedding_model_name_or_path)
        self.reranker_tokenizer, self.reranker_model = self._init_reranker_model(reranker_model_name_or_path)
        
        # 初始化 Milvus
        self._init_milvus(milvus_db_path, milvus_collection_name)
        
        # 初始化 Elasticsearch
        self._init_elasticsearch(es_url, es_username, es_password, es_index_name)
        
        logger.info("混合检索系统初始化完成")
    
    def _init_embedding_model(self, model_name_or_path: str):
        """初始化 BCE 嵌入模型"""
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            model = AutoModel.from_pretrained(model_name_or_path)
            model = model.to(self.device)
            model.eval()
            logger.info(f"BCE嵌入模型加载成功: {model_name_or_path}")
            return tokenizer, model
        
        except Exception as e:
            logger.error(f"嵌入模型加载失败: {e}")
            raise
    
    def _init_reranker_model(self, model_name_or_path: str):
        """初始化 BCE 重排序模型"""
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
            model = model.to(self.device)
            model.eval()
            logger.info(f"BCE重排序模型加载成功: {model_name_or_path}")
            return tokenizer, model
        
        except Exception as e:
            logger.error(f"重排序模型加载失败: {e}")
            raise
    
    def _init_milvus(self, db_path, collection_name):
        """初始化 Milvus 连接"""
        try:
            self.milvus_client = MilvusClient(db_path)
            self.milvus_collection_name = collection_name
            
            collections = self.milvus_client.list_collections()
            if collection_name not in collections:
                raise ValueError(f"Milvus 集合 '{collection_name}' 不存在")
            
            logger.info(f"Milvus 连接成功: {collection_name}")
            
        except Exception as e:
            logger.error(f"Milvus 初始化失败: {e}")
            self.milvus_client = None
    
    def _init_elasticsearch(self, es_url, username, password, index_name):
        """初始化 Elasticsearch 连接"""
        try:
            self.es_client = Elasticsearch(
                [es_url],
                basic_auth=(username, password),
                verify_certs=False,
                ssl_show_warn=False
            )
            
            self.es_index_name = index_name
            
            if self.es_client.ping():
                logger.info(f"Elasticsearch 连接成功: {es_url}")
            else:
                logger.warning("Elasticsearch 连接失败")
                self.es_client = None
                
        except Exception as e:
            logger.error(f"Elasticsearch 初始化失败: {e}")
            self.es_client = None
    
    def _generate_embedding(self, texts: List[str]) -> np.ndarray:
        """生成文本嵌入"""
        with torch.no_grad():
            inputs = self.embedding_tokenizer(
                texts,
                padding=True,
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
            
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            outputs = self.embedding_model(** inputs)
            cls_embeddings = outputs.last_hidden_state[:, 0, :]
            normalized_embeddings = cls_embeddings / cls_embeddings.norm(dim=1, keepdim=True)
            
            return normalized_embeddings.cpu().numpy()
    
    def _vector_search(self, query_text: str, top_k: int = 20) -> List[Dict]:
        """执行向量搜索"""
        if not self.milvus_client:
            return []
        
        try:
            query_embedding = self._generate_embedding([query_text])[0].tolist()
            
            search_results = self.milvus_client.search(
                collection_name=self.milvus_collection_name,
                anns_field="vector",
                data=[query_embedding],
                limit=top_k,
                output_fields=["text", "folder", "file", "timestamp"],
                search_params={"params": {}}
            )
            
            results = []
            for hit in search_results[0]:
                results.append({
                    "id": f"milvus_{hit['id']}",
                    "text": hit["entity"]["text"],
                    "source": f"{hit['entity']['folder']}/{hit['entity']['file']}",
                    "folder": hit["entity"]["folder"],
                    "file": hit["entity"]["file"],
                    "timestamp": hit["entity"]["timestamp"],
                    "score": 1.0 / (1.0 + hit["distance"]),
                    "search_type": "vector"
                })
            
            logger.info(f"向量搜索找到 {len(results)} 个结果")
            return results
            
        except Exception as e:
            logger.error(f"向量搜索失败: {e}")
            return []
    
    def _keyword_search(self, query_text: str, top_k: int = 20) -> List[Dict]:
        """执行关键词搜索"""
        if not self.es_client:
            return []
        
        try:
            query = {
                "query": {
                    "match": {
                        "content": query_text
                    }
                },
                "size": top_k
            }
            
            response = self.es_client.search(index=self.es_index_name, body=query)
            hits = response["hits"]["hits"]
            
            results = []
            for hit in hits:
                source = hit["_source"]
                results.append({
                    "id": f"es_{hit['_id']}",
                    "text": source["content"],
                    "source": f"{source['folder']}/{source['filename']}",
                    "folder": source["folder"],
                    "file": source["filename"],
                    "timestamp": source.get("import_time", ""),
                    "score": hit["_score"] / 100.0,
                    "search_type": "keyword"
                })
            
            logger.info(f"关键词搜索找到 {len(results)} 个结果")
            return results
            
        except Exception as e:
            logger.error(f"关键词搜索失败: {e}")
            return []
    
    def _merge_and_deduplicate(self, vector_results: List[Dict], keyword_results: List[Dict]) -> List[Dict]:
        """合并并去重搜索结果"""
        seen_texts = set()
        merged_results = []
        
        for result in vector_results:
            text_key = result["text"][:100]
            if text_key not in seen_texts:
                seen_texts.add(text_key)
                merged_results.append(result)
        
        for result in keyword_results:
            text_key = result["text"][:100]
            if text_key not in seen_texts:
                seen_texts.add(text_key)
                merged_results.append(result)
        
        logger.info(f"合并后共有 {len(merged_results)} 个唯一结果")
        return merged_results
    
    def _calculate_rerank_scores(self, query_text: str, texts: List[str]) -> List[float]:
        """计算重排序分数 - 修复维度不匹配问题"""
        with torch.no_grad():
            # 确保文本长度不会超过模型限制
            max_text_length = self.max_seq_length - len(self.reranker_tokenizer.tokenize(query_text)) - 3  # 预留空间给特殊标记
            if max_text_length < 10:  # 确保有足够空间
                max_text_length = 10
                
            # 处理文本，确保不会超过最大长度
            processed_texts = []
            for text in texts:
                # 截断文本以适应模型长度限制
                tokens = self.reranker_tokenizer.tokenize(text)
                if len(tokens) > max_text_length:
                    tokens = tokens[:max_text_length]
                    text = self.reranker_tokenizer.convert_tokens_to_string(tokens)
                processed_texts.append(text)
            
            # 构造 (query, doc) 对
            sentence_pairs = [[query_text, doc_text] for doc_text in processed_texts]
            
            # Tokenize处理
            inputs = self.reranker_tokenizer(
                sentence_pairs,
                padding=True,
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
            
            # 移动到设备
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # 模型前向传播
            outputs = self.reranker_model(** inputs)
            logits = outputs.logits.view(-1,).float()
            scores = torch.sigmoid(logits).cpu().tolist()
            
            return scores
    
    def _rerank_results(self, query_text: str, results: List[Dict], top_k: int = 10) -> List[Dict]:
        """使用重排序模型对结果进行重排序"""
        if not results or not self.reranker_model:
            return results[:top_k]
        
        try:
            texts_to_rerank = [result["text"] for result in results]
            rerank_scores = self._calculate_rerank_scores(query_text, texts_to_rerank)
            
            for i, result in enumerate(results):
                result["rerank_score"] = rerank_scores[i]
                result["original_score"] = result["score"]
            
            reranked_results = sorted(results, key=lambda x: x["rerank_score"], reverse=True)
            logger.info(f"重排序完成，返回前 {min(top_k, len(reranked_results))} 个结果")
            return reranked_results[:top_k]
            
        except Exception as e:
            logger.error(f"重排序失败: {e}")
            # 失败时尝试返回原始分数排序的结果
            return sorted(results, key=lambda x: x["score"], reverse=True)[:top_k]
    
    def hybrid_search(self, 
                     query_text: str, 
                     top_k: int = 10,
                     vector_weight: float = 0.6,
                     keyword_weight: float = 0.4,
                     enable_rerank: bool = True,
                     retrieval_size: int = 50) -> Dict[str, Any]:
        """执行混合检索"""
        logger.info(f"开始混合检索: '{query_text[:50]}...'")
        start_time = datetime.now()
        
        vector_results = self._vector_search(query_text, retrieval_size)
        keyword_results = self._keyword_search(query_text, retrieval_size)
        merged_results = self._merge_and_deduplicate(vector_results, keyword_results)
        
        if not enable_rerank:
            for result in merged_results:
                if result["search_type"] == "vector":
                    result["hybrid_score"] = result["score"] * vector_weight
                else:
                    result["hybrid_score"] = result["score"] * keyword_weight
            
            merged_results = sorted(merged_results, key=lambda x: x["hybrid_score"], reverse=True)
            final_results = merged_results[:top_k]
        else:
            final_results = self._rerank_results(query_text, merged_results, top_k)
        
        execution_time = (datetime.now() - start_time).total_seconds()
        
        search_result = {
            "query": query_text,
            "results": final_results,
            "total_found": len(final_results),
            "vector_results_count": len(vector_results),
            "keyword_results_count": len(keyword_results),
            "merged_results_count": len(merged_results),
            "execution_time_seconds": execution_time,
            "rerank_enabled": enable_rerank,
            "timestamp": datetime.now().isoformat(),
            "error": None  # 始终包含error键，默认为None
        }
        
        logger.info(f"混合检索完成，找到 {len(final_results)} 个结果，耗时 {execution_time:.2f} 秒")
        return search_result
    
    def print_search_results(self, search_result: Dict[str, Any], show_full_text: bool = False):
        """格式化打印搜索结果"""
        print(f"\n{'='*80}")
        print(f"查询: {search_result['query']}")
        print(f"执行时间: {search_result['execution_time_seconds']:.2f} 秒")
        print(f"向量搜索结果: {search_result['vector_results_count']} 个")
        print(f"关键词搜索结果: {search_result['keyword_results_count']} 个")
        print(f"合并后结果: {search_result['merged_results_count']} 个")
        print(f"最终返回: {search_result['total_found']} 个")
        print(f"重排序: {'已启用' if search_result['rerank_enabled'] else '未启用'}")
        print(f"{'='*80}")
        
        for i, result in enumerate(search_result['results'], 1):
            print(f"\n--- 结果 {i} ---")
            print(f"来源: {result['source']}")
            print(f"搜索类型: {result['search_type']}")
            
            if 'rerank_score' in result:
                print(f"重排序分数: {result['rerank_score']:.4f}")
                print(f"原始分数: {result['original_score']:.4f}")
            elif 'hybrid_score' in result:
                print(f"混合分数: {result['hybrid_score']:.4f}")
            else:
                print(f"分数: {result['score']:.4f}")
            
            text_preview = result['text'][:300] if not show_full_text else result['text']
            if len(result['text']) > 300 and not show_full_text:
                text_preview += "..."
            print(f"内容: {text_preview}")
    
    def close(self):
        """关闭所有连接"""
        if self.milvus_client:
            self.milvus_client.close()
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            logger.info("GPU缓存已清理")
        
        logger.info("混合检索系统已关闭")


def hybrid_rag_search(query_text: str, 
                     top_k: int = 10,
                     milvus_db_path: str = "./milvus_data.db",
                     embedding_model_name_or_path: str = "maidalun1020/bce-embedding-base_v1",
                     reranker_model_name_or_path: str = "maidalun1020/bce-reranker-base_v1",
                     es_url: str = "https://localhost:9200",
                     es_username: str = "elastic",
                     es_password: str = "",
                     es_index_name: str = "chunk_documents",
                     enable_rerank: bool = True,
                     max_seq_length: int = 512) -> Dict[str, Any]:
    """独立的混合检索服务函数"""
    searcher = None
    try:
        searcher = HybridRAGSearcher(
            milvus_db_path=milvus_db_path,
            embedding_model_name_or_path=embedding_model_name_or_path,
            reranker_model_name_or_path=reranker_model_name_or_path,
            es_url=es_url,
            es_username=es_username,
            es_password=es_password,
            es_index_name=es_index_name,
            max_seq_length=max_seq_length
        )
        
        result = searcher.hybrid_search(
            query_text=query_text,
            top_k=top_k,
            enable_rerank=enable_rerank
        )
        
        return result
        
    except Exception as e:
        logger.error(f"搜索服务错误: {e}")
        return {
            "query": query_text,
            "results": [],
            "total_found": 0,
            "error": str(e),
            "timestamp": datetime.now().isoformat()
        }
    
    finally:
        if searcher:
            searcher.close()


def test_hybrid_search():
    """测试混合检索系统"""
    
    config = {
        "milvus_db_path": "./milvus_data.db",
        "embedding_model_name_or_path": "./embedding-model",
        "reranker_model_name_or_path": "./reranker-model",
        "es_url": "https://localhost:9200",
        "es_username": "elastic",
        "es_password": "vSCQnhBXoox0sRo7-U1x",
        "es_index_name": "chunk_documents",
        "max_seq_length": 4096
    }
    
    test_queries = [
        "人工智能的发展趋势",
        "机器学习算法原理",
    ]
    
    try:
        searcher = HybridRAGSearcher(** config)
        
        for query in test_queries:
            print(f"\n{'='*100}")
            print(f"测试查询: {query}")
            
            print(f"\n【启用重排序】")
            result_with_rerank = searcher.hybrid_search(
                query_text=query,
                top_k=5,
                enable_rerank=True
            )
            searcher.print_search_results(result_with_rerank, show_full_text=False)
            
            print(f"\n【禁用重排序】")
            result_no_rerank = searcher.hybrid_search(
                query_text=query,
                top_k=3,
                enable_rerank=False,
                vector_weight=0.6,
                keyword_weight=0.4
            )
            for i, res in enumerate(result_no_rerank['results'], 1):
                print(f"{i}. 混合分数: {res['hybrid_score']:.4f} | 来源: {res['source']}")
        
        print(f"\n{'='*100}")
        print("测试独立搜索函数:")
        standalone_result = hybrid_rag_search(
            query_text="人工智能技术应用场景",
            top_k=3,
            **config
        )
        
        # 修复KeyError问题：检查error键是否存在且不为None
        if "error" in standalone_result and standalone_result["error"] is not None:
            print(f"独立函数错误: {standalone_result['error']}")
        else:
            print(f"独立函数搜索结果: 找到 {standalone_result['total_found']} 个结果")
            for i, res in enumerate(standalone_result['results'], 1):
                score = res.get('rerank_score', res.get('hybrid_score', res['score']))
                print(f"{i}. 分数: {score:.4f} | 来源: {res['source']}")
    
    except Exception as e:
        logger.error(f"测试过程中发生错误: {e}")
        import traceback
        traceback.print_exc()
    
    finally:
        if 'searcher' in locals():
            searcher.close()


if __name__ == "__main__":
    test_hybrid_search()


INFO:__main__:使用设备: cuda | 最大序列长度: 4096
INFO:__main__:BCE嵌入模型加载成功: ./embedding-model
INFO:__main__:BCE重排序模型加载成功: ./reranker-model
INFO:__main__:Milvus 连接成功: testchunks
INFO:elastic_transport.transport:HEAD https://localhost:9200/ [status:200 duration:0.015s]
INFO:__main__:Elasticsearch 连接成功: https://localhost:9200
INFO:__main__:混合检索系统初始化完成
INFO:__main__:开始混合检索: '人工智能的发展趋势...'
INFO:__main__:向量搜索找到 50 个结果
INFO:elastic_transport.transport:POST https://localhost:9200/chunk_documents/_search [status:200 duration:0.018s]
INFO:__main__:关键词搜索找到 50 个结果
INFO:__main__:合并后共有 6 个唯一结果
ERROR:__main__:重排序失败: The expanded size of the tensor (2962) must match the existing size (514) at non-singleton dimension 1.  Target sizes: [6, 2962].  Tensor sizes: [1, 514]
INFO:__main__:混合检索完成，找到 5 个结果，耗时 0.07 秒
INFO:__main__:开始混合检索: '人工智能的发展趋势...'
INFO:__main__:向量搜索找到 50 个结果
INFO:elastic_transport.transport:POST https://localhost:9200/chunk_documents/_search [status:200 duration:0.012s]
INFO:__main__:关键词搜索找到 50 个结


测试查询: 人工智能的发展趋势

【启用重排序】

查询: 人工智能的发展趋势
执行时间: 0.07 秒
向量搜索结果: 50 个
关键词搜索结果: 50 个
合并后结果: 6 个
最终返回: 5 个
重排序: 已启用

--- 结果 1 ---
来源: documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path_merged_node_0_chunk_92_to_node_0_chunk_92.txt
搜索类型: vector
分数: 0.5393
内容: Chunk path: ./chunks_output/documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path_merged_node_0_chunk_92_to_node_0_chunk_92.txt
Chunk name: prod_documents_dup_part_1_part_1_path_merged_node_0_chunk_92_to_node_0_chunk_92.txt
Chunk ...

--- 结果 2 ---
来源: test_ml_chunks/prod_test_ml_node_80_chunk_0.txt
搜索类型: vector
分数: 0.4966
内容: Chunk path: ./chunks_output/test_ml_chunks/prod_test_ml_node_80_chunk_0.txt
Chunk name: prod_test_ml_node_80_chunk_0.txt
Chunk ID: node_80_chunk_0
Source Node: 80
Chunk Index: 0
Path Titles: AI 与其他学科融合：与物理学（如量子机器学习，提升计算速度）、生物学（如生物信息学中的基因序列分析）、社会学（如社...

--- 结果 3 ---
来源: test_ml_chunks/prod_test_ml_path_merged_node_76_chunk_0_to_node_79_chunk_0.txt
搜索类型: vector
分数: 0.4914
内容:

ERROR:__main__:重排序失败: The expanded size of the tensor (3380) must match the existing size (514) at non-singleton dimension 1.  Target sizes: [28, 3380].  Tensor sizes: [1, 514]
INFO:__main__:混合检索完成，找到 5 个结果，耗时 0.11 秒
INFO:__main__:开始混合检索: '机器学习算法原理...'
INFO:__main__:向量搜索找到 50 个结果
INFO:elastic_transport.transport:POST https://localhost:9200/chunk_documents/_search [status:200 duration:0.016s]
INFO:__main__:关键词搜索找到 50 个结果
INFO:__main__:合并后共有 28 个唯一结果
INFO:__main__:混合检索完成，找到 3 个结果，耗时 0.04 秒
INFO:__main__:使用设备: cuda | 最大序列长度: 4096



查询: 机器学习算法原理
执行时间: 0.11 秒
向量搜索结果: 50 个
关键词搜索结果: 50 个
合并后结果: 28 个
最终返回: 5 个
重排序: 已启用

--- 结果 1 ---
来源: test_ml_chunks/prod_test_ml_path_merged_node_11_chunk_0_to_node_15_chunk_0.txt
搜索类型: vector
分数: 0.6455
内容: Chunk path: ./chunks_output/test_ml_chunks/prod_test_ml_path_merged_node_11_chunk_0_to_node_15_chunk_0.txt
Chunk name: prod_test_ml_path_merged_node_11_chunk_0_to_node_15_chunk_0.txt
Chunk ID: path_merged_node_11_chunk_0_to_node_15_chunk_0
Source No...

--- 结果 2 ---
来源: test_ml_chunks/prod_test_ml_node_26_chunk_0.txt
搜索类型: vector
分数: 0.5994
内容: Chunk path: ./chunks_output/test_ml_chunks/prod_test_ml_node_26_chunk_0.txt
Chunk name: prod_test_ml_node_26_chunk_0.txt
Chunk ID: node_26_chunk_0
Source Node: 26
Chunk Index: 0
Path Titles: 3. 机器学习的主要分类 > 3.1 监督学习（Supervised Learning） > 3.1.1 定义与特点...

--- 结果 3 ---
来源: test_ml_chunks/prod_test_ml_node_27_chunk_0.txt
搜索类型: vector
分数: 0.5955
内容: Chunk path: ./chunks_output/test_ml_chunks/prod_test_ml_node_27_chunk_0.txt
Chunk name: prod_tes

INFO:__main__:BCE嵌入模型加载成功: ./embedding-model
INFO:__main__:BCE重排序模型加载成功: ./reranker-model
INFO:__main__:Milvus 连接成功: testchunks
INFO:elastic_transport.transport:HEAD https://localhost:9200/ [status:200 duration:0.015s]
INFO:__main__:Elasticsearch 连接成功: https://localhost:9200
INFO:__main__:混合检索系统初始化完成
INFO:__main__:开始混合检索: '人工智能技术应用场景...'
INFO:__main__:向量搜索找到 50 个结果
INFO:elastic_transport.transport:POST https://localhost:9200/chunk_documents/_search [status:200 duration:0.017s]
INFO:__main__:关键词搜索找到 50 个结果
INFO:__main__:合并后共有 13 个唯一结果
ERROR:__main__:重排序失败: The expanded size of the tensor (2963) must match the existing size (514) at non-singleton dimension 1.  Target sizes: [13, 2963].  Tensor sizes: [1, 514]
INFO:__main__:混合检索完成，找到 3 个结果，耗时 0.07 秒
INFO:__main__:GPU缓存已清理
INFO:__main__:混合检索系统已关闭
INFO:__main__:GPU缓存已清理
INFO:__main__:混合检索系统已关闭


独立函数搜索结果: 找到 3 个结果
1. 分数: 0.5325 | 来源: documents_dup_part_1_part_1_chunks/prod_documents_dup_part_1_part_1_path_merged_node_0_chunk_92_to_node_0_chunk_92.txt
2. 分数: 0.5137 | 来源: test_ml_chunks/prod_test_ml_node_66_chunk_0.txt
3. 分数: 0.5115 | 来源: test_ml_chunks/prod_test_ml_node_80_chunk_0.txt


In [2]:
import os
import numpy as np
import torch
from typing import List, Dict, Any, Optional
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from pymilvus import MilvusClient
from elasticsearch import Elasticsearch
from elasticsearch.exceptions import RequestError, ConnectionError
import json
import requests
import time
from datetime import datetime
import logging

# -------------------------- 1. 基础配置（用户可根据实际环境调整） --------------------------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# RAG核心配置
RAG_CONFIG = {
    "milvus_db_path": "./milvus_data.db",
    "milvus_collection_name": "testchunks",
    "embedding_model_path": "./embedding-model",
    "reranker_model_path": "./reranker-model",
    "es_url": "https://localhost:9200",
    "es_username": "elastic",
    "es_password": "vSCQnhBXoox0sRo7-U1x",
    "es_index_name": "chunk_documents",
    "max_seq_length": 512,
    "rag_top_k": 3,  # 检索返回的top相关结果数
    "doc_preview_length": 1000  # 文档内容预览长度（固定300字）
}

# DeepSeek-V3.1大模型配置
LLM_CONFIG = {
    "api_url": "https://api.siliconflow.cn/v1/chat/completions",
    "api_key": "sk-ionsbeieleeekwlstqotkyrmictdzshgnbaytavcudxkixcs",
    "model_name": "deepseek-ai/DeepSeek-V3.1",
    "max_tokens": 2000,
    "temperature": 0.7,
    "max_retries": 3,
    "retry_delay": 5
}

# -------------------------- 2. 混合RAG检索系统（不变，复用之前稳定版本） --------------------------
class HybridRAGSearcher:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logger.info(f"RAG检索系统使用设备: {self.device}")

        self.embedding_tokenizer, self.embedding_model = self._init_embedding_model()
        self.reranker_tokenizer, self.reranker_model = self._init_reranker_model()
        self.milvus_client = self._init_milvus()
        self.es_client = self._init_elasticsearch()

        logger.info("混合RAG检索系统初始化完成")

    def _init_embedding_model(self):
        try:
            tokenizer = AutoTokenizer.from_pretrained(self.config["embedding_model_path"])
            model = AutoModel.from_pretrained(self.config["embedding_model_path"]).to(self.device)
            model.eval()
            logger.info("BCE嵌入模型加载成功")
            return tokenizer, model
        except Exception as e:
            logger.error(f"嵌入模型加载失败: {e}")
            raise

    def _init_reranker_model(self):
        try:
            tokenizer = AutoTokenizer.from_pretrained(self.config["reranker_model_path"])
            model = AutoModelForSequenceClassification.from_pretrained(self.config["reranker_model_path"]).to(self.device)
            model.eval()
            logger.info("BCE重排序模型加载成功")
            return tokenizer, model
        except Exception as e:
            logger.error(f"重排序模型加载失败: {e}")
            raise

    def _init_milvus(self):
        try:
            client = MilvusClient(self.config["milvus_db_path"])
            collections = client.list_collections()
            if self.config["milvus_collection_name"] not in collections:
                raise ValueError(f"Milvus集合 '{self.config['milvus_collection_name']}' 不存在")
            logger.info("Milvus连接成功")
            return client
        except Exception as e:
            logger.error(f"Milvus初始化失败: {e}")
            raise

    def _init_elasticsearch(self):
        try:
            client = Elasticsearch(
                [self.config["es_url"]],
                basic_auth=(self.config["es_username"], self.config["es_password"]),
                verify_certs=False,
                ssl_show_warn=False
            )
            if client.ping():
                logger.info("Elasticsearch连接成功")
                return client
            else:
                raise ConnectionError("Elasticsearch ping失败")
        except Exception as e:
            logger.error(f"Elasticsearch初始化失败: {e}")
            raise

    def _generate_embedding(self, texts: List[str]) -> np.ndarray:
        with torch.no_grad():
            inputs = self.embedding_tokenizer(
                texts, padding=True, truncation=True, max_length=self.config["max_seq_length"], return_tensors="pt"
            ).to(self.device)
            outputs = self.embedding_model(**inputs)
            cls_emb = outputs.last_hidden_state[:, 0, :]
            return (cls_emb / cls_emb.norm(dim=1, keepdim=True)).cpu().numpy()

    def _vector_search(self, query: str) -> List[Dict]:
        try:
            query_emb = self._generate_embedding([query])[0].tolist()
            results = self.milvus_client.search(
                collection_name=self.config["milvus_collection_name"],
                anns_field="vector",
                data=[query_emb],
                limit=self.config["rag_top_k"] * 2,
                output_fields=["text", "folder", "file", "timestamp"]
            )
            return [
                {
                    "id": f"milvus_{hit['id']}",
                    "text": hit["entity"]["text"],
                    "source": f"{hit['entity']['folder']}/{hit['entity']['file']}",  # 完整路径（文件夹/文件名）
                    "full_source_path": os.path.abspath(f"{hit['entity']['folder']}/{hit['entity']['file']}"),  # 绝对路径
                    "score": 1.0 / (1.0 + hit["distance"]),
                    "search_type": "向量检索",
                    "timestamp": hit["entity"].get("timestamp", "")
                }
                for hit in results[0]
            ]
        except Exception as e:
            logger.error(f"向量搜索失败: {e}")
            return []

    def _keyword_search(self, query: str) -> List[Dict]:
        try:
            response = self.es_client.search(
                index=self.config["es_index_name"],
                body={"query": {"match": {"content": query}}, "size": self.config["rag_top_k"] * 2}
            )
            return [
                {
                    "id": f"es_{hit['_id']}",
                    "text": hit["_source"]["content"],
                    "source": f"{hit['_source']['folder']}/{hit['_source']['filename']}",  # 完整路径（文件夹/文件名）
                    "full_source_path": os.path.abspath(f"{hit['_source']['folder']}/{hit['_source']['filename']}"),  # 绝对路径
                    "score": hit["_score"] / 100.0,
                    "search_type": "关键词检索",
                    "timestamp": hit["_source"].get("import_time", "")
                }
                for hit in response["hits"]["hits"]
            ]
        except Exception as e:
            logger.error(f"关键词搜索失败: {e}")
            return []

    def _merge_deduplicate(self, vector_res: List[Dict], keyword_res: List[Dict]) -> List[Dict]:
        seen = set()
        merged = []
        for res in vector_res + keyword_res:
            text_key = res["text"][:100]
            if text_key not in seen:
                seen.add(text_key)
                merged.append(res)
        return sorted(merged, key=lambda x: x["score"], reverse=True)[:self.config["rag_top_k"] * 3]

    def _rerank(self, query: str, results: List[Dict]) -> List[Dict]:
        if not results:
            return []
        try:
            max_doc_len = self.config["max_seq_length"] - len(self.reranker_tokenizer.tokenize(query)) - 3
            pairs = [[query, res["text"][:max_doc_len]] for res in results]
            
            with torch.no_grad():
                inputs = self.reranker_tokenizer(
                    pairs, padding=True, truncation=True, max_length=self.config["max_seq_length"], return_tensors="pt"
                ).to(self.device)
                scores = torch.sigmoid(self.reranker_model(**inputs).logits.view(-1,)).cpu().tolist()
            
            for res, score in zip(results, scores):
                res["rerank_score"] = score
            return sorted(results, key=lambda x: x["rerank_score"], reverse=True)[:self.config["rag_top_k"]]
        except Exception as e:
            logger.error(f"重排序失败: {e}")
            return sorted(results, key=lambda x: x["score"], reverse=True)[:self.config["rag_top_k"]]

    def search(self, query: str) -> List[Dict]:
        logger.info(f"开始RAG检索，查询: {query[:50]}...")
        start = datetime.now()
        vector_res = self._vector_search(query)
        keyword_res = self._keyword_search(query)
        merged_res = self._merge_deduplicate(vector_res, keyword_res)
        final_res = self._rerank(query, merged_res)
        
        # 补充检索耗时（每个文档都带，便于后续展示）
        retrieval_time = (datetime.now() - start).total_seconds()
        for res in final_res:
            res["retrieval_time"] = retrieval_time
        logger.info(f"RAG检索完成，获取 {len(final_res)} 条相关结果，耗时 {retrieval_time:.2f} 秒")
        return final_res

    def close(self):
        if self.milvus_client:
            self.milvus_client.close()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        logger.info("RAG检索系统资源已释放")

# -------------------------- 3. 新增：文档信息格式化（重点：300字内容+完整路径） --------------------------
def format_detailed_documents(rag_results: List[Dict], preview_length: int = 300) -> str:
    """
    格式化检索文档详情：包含完整路径（相对+绝对）、检索类型、分数、300字内容
    Args:
        rag_results: RAG检索结果列表
        preview_length: 文档内容预览长度（默认300字）
    Returns:
        结构化的文档详情字符串
    """
    if not rag_results:
        return "【未检索到相关文档】\n"
    
    detailed_str = "### 详细检索文档（共{}条）\n".format(len(rag_results))
    for idx, doc in enumerate(rag_results, 1):
        # 1. 处理文档内容：固定截取300字，不足则全显，超过加省略号
        doc_text = doc["text"].strip()
        if len(doc_text) <= preview_length:
            preview_text = doc_text
            ellipsis = ""
        else:
            # 截取前preview_length字，避免截断在中间（简单处理：以中文句末符号结束）
            preview_text = doc_text[:preview_length]
            # 若最后一个字符不是句末符号，找最近的句末符号截断
            end_symbols = ["。", "！", "？", "；", "】", ")", "}"]
            last_symbol_idx = max([preview_text.rfind(s) for s in end_symbols if s in preview_text], default=-1)
            if last_symbol_idx != -1 and last_symbol_idx > preview_length * 0.7:  # 确保保留70%以上内容
                preview_text = preview_text[:last_symbol_idx + 1]
            ellipsis = "..."
        
        # 2. 拼接文档详情（路径分相对+绝对，便于用户定位文件）
        detailed_str += f"""
#### 文档{idx}
- **来源路径（相对路径）**: {doc["source"]}
- **完整路径（绝对路径）**: {doc["full_source_path"]}
- **检索类型**: {doc["search_type"]}
- **相关性分数**: {doc.get("rerank_score", doc["score"]):.4f}（分数越高越相关）
- **文档时间戳**: {doc["timestamp"] if doc["timestamp"] else "未记录"}
- **300字内容预览**: {preview_text}{ellipsis}

{"-"*50}  # 分隔线，区分不同文档
"""
    return detailed_str

# -------------------------- 4. 大模型调用与结果整合（新增文档详情展示） --------------------------
def format_rag_for_llm(rag_results: List[Dict]) -> str:
    """给大模型的参考文档（简洁版，不影响回答生成）"""
    if not rag_results:
        return "【无相关参考文档】"
    llm_ref = "【相关参考文档】\n"
    for i, res in enumerate(rag_results, 1):
        preview = res["text"]
        llm_ref += f"{i}. 来源：{res['source']} | 内容：{preview}\n"
    return llm_ref

def call_deepseek_v31(prompt: str, system_prompt: str) -> str:
    """调用DeepSeek-V3.1大模型生成回答"""
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {LLM_CONFIG['api_key']}"
    }
    payload = {
        "model": LLM_CONFIG["model_name"],
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ],
        "max_tokens": LLM_CONFIG["max_tokens"],
        "temperature": LLM_CONFIG["temperature"],
        "stream": False
    }

    for retry in range(LLM_CONFIG["max_retries"]):
        try:
            response = requests.post(LLM_CONFIG["api_url"], json=payload, headers=headers, timeout=60)
            response.raise_for_status()
            result = response.json()
            if "choices" in result and result["choices"]:
                return result["choices"][0]["message"]["content"].strip()
            else:
                raise ValueError(f"大模型返回格式异常: {result}")
        except Exception as e:
            logger.error(f"大模型调用失败（第{retry+1}/{LLM_CONFIG['max_retries']}次）: {e}")
            if retry < LLM_CONFIG["max_retries"] - 1:
                time.sleep(LLM_CONFIG["retry_delay"])
    return "抱歉，大模型调用多次失败，请稍后重试。"

def rag_qa_pipeline(query: str, rag_searcher: HybridRAGSearcher) -> str:
    """端到端RAG问答流程：检索→生成回答→附加详细文档"""
    # 1. RAG检索获取结果
    rag_results = rag_searcher.search(query)
    # 2. 生成大模型所需的简洁参考
    llm_ref = format_rag_for_llm(rag_results)
    # 3. 生成用户所需的详细文档详情（300字+路径）
    detailed_docs = format_detailed_documents(
        rag_results, 
        preview_length=RAG_CONFIG["doc_preview_length"]
    )

    # 4. 大模型Prompt（引导基于参考回答）
    system_prompt = """你是基于检索增强（RAG）的专业问答助手，严格遵循：
1. 必须优先使用【相关参考文档摘要】中的信息回答，每个结论需标注对应文档编号（如【参考1】）；
2. 若参考文档信息不足，补充知识时需标注“注：以下内容基于模型自身知识补充”；
3. 回答逻辑清晰，分点说明（适用时），不编造信息，无法回答则直接说明。"""
    
    user_prompt = f"""用户问题：{query}

{llm_ref}

请基于上述参考文档，回答用户问题。"""

    # 5. 调用大模型生成回答
    logger.info("调用DeepSeek-V3.1生成回答...")
    answer = call_deepseek_v31(user_prompt, system_prompt)

    # 6. 整合最终输出（回答 + 详细文档）
    final_output = f"""
# RAG问答结果
## 一、用户问题
{query}

## 二、AI回答
{answer}

## 三、检索统计信息
- 检索到相关文档数量：{len(rag_results)} 条
- 总检索耗时：{rag_results[0]["retrieval_time"]:.2f} 秒（含向量检索、关键词检索、重排序）
- 检索类型：向量检索（语义匹配）+ 关键词检索（字面匹配）

## 四、{detailed_docs}  # 插入详细文档（300字+路径）

> 注：详细文档中的“完整路径（绝对路径）”可直接复制到文件管理器打开，查看文档全文。
"""
    return final_output

# -------------------------- 5. 主函数（交互入口） --------------------------
def main():
    rag_searcher = None
    try:
        rag_searcher = HybridRAGSearcher(RAG_CONFIG)
        print("="*100)
        print("===== 带详细文档展示的RAG问答系统（基于DeepSeek-V3.1） =====")
        print("说明：输入问题后，将返回AI回答+检索文档详情（含300字内容+文件路径）")
        print("输入'退出'或'quit'可结束程序")
        print("="*100)
        
        while True:
            user_query = input("\n请输入你的问题：").strip()
            if user_query.lower() in ["退出", "quit", "exit"]:
                print("\n感谢使用，程序已退出！")
                break
            if not user_query:
                print("请输入有效的问题，不能为空！")
                continue
            
            # 执行问答流程并打印结果
            print("\n" + "="*50)
            print(f"正在处理问题：{user_query}")
            print("步骤1/2：RAG检索相关文档...")
            print("步骤2/2：调用大模型生成回答...")
            print("="*50)
            
            final_result = rag_qa_pipeline(user_query, rag_searcher)
            print("\n" + "="*100)
            print("最终结果：")
            print(final_result)
            print("="*100)
            
    except Exception as e:
        logger.error(f"系统初始化失败: {e}")
        print(f"\n错误：系统初始化失败，原因：{str(e)}")
        print("请检查Milvus/Elasticsearch服务是否启动，或配置参数是否正确。")
    finally:
        if rag_searcher:
            rag_searcher.close()

if __name__ == "__main__":
    main()

2025-09-19 11:38:51,405 - INFO - RAG检索系统使用设备: cuda
2025-09-19 11:38:52,584 - INFO - BCE嵌入模型加载成功
2025-09-19 11:38:53,817 - INFO - BCE重排序模型加载成功
2025-09-19 11:38:53,822 - INFO - Milvus连接成功
2025-09-19 11:38:53,870 - INFO - HEAD https://localhost:9200/ [status:200 duration:0.027s]
2025-09-19 11:38:53,871 - INFO - Elasticsearch连接成功
2025-09-19 11:38:53,871 - INFO - 混合RAG检索系统初始化完成


===== 带详细文档展示的RAG问答系统（基于DeepSeek-V3.1） =====
说明：输入问题后，将返回AI回答+检索文档详情（含300字内容+文件路径）
输入'退出'或'quit'可结束程序



请输入你的问题： 机器学习


2025-09-19 11:38:57,686 - INFO - 开始RAG检索，查询: 机器学习...
2025-09-19 11:38:57,712 - INFO - POST https://localhost:9200/chunk_documents/_search [status:200 duration:0.008s]
2025-09-19 11:38:57,747 - INFO - RAG检索完成，获取 3 条相关结果，耗时 0.06 秒
2025-09-19 11:38:57,748 - INFO - 调用DeepSeek-V3.1生成回答...



正在处理问题：机器学习
步骤1/2：RAG检索相关文档...
步骤2/2：调用大模型生成回答...

最终结果：

# RAG问答结果
## 一、用户问题
机器学习

## 二、AI回答
根据提供的参考文档，以下是关于机器学习的回答：

### 1. **机器学习的定义**
机器学习（Machine Learning, ML）是人工智能（AI）的核心分支之一，通过设计算法使计算机从数据中自动学习规律、优化模型，并在未明确编程的情况下完成预测、分类或决策等任务。其核心目标是让计算机像人类一样从经验中学习，性能随数据量和训练次数的增加而提升【参考1】。

### 2. **机器学习的主要分类**
机器学习根据数据是否包含标签以及学习方式的差异，可分为三大类（具体类别未在文档中详细列出，但提到了分类依据）【参考3】。

### 3. **应用与意义**
机器学习以“数据驱动”的方式实现了从“规则编程”到“自主学习”的突破，在计算机视觉、自然语言处理、金融、医疗等领域有深远影响。尽管面临数据隐私和模型可解释性等挑战，它仍在提升智能化水平、解决复杂问题和推动产业升级中发挥重要作用，是未来数字社会的核心基础设施之一【参考2】。

注：以上回答严格基于提供的参考文档，未补充额外信息。文档中未涉及机器学习的具体算法或分类细节，因此仅概括性说明。如需更详细内容，建议提供更多相关文档。

## 三、检索统计信息
- 检索到相关文档数量：3 条
- 总检索耗时：0.06 秒（含向量检索、关键词检索、重排序）
- 检索类型：向量检索（语义匹配）+ 关键词检索（字面匹配）

## 四、### 详细检索文档（共3条）

#### 文档1
- **来源路径（相对路径）**: test_ml_chunks/prod_test_ml_node_3_chunk_0.txt
- **完整路径（绝对路径）**: /root/doc_processor/test_ml_chunks/prod_test_ml_node_3_chunk_0.txt
- **检索类型**: 向量检索
- **相关性分数**: 0.5348（分数越高越相关）
- **文档时间戳**: 2025-09-19T09:26:50.183964
- **300字内容预览**: Chunk path: ./chunks_output/te


请输入你的问题： 退出


2025-09-19 11:39:44,521 - INFO - RAG检索系统资源已释放



感谢使用，程序已退出！
