In [None]:
!pip install llama_index transformers unstructured pymilvus
!pip install llama-index-core
!pip install llama-index-extractors-entity
!pip install llama-index-vector-stores-milvus
!pip install llama-index-embeddings-huggingface
!pip install llama-index-llms-huggingface
!pip install llama-index-llms-dashscope
!pip install llama-index-extractors
!pip install pymilvus[milvus_lite]
!pip install unstructured[docx]
!pip install unstructured[doc]
!pip install unstructured[txt]
!pip install unstructured[md]
!pip install fitz frontend tools
!pip uninstall fitz pymupdf -y
!pip install pymupdf
!pip install -r requirements.txt

In [1]:
from llama_index.core import (VectorStoreIndex, SimpleDirectoryReader, load_index_from_storage
    , Document, Settings, StorageContext, PromptTemplate)
from llama_index.vector_stores.milvus import MilvusVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.extractors import KeywordExtractor, SummaryExtractor
from llama_index.core.schema import MetadataMode
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.dashscope import DashScope
from llama_index.llms.openai import OpenAI

from llama_index.extractors.entity import EntityExtractor
from llama_index.readers.file import UnstructuredReader,PyMuPDFReader,PDFReader

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig

import os, re, asyncio
from tqdm.asyncio import tqdm_asyncio
from tqdm import tqdm
import json



In [2]:
#!python pdf2md.py

In [2]:
embedding_model = "/root/autodl-tmp/Qwen3-Embedding-0.6B"
Settings.embed_model = HuggingFaceEmbedding(
    model_name=embedding_model,
    cache_folder=None,
    trust_remote_code=True,
    local_files_only=True
)

config = AutoConfig.from_pretrained(embedding_model, trust_remote_code=True, local_files_only=True)
dimension = config.hidden_size
log(f"模型嵌入维度: {dimension}")

2025-10-16 14:08:04,518 - INFO - Load pretrained SentenceTransformer: /root/autodl-tmp/Qwen3-Embedding-0.6B
2025-10-16 14:08:05,591 - INFO - 1 prompt is loaded, with the key: query


模型嵌入维度: 1024


In [6]:
from llama_index.core.llms import (
    CustomLLM,
    CompletionResponse,
    LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from llama_index.core import Settings
from typing import Any
import requests
from datetime import datetime
import os


class SiliconFlowLLM(CustomLLM):
    """硅基流动自定义 LLM"""
    
    model: str = "Qwen/Qwen3-Next-80B-A3B-Instruct"
    api_key: str = ""
    api_base: str = "https://api.siliconflow.cn/v1"
    max_tokens: int = 4096
    temperature: float = 0.1
    
    # 新增：保存请求的开关和配置
    save_requests: bool = True  # 默认开启，设为 False 则关闭
    save_dir: str = "llm_requests"  # 保存目录
    save_filename: str = "requests_log.txt"  # 文件名
    
    @property
    def metadata(self) -> LLMMetadata:
        """获取 LLM 元数据"""
        return LLMMetadata(
            context_window=32768,  # 根据具体模型调整
            num_output=self.max_tokens,
            model_name=self.model,
        )
    
    def _save_request(self, prompt: str, response_text: str = None):
        """内部方法：保存请求到文件"""
        if not self.save_requests:
            return
        
        try:
            # 创建保存目录
            os.makedirs(self.save_dir, exist_ok=True)
            
            # 生成文件路径
            filepath = os.path.join(self.save_dir, self.save_filename)
            
            # 准备保存内容
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            separator = "=" * 80
            
            content = f"\n{separator}\n"
            content += f"时间: {timestamp}\n"
            content += f"模型: {self.model}\n"
            content += f"{separator}\n"
            content += f"【请求内容】\n{prompt}\n"
            
            if response_text:
                content += f"\n【响应内容】\n{response_text}\n"
            
            content += f"{separator}\n"
            
            # 追加写入文件
            with open(filepath, 'a', encoding='utf-8') as f:
                f.write(content)
                
            log(f"✓ 请求已保存到: {filepath}")
            
        except Exception as e:
            log(f"✗ 保存请求失败: {str(e)}")
    
    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        """完成请求"""
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        data = {
            "model": self.model,
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "max_tokens": kwargs.get("max_tokens", self.max_tokens),
            "temperature": kwargs.get("temperature", self.temperature),
            "stream": False
        }
        
        response = requests.post(
            f"{self.api_base}/chat/completions",
            headers=headers,
            json=data
        )
        
        response.raise_for_status()
        result = response.json()
        
        response_text = result["choices"][0]["message"]["content"]
        
        # 保存请求和响应
        self._save_request(prompt, response_text)
        
        return CompletionResponse(
            text=response_text
        )
    
    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any):
        """流式完成（未实现，但需要定义）"""
        # 调用非流式方法
        response = self.complete(prompt, **kwargs)
        yield response


# 使用示例
if __name__ == "__main__":
    # 1. 创建自定义 LLM 实例
    llm = SiliconFlowLLM(
        model="Qwen/Qwen3-30B-A3B-Instruct-2507",  # 可选其他模型
        api_key="sk-ionsbeieleeekwlstqotkyrmictdzshgnbaytavcudxkixcs",  # 替换为你的 API Key
        api_base="https://api.siliconflow.cn/v1",
        max_tokens=1024,
        temperature=0.3,
        # 控制保存功能
        save_requests=True,  # 设为 False 可关闭保存功能
        save_dir="llm_requests",  # 可自定义保存目录
        save_filename="requests_log.txt"  # 可自定义文件名
    )
    
    # 2. 设置到 Settings
    Settings.llm = llm
    
    # 3. 测试使用
    response = llm.complete("你好，请介绍一下你自己")
    log(response.text)
    
    # 4. 如果需要临时关闭保存功能
    # llm.save_requests = False
    
    # 5. 再次测试（不会保存）
    # response = llm.complete("再问一个问题")
    # log(response.text)


✓ 请求已保存到: llm_requests/requests_log.txt
你好！我是通义千问（Qwen），是阿里巴巴集团旗下的通义实验室自主研发的超大规模语言模型。我能够回答问题、创作文字，比如写故事、写公文、写邮件、写剧本、逻辑推理、编程等等，还能表达观点，玩游戏等。如果你有任何问题或需要帮助，欢迎随时告诉我！


In [7]:
milvus_dir = "./milvus_test"
milvus_db_path = os.path.join(milvus_dir, "milvus_lite.db")
abs_db_path = os.path.abspath(milvus_db_path)
log(f"绝对数据库路径: {abs_db_path}")

if not os.path.exists(milvus_dir):
    os.makedirs(milvus_dir)
    log("已创建 ./milvus 目录")



# milvus_vector_store = MilvusVectorStore(
#     uri=f"{abs_db_path}",
#     collection_name="rag_collection",
#     dim=1024,
#     overwrite=True
# )
# storage_context = StorageContext.from_defaults(vector_store=milvus_vector_store)

