In [16]:
import concurrent.futures

from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.dashscope import DashScopeEmbedding
import json
import os
from pathlib import Path
from typing import List, Dict, Union
from tqdm import tqdm
from dashscope import TextReRank
from llama_index.core import load_index_from_storage, StorageContext
from llama_index.core.schema import NodeWithScore, MetadataMode
from llama_index.core import Settings
from llama_index.retrievers.bm25 import BM25Retriever

class UnifiedEmbedding:
    """统一嵌入模型接口"""
    def __init__(
        self,
        model_type: str = "dashscope",
        model_name: str = "text-embedding-v2",
        dashscope_text_type: str = "document",
        api_key: str = None,
        base_url: str = None
    ):
        """
        Args:
            model_type: 模型类型 (openai/dashscope)
            model_name: 模型名称
            dashscope_text_type: DashScope文本类型
            api_key: API密钥
            base_url: 服务地址
        """
        self.model_type = model_type
        
        if model_type == "dashscope":
            self.embedder = DashScopeEmbedding(
                model_name=model_name,
                text_type=dashscope_text_type,
                api_key=api_key
            )
        else:
            self.embedder = OpenAIEmbedding(
                # openai 默认的嵌入模型
                # model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
                # model=model_name,
                api_key=api_key,
                api_base=base_url
            )

    def get_embed_model(self):
        return self.embedder

def load_indexes_from_folder(storage_dir: str) -> List:
    """从指定文件夹加载所有索引"""
    storage_path = Path(storage_dir)
    if any(storage_path.glob("*.json")):
        index_dirs = [storage_dir]
    else:
        index_dirs = sorted([d for d in Path(storage_dir).iterdir() if d.is_dir()])
    return [load_index_from_storage(
        StorageContext.from_defaults(persist_dir=str(d))
    ) for d in index_dirs]

def process_retrieval(
    data: Dict,
    retrievers: List,
    use_rerank: bool = False,
    rerank_model: str = "gte-rerank-v2"
) -> Dict:
    """处理单个查询的检索流程，同时返回原始和重排序结果"""
    # 执行多索引检索
    nodes = []
    for retriever in retrievers:
        nodes.extend(retriever.retrieve(data["query"]))
    
    # 原始排序结果
    original_sorted = sorted(nodes, key=lambda x: x.score, reverse=True)
    
    # 初始化重排序结果（默认与原始相同）
    reranked_sorted = original_sorted.copy()
    
    # 执行重排序逻辑
    if use_rerank and len(nodes) > 0:
        rerank_results = TextReRank.call(
            model=rerank_model,
            query=data["query"],
            documents=[n.get_content() for n in nodes],
            top_n=10,
            return_documents=False
        )
        reranked_sorted = [nodes[r['index']] for r in rerank_results.output['results']]
    
    # 构建包含两种结果的结构
    return {
        "query": data["query"],
        "answer": data["answer"],
        "question_type": data["question_type"],
        "retrieval_original": format_retrieval_results(original_sorted),
        "retrieval_reranked": format_retrieval_results(reranked_sorted),
        "gold_list": data["evidence_list"]
    }

def format_retrieval_results(nodes: List[NodeWithScore]) -> List[Dict]:
    """格式化检索结果"""
    return [{
        "text": node.get_content(metadata_mode=MetadataMode.LLM),
        "score": node.score
    } for node in nodes]

def split_and_save_results(results: List[Dict], output_dir: Union[str, Path]):
    """分割并保存两种检索结果"""
    output_dir = Path(output_dir)
    original_results = []
    reranked_results = []
    
    for res in results:
        if not res:
            continue
            
        base_entry = {
            "query": res["query"],
            "answer": res["answer"],
            "question_type": res["question_type"],
            "gold_list": res["gold_list"]
        }
        
        original_entry = {
            **base_entry,
            "retrieval_list": res["retrieval_original"]
        }
        reranked_entry = {
            **base_entry,
            "retrieval_list": res["retrieval_reranked"]
        }
        
        original_results.append(original_entry)
        reranked_results.append(reranked_entry)

    # 保存结果文件
    def save_data(data, filename):
        with open(output_dir / filename, 'w') as f:
            json.dump(data, f, indent=2)
    
    save_data(original_results, "retrieval_results_original.json")
    save_data(reranked_results, "retrieval_results_reranked.json")

def run_bm25retriever_pipeline(
    index_dir: str,
    query_path: str,
    output_dir: str,
    top_k: int = 20,
):
    os.makedirs(output_dir, exist_ok=True)
    storage_path = Path(index_dir)
    if any(storage_path.glob("*.json")):
        index_dirs = [index_dir]
    else :
        index_dirs = sorted([d for d in Path(index_dir).iterdir() if d.is_dir()])
    retrievers = [BM25Retriever.from_persist_dir(path=index_dir) for index_dir in index_dirs]
    for retriever in retrievers:
        retriever.similarity_top_k = top_k
    print("加载索引完成...")
    
    with open(query_path, 'r') as f:
        query_data = json.load(f)
        
    results = []
    for data in tqdm(query_data):
        result = process_retrieval(data, retrievers)
        results.append(result)
        
    split_and_save_results(results, output_dir)
        
