In [1]:
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from config import settings
import os

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
### 下载 Natural Questions Lite 数据集
dataset = load_dataset("squad", split="validation")
print("样本数：", len(dataset))

样本数： 10570


In [None]:
### 下载 Natural Questions Lite 数据集
dataset = load_dataset("squad", split="train[:5%]")
print("样本数：", len(dataset))


样本数： 4380


In [None]:
print(dataset[0])

In [None]:
### STEP 2: 构建知识库
print("\n 正在构建文档索引...")
docs = []

for item in dataset:
    passage = item["context"]
    pid = item["id"]
    doc = Document(page_content=passage, metadata={"doc_id": pid})
    docs.append(doc)


embeddings = OpenAIEmbeddings(model=settings.EMBEDDING_MODEL)
db = FAISS.from_documents(docs, embeddings)
db.save_local("embeddings/squad")
print(" 向量库已保存")

In [3]:
embeddings = OpenAIEmbeddings(model=settings.EMBEDDING_MODEL)
db = FAISS.load_local("embeddings/squad", embeddings, allow_dangerous_deserialization=True)

In [4]:
from src.retriever import get_topk_docs

# 评估
def evaluate(dataset, top_k=10, db=None):
    records = []
    for item in tqdm(dataset):
        query = item["question"]
        gt_answer = item["answers"]["text"][0]

        docs = get_topk_docs(query, k=top_k, db=db)   # <== 加上 db
        retrieved_texts = [doc.page_content for doc in docs]

        hit = int(any(gt_answer.lower() in passage.lower() for passage in retrieved_texts))
        rr = 0
        for i, passage in enumerate(retrieved_texts):
            if gt_answer.lower() in passage.lower():
                rr = 1 / (i + 1)
                break

        records.append({"question": query, "hit": hit, "rr": rr})

    df = pd.DataFrame(records)
    print(f" Hit@{top_k}: {df['hit'].mean():.4f}, MRR@{top_k}: {df['rr'].mean():.4f}")
    return df

df_result = evaluate(dataset, top_k=10, db=db)  # 传db进去

df_result.to_csv("squad_eval_result.csv", index=False)

100%|██████████| 4380/4380 [29:30<00:00,  2.47it/s]  

 Hit@10: 0.8096, MRR@10: 0.7290