绝对数据库路径: /root/marathon_rag/milvus_test/milvus_lite.db
已创建 ./milvus 目录


### 首次运行

In [8]:
###首次运行
milvus_dir = "./milvus_test"
milvus_db_path = os.path.join(milvus_dir, "milvus_lite.db")
abs_db_path = os.path.abspath(milvus_db_path)
log(f"绝对数据库路径: {abs_db_path}")

if not os.path.exists(milvus_dir):
    os.makedirs(milvus_dir)
    log("已创建 ./milvus 目录")



milvus_vector_store = MilvusVectorStore(
    uri=f"{abs_db_path}",
    collection_name="rag_collection",
    dim=1024,
    overwrite=True
)
storage_context = StorageContext.from_defaults(vector_store=milvus_vector_store)

绝对数据库路径: /root/marathon_rag/milvus_test/milvus_lite.db


  from pkg_resources import DistributionNotFound, get_distribution


In [9]:
def clean_text(text: str) -> str:
    text = re.sub(r'\n\s*\n+', '\n\n', text).strip()
    # text = re.sub(r'(\w+\s*){3,}\n', '', text)
    # text = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5\s\.,!?]', '', text)  # 去除特殊字符，保留中英文
    return text


In [10]:
async def generate_summary_async(text, max_words=30):
    prompt = f"总结以下文本，不超过{max_words}字，直接回复结果：{text}"
    response = await Settings.llm.acomplete(prompt)
    return response.text.strip()

def generate_summary(text, max_words=30):
    prompt = f"总结以下文本，不超过{max_words}字，直接回复结果：{text}"
    response = Settings.llm.complete(prompt)
    return response.text.strip()

async def add_summaries_to_nodes_async(nodes_list):
    tasks = [generate_summary_async(node.text) for node in nodes_list]

    summaries = []
    for future in tqdm_asyncio.as_completed(tasks, total=len(tasks), desc="生成节点摘要进度"):
        summary = await future
        summaries.append(summary)

    for node, summary in zip(nodes_list, summaries):
        node.metadata["node_summary"] = summary
        
def add_summaries_to_nodes(nodes_list):
    for node in tqdm(nodes_list, desc="生成摘要"):
        summary = generate_summary(node.text)
        node.metadata["node_summary"] = summary

In [15]:
qwen_tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/Qwen3-Reranker-4B", trust_remote_code=True)
documents_dir = "./docs"

file_extractor = {
    ".docx": UnstructuredReader(),
    ".doc": UnstructuredReader(),
    ".txt": UnstructuredReader(),
    ".md": UnstructuredReader(),
}


In [16]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

def load_single_file(file_path, file_extractor):
    """加载单个文件"""
    try:
        ext = Path(file_path).suffix.lower()
        if ext in file_extractor:
            reader = file_extractor[ext]
            log('loading:',file_path)
            docs = reader.load_data(file_path)
            return docs
        return []
    except Exception as e:
        log(f"加载文件 {file_path} 失败: {e}")
        return []

def load_documents_parallel(documents_dir, file_extractor, max_workers=4):
    """并行加载文档"""
    all_files = []
    for ext in file_extractor.keys():
        all_files.extend(Path(documents_dir).rglob(f"*{ext}"))
    
    documents = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(load_single_file, str(f), file_extractor): f 
                   for f in all_files}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="加载文件"):
            docs = future.result()
            documents.extend(docs)
    
    return documents

In [17]:
def preprocess_long_documents(documents, max_length=100000, overlap=0):
    """预处理超长文档，避免 tokenizer 处理超长文本"""
    processed_docs = []
    for doc in documents:
        text_length = len(doc.text)
        # 如果文档太长，先粗切分
        if text_length > max_length:
            log(f"检测到超长文档: {text_length} 字符，进行预切分")
            # 按固定长度切分，带重叠
            chunks = []
            start = 0
            chunk_index = 0
            
            while start < text_length:
                end = min(start + max_length, text_length)
                chunk_text = doc.text[start:end]
                
                # 创建新的 metadata，添加切片信息
                new_metadata = doc.metadata.copy() if doc.metadata else {}
                new_metadata['chunk_index'] = chunk_index
                new_metadata['total_chunks'] = (text_length + max_length - overlap - 1) // (max_length - overlap)
                new_metadata['is_chunked'] = True
                
                chunks.append(Document(text=chunk_text, metadata=new_metadata))
                
                # 下一个起点：当前起点 + (max_length - overlap)
                # 这样可以保证前后重叠 overlap 个字符
                start += (max_length - overlap)
                chunk_index += 1
            
            processed_docs.extend(chunks)
            log(f"  切分为 {len(chunks)} 个块，每块最大 {max_length} 字符，重叠 {overlap} 字符")
        else:
            processed_docs.append(doc)
    
    return processed_docs

In [18]:
# 使用方法
documents_dir=  "./data"
documents = load_documents_parallel(documents_dir, file_extractor, max_workers=1)

cleaned_documents = [Document(text=clean_text(doc.text), metadata=doc.metadata) 
                     for doc in documents]

# 添加这一步：最大长度100000，前后重叠1000
# cleaned_documents = preprocess_long_documents(
#     cleaned_documents, 
#     max_length=100000, 
#     overlap=0
# )
documents = cleaned_documents

log(f"文件大小:{len(documents)}")

node_parser = SentenceSplitter(chunk_size=1024, chunk_overlap=100, tokenizer=qwen_tokenizer.tokenize)  
nodes = node_parser.get_nodes_from_documents(documents)
log(f"节点数量:{len(nodes)}")

loading: data/银联“云闪付”业务管理办法.docx


加载文件:  17%|█▋        | 2/12 [00:02<00:10,  1.09s/it]

loading: data/银联“云闪付”线下使用的常见问题解答.docx
loading: data/中国银联银行卡联网联合技术规范.doc


加载文件:  58%|█████▊    | 7/12 [00:02<00:01,  4.80it/s]

加载文件 data/中国银联银行卡联网联合技术规范.doc 失败: soffice command was not found. Please install libreoffice
on your system and try again.

- Install instructions: https://www.libreoffice.org/get-help/install-howto/
- Mac: https://formulae.brew.sh/cask/libreoffice
- Debian: https://wiki.debian.org/LibreOffice
loading: data/2025年第三季度新能源汽车销量公布，比亚迪继续领跑.txt
loading: data/AI手机出货量预计突破4亿部，端侧大模型成为下一代智能终端竞争核心.txt
loading: data/多地加码楼市优化政策，全力支持刚需与改善性需求，市场信心逐步修复.txt
loading: data/我国发布《人工智能伦理治理指南》，为企业研发划清“红线”与“护栏”.txt
loading: data/中国银联全渠道商户服务操作手册.md




loading: data/中国银联商户服务平台用户操作手册(机构版).md
loading: data/银联收单扣率分类与标准.md
loading: data/银联清算业务体系.md