def run_thread_retrieval_pipeline(
    index_dir: str,
    query_path: str,
    output_dir: str,
    model_type: str = "dashscope",
    model_name: str = "text-embedding-v2",
    api_key: str = None,
    base_url: str = None,
    use_rerank: bool = False,
    rerank_model: str = "gte-rerank-v2",
    max_workers = 16,
    top_k: int = 20,
):
    
    embed_model = UnifiedEmbedding(
        model_type=model_type,
        model_name=model_name,
        api_key=api_key,
        base_url=base_url
    ).get_embed_model()
    
    Settings.embed_model = embed_model
    
    """执行检索流程主函数"""
    # 初始化组件
    os.makedirs(output_dir, exist_ok=True)
    indexes = load_indexes_from_folder(index_dir)
    retrievers = [index.as_retriever(similarity_top_k=top_k) for index in indexes]
    print("加载索引完成...")
    
    # 加载查询数据
    with open(query_path, 'r') as f:
        query_data = json.load(f)
    
    # 处理所有查询 使用线程池并行处理
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_retrieval, data, retrievers, use_rerank, rerank_model) for data in query_data]
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(query_data)):
            result = future.result()
            if result:
                results.append(result)
    
    split_and_save_results(results, output_dir)

In [18]:
run_thread_retrieval_pipeline(
    index_dir="./embeddings/dashscope/balance_3/index_group_0",
    query_path="dataset/MultiHopRAG.json",
    output_dir="./rerank/dashscope/with_rerank_balance_3_0_30",
    use_rerank=True,
    rerank_model="gte-rerank-v2",
    model_type="dashscope",
    model_name="text-embedding-v2",
    api_key=os.getenv("DASHSCOPE_API_KEY"),
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    top_k=30,
    max_workers=16,  # 根据机器性能调整
)

加载索引完成...


100%|██████████| 2556/2556 [16:36<00:00,  2.57it/s]


In [3]:
# 同步
run_thread_retrieval_pipeline(
    index_dir="./embeddings/openai/balance_1",
    query_path="dataset/MultiHopRAG.json",
    output_dir="./rerank/openai/with_rerank_balance_1_20",
    use_rerank=False,
    rerank_model="gte-rerank-v2",
    model_type="openai",
    model_name="text-embedding-3-large",
    api_key=os.getenv("OPENAI_API_KEY_COST"),
    base_url=os.getenv("OPENAI_API_BASE"),
    max_workers=8,
    top_k=20,
)

加载索引完成...


100%|██████████| 2556/2556 [44:13<00:00,  1.04s/it] 


In [17]:
run_bm25retriever_pipeline(
    index_dir="embeddings/bm25/balance_3/index_group_0",
    query_path="dataset/MultiHopRAG.json",
    output_dir="rerank/bm25/with_rerank_balance_3_0_30",
    top_k=30,
)

加载索引完成...


100%|██████████| 2556/2556 [00:10<00:00, 249.27it/s]


In [4]:
from concurrent.futures import ThreadPoolExecutor

def rerank_existing_file(
    input_path: str,
    output_path: str,
    rerank_model: str = "gte-rerank-v2",
    max_workers: int = 64,
):
    """
    对已有检索结果进行重新排序
    Args:
        input_path: 输入文件路径（需包含retrieval_list字段）
        output_path: 输出文件路径
        rerank_model: 重排序模型名称
        max_workers: 最大并发数
    """
    # 读取原始文件
    with open(input_path, 'r') as f:
        original_data = json.load(f)

    # 处理单个项目的重排序
    def process_rerank(item: Dict) -> Dict:
        """执行重排序并更新结果"""
        try:
            nodes = item['retrieval_list']
            if not nodes:
                return item
            
            # 准备重排序参数
            documents = [n['text'] for n in nodes]
            
            # 调用重排序API
            rerank_result = TextReRank.call(
                model=rerank_model,
                query=item['query'],
                documents=documents,
                top_n=10,
                return_documents=False
            )
            
            # 更新排序结果
            sorted_nodes = [nodes[r['index']] for r in rerank_result.output['results']]
            
            # 保留原始数据结构
            return {
                **item,
                "retrieval_list": [
                    {
                        "text": node['text'],
                        "score": node['score']  # 可保留原始分数或使用新分数
                    } for node in sorted_nodes
                ]
            }
        except Exception as e:
            print(f"处理失败: {item['query']} - {str(e)}")
            return item  # 返回原始数据

    # 多线程处理
    processed_data = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_rerank, item) for item in original_data]
        
        # 添加进度条和延迟控制
        for future in tqdm(concurrent.futures.as_completed(futures), 
                          total=len(original_data),
                          desc="Reranking"):
            processed_data.append(future.result())

    # 保存结果
    os.makedirs(Path(output_path).parent, exist_ok=True)
    with open(output_path, 'w') as f:
        json.dump(processed_data, f, indent=2)

In [None]:
rerank_existing_file(
    input_path="rerank/dashscope/without_rerank_balance_2/retrieval_results.json",
    output_path="rerank/dashscope/with_rerank_balance_2/retrieval_results.json",
    rerank_model="gte-rerank-v2",
    max_workers=64,  # 建议根据API配额设置
)