In [None]:
# 用于“分库分测、禁止泄露” & “索引全段落”
from datasets import load_dataset
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from config import settings
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import os
def build_index():
    """
    构建向量库：仅使用 SQuAD train split 的整段 context
    """
    train_ds = load_dataset("squad", split="train[:2%]")
    seen, contexts = set(), []
    for item in train_ds:
        ctx = item["context"]
        if ctx not in seen:
            seen.add(ctx)
            contexts.append(ctx)
    print(f"使用 Train split 构建索引，共 {len(contexts)} 段落。")

    docs = [Document(page_content=ctx, metadata={"source": f"train_paragraph_{i}"})
            for i, ctx in enumerate(contexts)]
    embeddings = OpenAIEmbeddings(model=settings.EMBEDDING_MODEL)
    db = FAISS.from_documents(docs, embeddings)
    os.makedirs("embeddings/rag_train_full_paragraphs", exist_ok=True)
    db.save_local("embeddings/rag_train_full_paragraphs")
    print("向量库已保存：embeddings/rag_train_full_paragraphs")
    return db


def evaluate_retrieval(dev_ds, db, top_k=10):
    """
    基于答案文本评估检索效果（Hit@k & MRR@k），而非简单匹配完整 context。
    """
    hits, rr_list = [], []
    for item in dev_ds:
        query = item["question"]
        answers = item.get("answers", {}).get("text", [])

        retrieved = db.similarity_search(query, k=top_k)
        hit = 0
        rr = 0
        for rank, doc in enumerate(retrieved):
            content = doc.page_content.lower()
            if any(ans.lower() in content for ans in answers):
                hit = 1
                if rr == 0:
                    rr = 1.0 / (rank + 1)
        hits.append(hit)
        rr_list.append(rr)
    return np.mean(hits), np.mean(rr_list)


if __name__ == "__main__":
    index_path = "embeddings/rag_train_full_paragraphs"
    # 若向量库不存在，则重新构建
    if not os.path.isdir(index_path) or not os.path.exists(os.path.join(index_path, "index.faiss")):
        print("检测不到已保存的向量库，开始构建...")
        db = build_index()
    else:
        embeddings = OpenAIEmbeddings(model=settings.EMBEDDING_MODEL)
        db = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
        print("已加载已有向量库：", index_path)

    dev_ds = load_dataset("squad", split="validation[:2%]")
    hit10, mrr10 = evaluate_retrieval(dev_ds, db, top_k=10)
    print(f"Hit@10: {hit10:.4f}, MRR@10: {mrr10:.4f}")


检测不到已保存的向量库，开始构建...


Generating train split: 100%|██████████| 87599/87599 [00:00<00:00, 880190.01 examples/s]
Generating validation split: 100%|██████████| 10570/10570 [00:00<00:00, 1003072.39 examples/s]


使用 Train split 构建索引，共 238 段落。
向量库已保存：embeddings/rag_train_full_paragraphs
Hit@10: 0.3460, MRR@10: 0.2260