加载文件:  92%|█████████▏| 11/12 [00:03<00:00,  5.66it/s]

loading: data/银行云闪付一键查卡业务说明.md


加载文件: 100%|██████████| 12/12 [00:03<00:00,  3.15it/s]


文件大小:11
节点数量:82


In [None]:
#add_summaries_to_nodes(nodes)

In [None]:
# def save_summaries_to_json(nodes_list, file_path="nodes_summaries_temp.json"):
#     summaries_dict = {}
#     for idx, node in enumerate(nodes_list):
#         summaries_dict[str(idx)] = node.metadata.get("node_summary", "")  # 获取摘要，若无则为空
    
#     # 保存到 JSON
#     with open(file_path, 'w', encoding='utf-8') as f:
#         json.dump(summaries_dict, f, ensure_ascii=False, indent=4)
    
#     log(f"节点摘要已保存到 {file_path}")

# def load_summaries_to_nodes(nodes_list, file_path="nodes_summaries.json"):
#     with open(file_path, 'r', encoding='utf-8') as f:
#         summaries_dict = json.load(f)
#     sorted_keys = sorted(summaries_dict.keys(), key=int)

#     for key in sorted_keys:
#         idx = int(key)
#         if idx < len(nodes_list):
#             nodes_list[idx].metadata["node_summary"] = summaries_dict[key]
#         else:
#             log(f"警告：索引 {idx} 超出节点列表长度，跳过。")
    
#     return nodes_list

In [None]:
# save_summaries_to_json(nodes)

In [21]:
import pickle
# ============ 保存 Nodes ============
def save_nodes(nodes, save_dir="./saved_nodes"):
    """保存节点数据（支持pickle和json两种格式）"""
    save_path = Path(save_dir)
    save_path.mkdir(parents=True, exist_ok=True)
    
    # 方法1: 使用 pickle 保存完整节点对象（推荐）
    pickle_file = save_path / "nodes.pkl"
    with open(pickle_file, 'wb') as f:
        pickle.dump(nodes, f)
    log(f"Nodes已保存到: {pickle_file}")

def load_nodes(save_dir="./saved_data"):
    """加载节点数据"""
    save_path = Path(save_dir)
    pickle_file = save_path / "nodes.pkl"
    
    if not pickle_file.exists():
        raise FileNotFoundError(f"❌ 找不到节点文件: {pickle_file}")
    
    with open(pickle_file, 'rb') as f:
        nodes = pickle.load(f)
    
    log(f"✅ 已加载 {len(nodes)} 个节点")
    
    # 验证数据
    log(f"📊 节点验证:")
    log(f"  - 总节点数: {len(nodes)}")
    return nodes


In [22]:
save_nodes(nodes, save_dir="./saved_nodes")

Nodes已保存到: saved_nodes/nodes.pkl


In [23]:
# index = VectorStoreIndex.from_documents(
#     documents,
#     storage_context=storage_context,
#     embed_model=Settings.embed_model,
#     node_parser=node_parser,
#     store_nodes_override=True
# )



transformations = [node_parser]
index = VectorStoreIndex.from_documents(
    documents,
    storage_context=storage_context,
    embed_model=Settings.embed_model,
    node_parser=node_parser,
    transformations=transformations,
    store_nodes_override=True
)


In [24]:
from llama_index.core import StorageContext, load_index_from_storage
from pathlib import Path
# ============ 保存 Milvus 索引 ============
def save_milvus_index(index, persist_dir="./milvus_storage"):
    """保存Milvus索引（持久化到本地）"""
    persist_path = Path(persist_dir)
    persist_path.mkdir(parents=True, exist_ok=True)
    
    # LlamaIndex会自动保存索引结构和docstore
    index.storage_context.persist(persist_dir=persist_dir)
    
    log(f"✅ 索引已保存到: {persist_dir}")
    
    # 保存索引元信息
    index_info = {
        'collection_name': 'rag_collection',
        'milvus_db_path': abs_db_path,
        'embedding_dim': dimension,
        'total_documents': len(index.docstore.docs),
        'index_type': 'VectorStoreIndex'
    }
    
    info_file = persist_path / "index_info.json"
    with open(info_file, 'w', encoding='utf-8') as f:
        json.dump(index_info, f, ensure_ascii=False, indent=2)
    log(f"✅ 索引信息已保存到: {info_file}")
    log(f"📊 索引信息: {index_info}")

# ============ 使用示例 ============
# 保存索引
save_milvus_index(index, persist_dir="./milvus_storage")

# 加载索引
# index = load_milvus_index(persist_dir="./milvus_storage", milvus_db_path=abs_db_path)


✅ 索引已保存到: ./milvus_storage
✅ 索引信息已保存到: milvus_storage/index_info.json
📊 索引信息: {'collection_name': 'rag_collection', 'milvus_db_path': '/root/marathon_rag/milvus_test/milvus_lite.db', 'embedding_dim': 1024, 'total_documents': 82, 'index_type': 'VectorStoreIndex'}


### 非第一次运行 加载持久化运行

In [6]:
import pickle
import json
from pathlib import Path
def load_nodes(save_dir="./saved_data"):
    """加载节点数据"""
    save_path = Path(save_dir)
    pickle_file = save_path / "nodes.pkl"
    
    if not pickle_file.exists():
        raise FileNotFoundError(f"❌ 找不到节点文件: {pickle_file}")
    
    with open(pickle_file, 'rb') as f:
        nodes = pickle.load(f)
    
    log(f"✅ 已加载 {len(nodes)} 个节点")
    
    # 验证数据
    log(f"📊 节点验证:")
    log(f"  - 总节点数: {len(nodes)}")
    return nodes

In [7]:
from llama_index.core import StorageContext, load_index_from_storage
from pathlib import Path
# ============ 加载 Milvus 索引 ============
def load_milvus_index(persist_dir="./storage", milvus_db_path=None):
    """加载已保存的Milvus索引"""
    persist_path = Path(persist_dir)
    
    if not persist_path.exists():
        raise FileNotFoundError(f"❌ 找不到索引目录: {persist_dir}")
    
    # 读取索引信息
    info_file = persist_path / "index_info.json"
    if info_file.exists():
        with open(info_file, 'r', encoding='utf-8') as f:
            index_info = json.load(f)
        log(f"📊 索引信息: {index_info}")
        milvus_db_path = milvus_db_path or index_info.get('milvus_db_path')
    
    # 重建 Milvus vector store
    milvus_vector_store = MilvusVectorStore(
        uri=milvus_db_path,
        collection_name="rag_collection",
        dim=dimension,
        overwrite=False  # 不覆盖已有数据
    )
    
    # 重建 storage context
    storage_context = StorageContext.from_defaults(
        vector_store=milvus_vector_store,
        persist_dir=persist_dir
    )
    
    # 加载索引
    index = load_index_from_storage(
        storage_context=storage_context,
        embed_model=Settings.embed_model
    )
    
    log(f"✅ 索引已加载")
    log(f"  - 文档数量: {len(index.docstore.docs)}")
    
    return index



