In [1]:
# rag_eval_ms_marco_v1_1.py
# 基于 MS MARCO Passage Ranking v1.1 数据集 测试检索 Hit@k & MRR@k
# 数据量约 168.7 MB（压缩），train=82k queries，validation=10k queries

import os
import numpy as np
from datasets import load_dataset
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from config import settings
from tqdm import tqdm
import time

INDEX_DIR = "embeddings/ms_marco_v1_1_passages"
TOP_K = 10

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def build_index_with_speed(batch_size=200, cache_dir=None):
    """构建向量库并显示当前嵌入速率，支持断点续传"""
    print("Loading MS MARCO v1.1 train split passages...")
    train_ds = load_dataset("ms_marco", "v1.1", split="train", cache_dir=cache_dir)

    # 使用字典而非集合来处理去重，同时保存ID信息
    passage_dict = {}
    for item in tqdm(train_ds, desc="Collecting passages"):
        p = item["passages"]
        for i, text in enumerate(p["passage_text"]):
            if text not in passage_dict:
                passage_dict[text] = f"msmarco_passage_{len(passage_dict)}"
    
    passages = list(passage_dict.keys())
    ids = list(passage_dict.values())
    print(f"Collected {len(passages)} unique passages.")
    
    # 确保目录存在
    os.makedirs(INDEX_DIR, exist_ok=True)
    
    # 实现断点续传机制
    checkpoint_file = os.path.join(INDEX_DIR, "embedding_checkpoint.npz")
    start_idx = 0
    all_embeddings = []
    
    # 检查检查点文件是否存在
    if os.path.exists(checkpoint_file):
        try:
            checkpoint = np.load(checkpoint_file, allow_pickle=True)
            all_embeddings = checkpoint['embeddings'].tolist()
            start_idx = len(all_embeddings)
            print(f"Resuming from checkpoint with {start_idx} embeddings")
        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting from scratch.")
            all_embeddings = []
            start_idx = 0
    else:
        print("No checkpoint found. Starting from scratch.")
    
    # 按批次进行嵌入
    remaining = passages[start_idx:]
    batches = [remaining[i:i+batch_size] for i in range(0, len(remaining), batch_size)]
    
    embedder = OpenAIEmbeddings(model=settings.EMBEDDING_MODEL)

    start = time.time()
    pbar = tqdm(batches, desc="Embedding batches")
    try:
        for i, batch in enumerate(pbar):
            # 添加重试逻辑
            max_retries = 3
            for attempt in range(max_retries):
                try:
                    embs = embedder.embed_documents(batch)
                    all_embeddings.extend(embs)
                    break
                except Exception as e:
                    if attempt < max_retries - 1:
                        wait_time = 2 ** attempt  # 指数退避
                        print(f"Embedding failed: {e}. Retrying in {wait_time}s...")
                        time.sleep(wait_time)
                    else:
                        raise
            
            # 定期保存检查点
            if (i + 1) % 10 == 0:
                np.savez(checkpoint_file, embeddings=np.array(all_embeddings, dtype=object))
            
            elapsed = time.time() - start
            done = len(all_embeddings)
            rate = done / elapsed if elapsed > 0 else 0.0
            pbar.set_postfix({"docs/sec": f"{rate:.1f}"})
    except KeyboardInterrupt:
        print("Operation interrupted. Saving progress...")
    finally:
        if all_embeddings:
            np.savez(checkpoint_file, embeddings=np.array(all_embeddings, dtype=object))
    
    # 构建并保存 FAISS 索引
    docs = [Document(page_content=text, metadata={"source": id_})
            for text, id_ in zip(passages[:len(all_embeddings)], ids[:len(all_embeddings)])]
    
    db = FAISS.from_embeddings(all_embeddings, docs)
    db.save_local(INDEX_DIR)
    print(f"Index saved to {INDEX_DIR}")
    return db

In [3]:
def load_index():
    """
    加载已有索引；若不存在则构建带速率的索引
    """
    if not os.path.isdir(INDEX_DIR) or not os.path.exists(os.path.join(INDEX_DIR, "index.faiss")):
        print("Index not found, building new index...")
        return build_index_with_speed(batch_size=200)
    print(f"Loading existing index from {INDEX_DIR}")
    embeddings = OpenAIEmbeddings(model=settings.EMBEDDING_MODEL)
    return FAISS.load_local(INDEX_DIR, embeddings)

In [4]:
def evaluate_retrieval(dev_ds, db, top_k=TOP_K, batch_size=32):
    """
    针对 validation split，计算 Hit@k 与 MRR@k，使用批处理提高效率
    """
    # 构建 qrels：query_id -> 正例 passage_text 集合（使用集合加速查找）
    qrels = {}
    for item in tqdm(dev_ds, desc="Building qrels"):
        qid = item["query_id"]
        p = item["passages"]
        positives = set(txt for txt, sel in zip(p["passage_text"], p["is_selected"]) if sel == 1)
        qrels[qid] = positives
    
    results = {"hits": [], "rrs": []}
    
    # 批处理评估
    for i in range(0, len(dev_ds), batch_size):
        batch = dev_ds[i:i+batch_size]
        queries = [item["query"] for item in batch]
        qids = [item["query_id"] for item in batch]
        
        # 批量检索
        batch_results = []
        for query in queries:
            retrieved = db.similarity_search(query, k=top_k)
            batch_results.append([doc.page_content for doc in retrieved])
        
        for qid, texts in zip(qids, batch_results):
            positives = qrels.get(qid, set())
            
            # 精确匹配而非子字符串匹配
            hit = int(any(txt in positives for txt in texts))
            results["hits"].append(hit)
            
            # MRR@k
            rr = 0
            for rank, txt in enumerate(texts):
                if txt in positives:
                    rr = 1.0 / (rank + 1)
                    break
            results["rrs"].append(rr)
    
    return np.mean(results["hits"]), np.mean(results["rrs"])

In [5]:
if __name__ == "__main__":
    db = load_index()
    print("Loading MS MARCO v1.1 validation split for evaluation...")
    dev_ds = load_dataset("ms_marco", "v1.1", split="validation")
    hit10, mrr10 = evaluate_retrieval(dev_ds, db)
    print(f"MS MARCO v1.1 Retrieval => Hit@{TOP_K}: {hit10:.4f}, MRR@{TOP_K}: {mrr10:.4f}")

Index not found, building new index...
Loading MS MARCO v1.1 train split passages...


Collecting passages: 100%|██████████| 82326/82326 [00:06<00:00, 12470.05it/s]


Collected 626907 unique passages.
Error loading checkpoint: No data left in file. Starting from scratch.


Embedding batches:   1%|          | 18/3135 [00:41<2:00:52,  2.33s/it, docs/sec=88.0]


Operation interrupted. Saving progress...


ValueError: too many values to unpack (expected 2)