In [8]:
nodes = load_nodes(save_dir="./saved_nodes")
index = load_milvus_index(persist_dir="./milvus_storage", milvus_db_path=abs_db_path)
index.embed_model=Settings.embed_model

✅ 已加载 2493 个节点
📊 节点验证:
  - 总节点数: 2493
📊 索引信息: {'collection_name': 'rag_collection', 'milvus_db_path': '/root/milvus_test/milvus_lite.db', 'embedding_dim': 1024, 'total_documents': 2493, 'index_type': 'VectorStoreIndex'}


  from pkg_resources import DistributionNotFound, get_distribution
2025-10-13 17:07:30,504 - INFO - Loading all indices.


Loading llama_index.core.storage.kvstore.simple_kvstore from ./milvus_storage/docstore.json.
Loading llama_index.core.storage.kvstore.simple_kvstore from ./milvus_storage/index_store.json.
✅ 索引已加载
  - 文档数量: 2493


In [25]:
# 单次快速检索
retriever = index.as_retriever(similarity_top_k=3)
nodes_test = retriever.retrieve("腾讯游戏 三角洲行动")


# 打印结果
for i, node in enumerate(nodes_test, 1):
    log(f"\n[{i}] 分数: {node.score:.4f} | 文件: {node.metadata.get('file_name')}")
    log(f"内容: {node.text[:200]}")


[1] 分数: 0.2065 | 文件: None
内容: .............................................................................. 37 7.4 共享关系管理 ...........................................................................................................

[2] 分数: 0.1932 | 文件: None
内容: .......................................................... 19 5.1 交易查询 .................................................................................................................................

[3] 分数: 0.1927 | 文件: None
内容: ............................................................................................................ 40 10.1.3 操作页面 ............................................................................


## search and rerank

In [10]:
# !pip install llama-index-retrievers-bm25
# !pip install llama-index-packs-fusion-retriever

In [26]:
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import VectorIndexRetriever,QueryFusionRetriever
from llama_index.core import get_response_synthesizer
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.packs.fusion_retriever import HybridFusionRetrieverPack

In [27]:
# 在代码开头添加这些调试函数
import json
from typing import List
from llama_index.core.schema import NodeWithScore

def print_retrieved_nodes(nodes: List[NodeWithScore], title="检索到的节点"):
    """打印检索到的节点详细信息"""
    log(f"\n{'='*80}")
    log(f"{title} (共 {len(nodes)} 个)")
    log(f"{'='*80}")
    
    for i, node in enumerate(nodes, 1):
        log(f"\n[节点 {i}]")
        log(f"  分数: {node.score:.4f}")
        log(f"  节点ID: {node.node.node_id}")
        log(f"  文件名: {node.node.metadata.get('file_name', 'N/A')}")
        
        # 如果是子节点，显示父节点信息
        if node.node.metadata.get('is_child_node'):
            log(f"  父节点ID: {node.node.metadata.get('parent_node_id')}")
            log(f"  子节点索引: {node.node.metadata.get('chunk_index')}")
        
        # 显示文本内容（前200字符）
        log(f"  内容预览: {node.node.text[100]}")
        log(len(node.node.text))
        
        # 显示完整metadata
        log(f"  Metadata: {json.dumps(node.node.metadata, ensure_ascii=False, indent=4)}")
    
    log(f"\n{'='*80}\n")

def print_prompt_to_llm(query: str, context: str, template_name=""):
    """打印发送给LLM的完整prompt"""
    log(f"\n{'='*80}")
    log(f"发送给LLM的Prompt {template_name}")
    log(f"{'='*80}")
    log(f"\n【用户查询】\n{query}")
    log(f"\n【上下文信息】\n{context}")
    log(f"\n{'='*80}\n")


In [28]:
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle, MetadataMode
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from typing import List, Optional, Any, Union
from pathlib import Path
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM

class Qwen3Reranker(BaseNodePostprocessor):
    """
    Qwen3-Reranker for reranking nodes based on relevance to query.
    
    Args:
        model (str): Path to the Qwen3-Reranker model.
        top_n (int): Number of nodes to return sorted by score. Defaults to 5.
        device (str, optional): Device (like "cuda", "cpu") for computation. 
            If None, checks if a GPU can be used.
        max_length (int): Maximum sequence length. Defaults to 8192.
        instruction (str, optional): Custom instruction for reranking.
        keep_retrieval_score (bool, optional): Whether to keep the retrieval score 
            in metadata. Defaults to False.
        clear_cache_after_rerank (bool): Whether to clear GPU cache after reranking.
            Defaults to True.
        batch_size (int): Number of query-document pairs to process at once.
            Defaults to 5. Lower values use less memory but may be slower.
    """
    
    model: str = Field(description="Path to Qwen3-Reranker model.")
    top_n: int = Field(default=5, description="Number of nodes to return sorted by score.")
    device: Optional[str] = Field(default=None, description="Device for computation.")
    max_length: int = Field(default=8192, description="Maximum sequence length.")
    instruction: Optional[str] = Field(
        default=None, 
        description="Custom instruction for reranking."
    )
    keep_retrieval_score: bool = Field(
        default=False,
        description="Whether to keep the retrieval score in metadata.",
    )
    clear_cache_after_rerank: bool = Field(
        default=True,
        description="Whether to clear GPU cache after reranking.",
    )
    batch_size: int = Field(
        default=5,
        description="Number of query-document pairs to process at once.",
    )
    
    # 私有属性
    _tokenizer: Any = PrivateAttr()
    _model: Any = PrivateAttr()
    _device: str = PrivateAttr()
    _token_false_id: int = PrivateAttr()
    _token_true_id: int = PrivateAttr()
    _prefix: str = PrivateAttr()
    _suffix: str = PrivateAttr()
    _prefix_tokens: List[int] = PrivateAttr()
    _suffix_tokens: List[int] = PrivateAttr()
    
    def __init__(
        self,
        model: str,
        top_n: int = 5,
        device: Optional[str] = None,
        max_length: int = 8192,
        instruction: Optional[str] = None,
        keep_retrieval_score: bool = False,
        clear_cache_after_rerank: bool = True,
        batch_size: int = 5,
        **kwargs
    ):
        # 先调用父类初始化，传递所有 Field 属性
        super().__init__(
            model=model,
            top_n=top_n,
            device=device,
            max_length=max_length,
            instruction=instruction,
            keep_retrieval_score=keep_retrieval_score,
            clear_cache_after_rerank=clear_cache_after_rerank,
            batch_size=batch_size,
            **kwargs
        )
        
        # 验证 batch_size
        if self.batch_size < 1:
            raise ValueError(f"batch_size must be >= 1, got {self.batch_size}")
        
        # 设置默认 instruction
        if self.instruction is None:
            self.instruction = "Given a web search query, retrieve relevant passages that answer the query"
        
        # 推断设备
        if self.device is None:
            self._device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self._device = self.device
        
        # 加载 tokenizer
        self._tokenizer = AutoTokenizer.from_pretrained(
            self.model, 
            padding_side='left'
        )
        
        # 加载模型
        try:
            self._model = AutoModelForCausalLM.from_pretrained(
                self.model,
                # torch_dtype=torch.float16,  # 量化
                attn_implementation="flash_attention_2"
            ).to(self._device).eval()
            log("✓ Using flash_attention_2")
        except Exception as e:
            log(f"⚠ Flash attention not available, using default: {e}")
            self._model = AutoModelForCausalLM.from_pretrained(
                self.model,
                torch_dtype=torch.float16,
            ).to(self._device).eval()
        
        # 获取 yes/no token ids
        self._token_false_id = self._tokenizer.convert_tokens_to_ids("no")
        self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes")
        
        # 定义前缀和后缀
        self._prefix = (
            "<|im_start|>system\n"
            "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
            "Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
            "<|im_start|>user\n"
        )
        self._suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self._prefix_tokens = self._tokenizer.encode(self._prefix, add_special_tokens=False)
        self._suffix_tokens = self._tokenizer.encode(self._suffix, add_special_tokens=False)
    
    @classmethod
    def class_name(cls) -> str:
        """返回类名，用于序列化"""
        return "Qwen3Reranker"
    
    def _clear_gpu_cache(self):
        """清理 GPU 缓存和 KV cache"""
        if self._device.startswith("cuda"):
            # 清理模型的 KV cache（如果存在）
            if hasattr(self._model, 'clear_cache'):
                self._model.clear_cache()
            
            # 清空 CUDA 缓存
            torch.cuda.empty_cache()
            
            # 强制垃圾回收
            gc.collect()
            
            # 可选：同步 CUDA 操作
            torch.cuda.synchronize()
    
    def _format_instruction(self, query: str, doc: str) -> str:
        """格式化输入文本"""
        return f"<Instruct>: {self.instruction}\n<Query>: {query}\n<Document>: {doc}"
    
    def _process_inputs(self, pairs: List[str]):
        """处理输入对"""
        inputs = self._tokenizer(
            pairs, 
            padding=False, 
            truncation='longest_first',
            return_attention_mask=False, 
            max_length=self.max_length - len(self._prefix_tokens) - len(self._suffix_tokens)
        )
        
        # 添加前缀和后缀
        for i, ele in enumerate(inputs['input_ids']):
            inputs['input_ids'][i] = self._prefix_tokens + ele + self._suffix_tokens
        
        # 填充
        inputs = self._tokenizer.pad(
            inputs, 
            padding=True, 
            return_tensors="pt", 
            max_length=self.max_length
        )
        
        # 移动到设备
        for key in inputs:
            inputs[key] = inputs[key].to(self._device)
        
        return inputs
    
    @torch.no_grad()
    def _compute_scores_batch(self, pairs: List[str]) -> List[float]:
        """
        计算一批 pairs 的相关性分数
        
        Args:
            pairs: 格式化后的 query-document 对列表
            
        Returns:
            分数列表
        """
        inputs = self._process_inputs(pairs)
        
        try:
            batch_scores = self._model(**inputs).logits[:, -1, :]
            true_vector = batch_scores[:, self._token_true_id]
            false_vector = batch_scores[:, self._token_false_id]
            batch_scores = torch.stack([false_vector, true_vector], dim=1)
            batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
            scores = batch_scores[:, 1].exp().tolist()
            return scores
        finally:
            # 立即删除 inputs 释放显存
            del inputs
            if self._device.startswith("cuda"):
                torch.cuda.empty_cache()
    
    def _compute_scores(self, pairs: List[str]) -> List[float]:
        """
        分批计算所有 pairs 的相关性分数
        
        Args:
            pairs: 格式化后的 query-document 对列表
            
        Returns:
            所有 pairs 的分数列表
        """
        all_scores = []
        total_pairs = len(pairs)
        
        # 分批处理
        for i in range(0, total_pairs, self.batch_size):
            batch_pairs = pairs[i:i + self.batch_size]
            batch_scores = self._compute_scores_batch(batch_pairs)
            all_scores.extend(batch_scores)
            
            # 可选：打印进度
            if total_pairs > self.batch_size:
                processed = min(i + self.batch_size, total_pairs)
                log(f"Reranking progress: {processed}/{total_pairs} pairs processed")
        
        return all_scores
    
    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        """
        重排序节点（必须实现的抽象方法）
        
        Args:
            nodes: 待重排序的节点列表
            query_bundle: 查询信息
            
        Returns:
            重排序后的节点列表
        """
        if query_bundle is None:
            raise ValueError("Missing query bundle in extra info.")
        
        if len(nodes) == 0:
            return []
        
        try:
            # 准备查询-文档对
            query_str = query_bundle.query_str
            query_and_nodes = [
                (
                    query_str,
                    node.node.get_content(metadata_mode=MetadataMode.EMBED),
                )
                for node in nodes
            ]
            
            # 格式化输入
            pairs = [
                self._format_instruction(query, doc) 
                for query, doc in query_and_nodes
            ]
            
            # 使用 callback manager 记录事件（可选但推荐）
            with self.callback_manager.event(
                CBEventType.RERANKING,
                payload={
                    EventPayload.NODES: nodes,
                    EventPayload.MODEL_NAME: self.model,
                    EventPayload.QUERY_STR: query_str,
                    EventPayload.TOP_K: self.top_n,
                },
            ) as event:
                # 分批处理并计算分数
                scores = self._compute_scores(pairs)
                
                assert len(scores) == len(nodes), \
                    f"Score count mismatch: got {len(scores)} scores for {len(nodes)} nodes"
                
                # 更新节点分数
                for node, score in zip(nodes, scores):
                    if self.keep_retrieval_score:
                        # 保留原始检索分数
                        node.node.metadata["retrieval_score"] = node.score
                    node.score = float(score)
                
                # 按分数排序并返回 top_n
                new_nodes = sorted(
                    nodes, 
                    key=lambda x: -x.score if x.score else 0
                )[: self.top_n]
                
                # 记录结果
                event.on_end(payload={EventPayload.NODES: new_nodes})
            
            return new_nodes
        
        finally:
            # 无论是否出错，都清理缓存
            if self.clear_cache_after_rerank:
                self._clear_gpu_cache()


In [30]:
reranker = Qwen3Reranker(
    model="/root/autodl-tmp/Qwen3-Reranker-4B",
    top_n=5,
    device="cuda",
    max_length=8192,
    instruction="根据用户的问题，判断文档是否包含相关答案或信息",
    keep_retrieval_score=True,
    clear_cache_after_rerank=True,
    batch_size=5  # 默认值，一次处理 5 对
)

`torch_dtype` is deprecated! Use `dtype` instead!


⚠ Flash attention not available, using default: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [25]:
# reranker_model_path = "autodl-tmp/Qwen3-Reranker-4B"

# reranker = SentenceTransformerRerank(
#     model=reranker_model_path,
#     top_n=5,
#     device="cuda",
#     trust_remote_code=True
# )

# cross_encoder = reranker._model
# reranker_tokenizer = cross_encoder.tokenizer
# reranker_model = cross_encoder.model

# special_tokens = {'pad_token': '[PAD]'}
# num_added_tokens = reranker_tokenizer.add_special_tokens(special_tokens)

# reranker_model.resize_token_embeddings(len(reranker_tokenizer))

# reranker_tokenizer.pad_token = '[PAD]'
# reranker_tokenizer.pad_token_id = reranker_tokenizer.convert_tokens_to_ids('[PAD]')
# reranker_model.config.pad_token_id = reranker_tokenizer.pad_token_id

# log(f"Pad token: {reranker_tokenizer.pad_token}")
# log(f"Pad token ID: {reranker_tokenizer.pad_token_id}")
# log(f"Model config pad_token_id: {reranker_model.config.pad_token_id}")

# custom_instruction = "根据用户的问题，判断文档是否包含相关答案或信息"  # 中文场景
# reranker = Qwen3Reranker(
#     model_path="autodl-tmp/Qwen3-Reranker-4B",
#     top_n=5,
#     device="cuda",
#     max_length=8192, #最长输入
#     instruction=None,
#     keep_retrieval_score=True  # 如果想保留原始检索分数
# )


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of Qwen3ForSequenceClassification were not initialized from the model checkpoint at autodl-tmp/Qwen3-Reranker-4B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Pad token: [PAD]
Pad token ID: 151669
Model config pad_token_id: 151669


In [31]:
def reranker_tokenize(text):
    rerank_tokenizer = AutoTokenizer.from_pretrained("autodl-tmp/Qwen3-Reranker-8B", padding_side='left')
    if not text.strip():
        return []
    tokens = rerank_tokenizer.tokenize(text)
    return tokens

bm25_retriever = BM25Retriever.from_defaults(
    nodes=nodes, 
    similarity_top_k=25,
    tokenizer=reranker_tokenize)

2025-10-16 14:13:50,856 - DEBUG - Building index from IDs objects


In [32]:
vector_retriever = index.as_retriever(similarity_top_k=25)

In [33]:
hybrid_retriever = QueryFusionRetriever(
    [vector_retriever, bm25_retriever],
    similarity_top_k=50,
    num_queries=1,  # set this to 1 to disable query generation
    mode="reciprocal_rerank",
    use_async=False,
    verbose=True
)

In [34]:
reranker.top_n=5

In [35]:
text_qa_template_str = (
    "上下文信息如下：\n"
    "{context_str}\n"
    "基于提供的上下文，用中文直接回答查询，答案只能从上下文知识中获取，不要自己发挥。\n"
    "查询：{query_str}\n"
    "回答："
)
text_qa_template = PromptTemplate(text_qa_template_str)

refine_template_str = (
    "原始查询是：{query_str}\n"
    "我们已有回答：{existing_answer}\n"
    "基于以下新上下文，用中文精炼现有回答，问题的核心回答要放在最前边，然后是解释，确保完整性和准确性：\n"
    "{context_msg}\n"
    "精炼后的回答："
)
refine_template = PromptTemplate(refine_template_str)

In [36]:
response_synthesizer = get_response_synthesizer(
    text_qa_template=text_qa_template,
    refine_template=refine_template,
    response_mode="compact"
)

In [37]:
from llama_index.core.postprocessor import LongContextReorder
longcontextreorder=LongContextReorder()

In [38]:
query_engine = RetrieverQueryEngine(
    retriever=hybrid_retriever,
    response_synthesizer=response_synthesizer,
    node_postprocessors=[reranker,longcontextreorder]
)

In [29]:
query = """
您是一家大型商业银行的首席合规官。您的一位客户是腾讯的一位高管，由于其家庭关系，他也被列为“外国政要”。在腾讯2025年第二季度财报发布后的一周内，他通过贵行进行了以下交易：

他使用腾讯发行的单位卡购买了一件价值11万元人民币的艺术品，摆放在办公室。

他将8万元人民币现金存入个人账户，并注明这笔资金来自个人股息。

他作为付款人签署了一张金额为600万元人民币的商业承兑汇票，付款期限为90天。该草案旨在为一家3D打印公司提供新的融资，该公司将使用腾讯的“混元3D模型”人工智能服务，该服务在最近的财报中被重点提及。

您的任务：

对于这第一笔交易，请分别找出任何可能违反“支付结算办法”的情况"""
response = query_engine.query(query)
log(response)

Reranking progress: 5/38 pairs processed
Reranking progress: 10/38 pairs processed
Reranking progress: 15/38 pairs processed
Reranking progress: 20/38 pairs processed
Reranking progress: 25/38 pairs processed
Reranking progress: 30/38 pairs processed
Reranking progress: 35/38 pairs processed
Reranking progress: 38/38 pairs processed
✓ 请求已保存到: llm_requests/requests_log.txt
根据《支付结算办法》的相关规定，针对该客户使用腾讯发行的单位卡购买一件价值11万元人民币的艺术品的交易，可能存在以下违反“支付结算办法”的情况：

1. **单位卡用于大额商品交易**：  
   根据《支付结算办法》第142条：“单位卡不得用于10万元以上的商品交易、劳务供应款项的结算。”  
   该交易金额为11万元人民币，超过了10万元的限额，因此**违反了该条规定**。

2. **单位卡资金来源与用途不符**：  
   根据《支付结算办法》第137条：“单位卡帐户的资金一律从其基本存款帐户转帐存入，不得交存现金，不得将销货收入的款项存入其帐户。”  
   该交易是通过单位卡支付，但购买艺术品属于非经营性支出，且未说明该资金是否从基本存款账户转帐而来。若该资金来源于非基本账户或为现金存入，则违反了单位卡资金来源的管理规定。

3. **单位卡用于非经营性用途**：  
   单位卡的使用应限于与单位经营相关的业务活动。购买艺术品用于办公室装饰，属于非经营性支出，不符合单位卡的合规使用范围，可能构成**滥用单位卡**，违反《支付结算办法》关于单位卡使用范围的限制。

综上，该笔交易**违反了《支付结算办法》第142条关于单位卡不得用于10万元以上商品交易的规定**，并可能涉及单位卡资金来源和用途的违规问题。


In [20]:
query = """
您是一家大型商业银行的首席合规官。您的一位客户是腾讯的一位高管，由于其家庭关系，他也被列为“外国政要”。在腾讯2025年第二季度财报发布后的一周内，他通过贵行进行了以下交易：

他使用腾讯发行的单位卡购买了一件价值11万元人民币的艺术品，摆放在办公室。

他将8万元人民币现金存入个人账户，并注明这笔资金来自个人股息。

他作为付款人签署了一张金额为600万元人民币的商业承兑汇票，付款期限为90天。该草案旨在为一家3D打印公司提供新的融资，该公司将使用腾讯的“混元3D模型”人工智能服务，该服务在最近的财报中被重点提及。

您的任务：

仅根据提供的文件，回答以下问题。

对于这第一笔交易，请分别找出任何可能违反“支付结算办法”的情况。请引用文件中的具体条款编号来支持您的发现。"""

# # 方法1：直接使用retriever查看检索结果
log("\n🔍 【步骤1：混合检索】")
retrieved_nodes = hybrid_retriever.retrieve(query)
# print_retrieved_nodes(retrieved_nodes, "混合检索结果")

# # 方法2：查看rerank后的结果
log("\n🎯 【步骤2：Rerank重排序】")
from llama_index.core.schema import QueryBundle
query_bundle = QueryBundle(query_str=query)
reranked_nodes = reranker.postprocess_nodes(retrieved_nodes, query_bundle)
# print_retrieved_nodes(reranked_nodes, "Rerank后的节点")

# # # 方法3：查看最终发送给LLM的内容
# log("\n📝 【步骤3：生成回答】")
# # 手动构建context来查看
# context_str = "\n\n".join([node.node.get_content() for node in reranked_nodes])
# print_prompt_to_llm(query, context_str, "(text_qa_template)")

# # 执行查询
# response = query_engine.query(query)
# log(f"\n✅ 【最终回答】\n{response}")


🔍 【步骤1：混合检索】


You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



🎯 【步骤2：Rerank重排序】




Reranking progress: 5/38 pairs processed
Reranking progress: 10/38 pairs processed
Reranking progress: 15/38 pairs processed
Reranking progress: 20/38 pairs processed
Reranking progress: 25/38 pairs processed
Reranking progress: 30/38 pairs processed
Reranking progress: 35/38 pairs processed
Reranking progress: 38/38 pairs processed


In [None]:
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from llama_index.core.retrievers import BaseRetriever
from typing import List, Optional, Dict
import copy

# ==================== 1. 节点分割器 ====================
class NodeSplitter:
    """将长节点分割成多个子节点,保持父子关系"""
    
    def __init__(self, chunk_size: int = 512, overlap_ratio: float = 0.1):
        """
        Args:
            chunk_size: 子节点的目标长度
            overlap_ratio: 重叠比例 (0.1 表示 10%)
        """
        self.chunk_size = chunk_size
        self.overlap_size = int(chunk_size * overlap_ratio)
        
    def split_node(self, node: NodeWithScore, parent_id: str = None) -> List[NodeWithScore]:
        """
        将单个节点分割成多个子节点
        
        Args:
            node: 原始节点
            parent_id: 父节点ID (如果为None,使用node.node.node_id)
            
        Returns:
            子节点列表,每个子节点都保留父节点引用
        """
        text = node.node.text
        text_length = len(text)
        
        # 如果文本长度小于chunk_size,直接返回原节点
        if text_length <= self.chunk_size:
            # 添加父节点ID到metadata
            node.node.metadata['parent_node_id'] = parent_id or node.node.node_id
            node.node.metadata['is_child_node'] = False
            return [node]
        
        parent_node_id = parent_id or node.node.node_id
        child_nodes = []
        start = 0
        chunk_index = 0
        
        while start < text_length:
            end = min(start + self.chunk_size, text_length)
            chunk_text = text[start:end]
            
            # 创建子节点
            child_node = TextNode(
                text=chunk_text,
                metadata={
                    **node.node.metadata,  # 继承父节点的metadata
                    'parent_node_id': parent_node_id,
                    'chunk_index': chunk_index,
                    'is_child_node': True,
                    'parent_text_length': text_length,
                    'chunk_start': start,
                    'chunk_end': end
                },
                excluded_embed_metadata_keys=node.node.excluded_embed_metadata_keys,
                excluded_llm_metadata_keys=node.node.excluded_llm_metadata_keys,
            )
            
            # 保持原始评分
            child_node_with_score = NodeWithScore(
                node=child_node,
                score=node.score
            )
            
            child_nodes.append(child_node_with_score)
            
            # 计算下一个起点 (带重叠)
            start += (self.chunk_size - self.overlap_size)
            chunk_index += 1
        
        return child_nodes
    
    def split_nodes(self, nodes: List[NodeWithScore]) -> tuple[List[NodeWithScore], Dict[str, NodeWithScore]]:
        """
        批量分割节点
        
        Returns:
            (子节点列表, 父节点映射字典)
        """
        all_child_nodes = []
        parent_node_map = {}  # parent_node_id -> 原始父节点
        
        for node in nodes:
            parent_id = node.node.node_id
            parent_node_map[parent_id] = node  # 保存原始父节点
            
            child_nodes = self.split_node(node, parent_id)
            all_child_nodes.extend(child_nodes)
        
        return all_child_nodes, parent_node_map


# ==================== 2. 子节点到父节点的后处理器 ====================
class ChildToParentPostprocessor(BaseNodePostprocessor):
    """
    将rerank后的子节点还原为父节点
    策略: 如果多个子节点来自同一父节点,取最高分的子节点分数作为父节点分数
    """
    
    # 使用 Pydantic 的方式声明字段
    parent_node_map: Dict[str, Any] = {}
    keep_top_k: int = 5
    
    def __init__(self, parent_node_map: Dict[str, NodeWithScore], keep_top_k: int = 5, **kwargs):
        """
        Args:
            parent_node_map: 父节点ID到父节点的映射
            keep_top_k: 最终保留的父节点数量
        """
        # 使用 Pydantic 的初始化方式
        super().__init__(
            parent_node_map=parent_node_map,
            keep_top_k=keep_top_k,
            **kwargs
        )
    
    def _postprocess_nodes(
        self, 
        nodes: List[NodeWithScore], 
        query_bundle: Optional[QueryBundle] = None
    ) -> List[NodeWithScore]:
        """
        将子节点还原为父节点
        """
        # 按父节点ID分组,记录每个父节点的最高分数
        parent_scores: Dict[str, float] = {}
        parent_child_nodes: Dict[str, List[NodeWithScore]] = {}
        
        for node in nodes:
            parent_id = node.node.metadata.get('parent_node_id')
            
            if not parent_id:
                # 如果没有父节点ID,说明是原始节点,直接保留
                parent_scores[node.node.node_id] = node.score
                parent_child_nodes[node.node.node_id] = [node]
                continue
            
            # 记录最高分数
            if parent_id not in parent_scores:
                parent_scores[parent_id] = node.score
                parent_child_nodes[parent_id] = [node]
            else:
                # 取最高分
                parent_scores[parent_id] = max(parent_scores[parent_id], node.score)
                parent_child_nodes[parent_id].append(node)
        
        # 构建父节点列表
        parent_nodes = []
        for parent_id, score in parent_scores.items():
            if parent_id in self.parent_node_map:
                # 使用保存的原始父节点
                parent_node = copy.deepcopy(self.parent_node_map[parent_id])
                parent_node.score = score
                
                # 可选: 在metadata中记录匹配的子节点信息
                child_info = [
                    {
                        'chunk_index': n.node.metadata.get('chunk_index'),
                        'score': n.score,
                        'text_preview': n.node.text[:100]
                    }
                    for n in parent_child_nodes[parent_id]
                ]
                parent_node.node.metadata['matched_children'] = child_info
                
                parent_nodes.append(parent_node)
            else:
                # 如果找不到父节点,使用第一个子节点(不应该发生)
                log(f"警告: 找不到父节点 {parent_id}, 使用子节点代替")
                parent_nodes.append(parent_child_nodes[parent_id][0])
        
        # 按分数排序并返回top_k
        parent_nodes.sort(key=lambda x: x.score, reverse=True)
        return parent_nodes[:self.keep_top_k]
    
    class Config:
        arbitrary_types_allowed = True  # 允许任意类型

# ==================== 3. 自定义检索器包装器 ====================
class SplitNodeRetriever(BaseRetriever):
    """
    包装原始检索器,自动处理节点分割
    """
    
    def __init__(
        self, 
        base_retriever: BaseRetriever,
        chunk_size: int = 512,
        overlap_ratio: float = 0.1
    ):
        """
        Args:
            base_retriever: 原始混合检索器
            chunk_size: 子节点大小
            overlap_ratio: 重叠比例
        """
        super().__init__()
        self.base_retriever = base_retriever
        self.node_splitter = NodeSplitter(chunk_size, overlap_ratio)
        self.parent_node_map = {}
    
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """
        检索并分割节点
        """
        # 1. 使用原始检索器检索
        nodes = self.base_retriever.retrieve(query_bundle)
        
        # 2. 分割节点
        child_nodes, self.parent_node_map = self.node_splitter.split_nodes(nodes)
        
        log(f"原始节点数: {len(nodes)}, 分割后子节点数: {len(child_nodes)}")
        
        return child_nodes
    
    def get_parent_node_map(self) -> Dict[str, NodeWithScore]:
        """获取父节点映射,供后处理器使用"""
        return self.parent_node_map


def create_parent_postprocessor(retriever: SplitNodeRetriever, keep_top_k: int = 5):
    """动态创建父节点后处理器"""
    return ChildToParentPostprocessor(
        parent_node_map=retriever.get_parent_node_map(),
        keep_top_k=keep_top_k
    )


class DynamicQueryEngine:
    """支持动态后处理器的查询引擎"""
    
    def __init__(
        self, 
        retriever, 
        response_synthesizer, 
        reranker, 
        keep_top_k=5,
        use_parent_nodes=True,
        reorder=None
    ):
        self.retriever = retriever
        self.response_synthesizer = response_synthesizer
        self.reranker = reranker
        self.keep_top_k = keep_top_k
        self.use_parent_nodes = use_parent_nodes
        self.reorder = reorder

    def longcontext_postprocess_nodes(
        self,
        nodes: List[NodeWithScore]
    ) -> List[NodeWithScore]:
        """Postprocess nodes."""
        reordered_nodes: List[NodeWithScore] = []
        ordered_nodes: List[NodeWithScore] = sorted(
            nodes, key=lambda x: x.score if x.score is not None else 0
        )
        for i, node in enumerate(ordered_nodes):
            if i % 2 == 0:
                reordered_nodes.insert(0, node)
            else:
                reordered_nodes.append(node)
        return reordered_nodes
    
    def query(self, query_str: str):
        from llama_index.core.schema import QueryBundle
        
        # 记录总开始时间
        total_start = time.time()
        timing_stats: Dict[str, float] = {}
        
        # 1. 检索 (自动分割节点)
        retrieval_start = time.time()
        query_bundle = QueryBundle(query_str=query_str)
        nodes = self.retriever.retrieve(query_bundle)
        timing_stats['检索'] = time.time() - retrieval_start
        
        # 2. Rerank子节点
        rerank_start = time.time()
        reranked_nodes = self.reranker.postprocess_nodes(nodes, query_bundle)
        timing_stats['Rerank'] = time.time() - rerank_start
        
        # 3. 根据开关决定是否还原父节点
        parent_start = time.time()
        if self.use_parent_nodes:
            parent_postprocessor = create_parent_postprocessor(
                self.retriever, 
                keep_top_k=self.keep_top_k
            )
            final_nodes = parent_postprocessor.postprocess_nodes(reranked_nodes, query_bundle)
            timing_stats['还原父节点'] = time.time() - parent_start
        else:
            final_nodes = reranked_nodes[:self.keep_top_k]
            timing_stats['截取节点'] = time.time() - parent_start
        
        # 4. Reorder (如果启用)
        if self.reorder:
            reorder_start = time.time()
            final_nodes = self.longcontext_postprocess_nodes(final_nodes)
            timing_stats['Reorder'] = time.time() - reorder_start
        
        # 5. 生成回答
        synthesis_start = time.time()
        response = self.response_synthesizer.synthesize(
            query=query_str,
            nodes=final_nodes
        )
        timing_stats['生成回答'] = time.time() - synthesis_start
        
        # 计算总耗时
        timing_stats['总耗时'] = time.time() - total_start
        
        # 简洁的耗时输出
        # log(f"检索: {timing_stats['检索']:.2f}s | Rerank: {timing_stats['Rerank']:.2f}s | 生成: {timing_stats['生成回答']:.2f}s | 总计: {timing_stats['总耗时']:.2f}s")

            # 打印耗时统计
        log("\n" + "="*50)
        log("⏱️  耗时统计:")
        log("="*50)
        for step, duration in timing_stats.items():
            if step != '总耗时':
                percentage = (duration / timing_stats['总耗时']) * 100
                log(f"{step:12s}: {duration:6.3f}秒 ({percentage:5.1f}%)")
        log("-"*50)
        log(f"{'总耗时':12s}: {timing_stats['总耗时']:6.3f}秒 (100.0%)")
        log("="*50 + "\n")
    
        
        return response

In [None]:
split_retriever = SplitNodeRetriever(
    base_retriever=hybrid_retriever,
    chunk_size=256,      # 分割为512长度 (或256)
    overlap_ratio=0    # 0%重叠
)


reranker.top_n=10
reranker.batch_size=10

dynamic_query_engine = DynamicQueryEngine(
    retriever=split_retriever,
    response_synthesizer=response_synthesizer,
    reranker=reranker,
    keep_top_k=10,
    use_parent_nodes=False,  # 🔥 直接用top5子节点
    reorder=True
)


In [None]:
query = """
在云闪付业务中，总行网络金融事业部的主要职责是什么"""
response = dynamic_query_engine.query(query)
log(response)