In [4]:
from datasets import load_dataset, Dataset
from ragas import evaluate
from ragas.metrics import context_recall, context_precision, answer_correctness
from src.retriever import get_relevant_chunks
import os
import pandas as pd

In [5]:
# 加载 SQuAD 验证集并选取前 100 条互不相同的 context
full = load_dataset("squad", split="validation")

In [6]:
seen, indices = set(), []
for i, ctx in enumerate(full["context"]):
    if ctx not in seen:
        seen.add(ctx)
        indices.append(i)
    if len(indices) == 100:
        break

In [7]:
raw = full.select(indices)
questions     = raw["question"]
ground_truths = [ans["text"][0] if ans["text"] else "" for ans in raw["answers"]]

print("选取到", len(raw), "条唯一 context")

选取到 100 条唯一 context


In [None]:
# 将这 100 段 context 输出到 documents/squad/*.txt
out_dir = "documents/squad"
os.makedirs(out_dir, exist_ok=True)

for i, ctx in enumerate(raw["context"]):
    with open(f"{out_dir}/{i:03d}.txt", "w", encoding="utf-8") as f:
        f.write(ctx)

print("已写入到", out_dir,",请构建向量库，参考readme")


已写入到 documents/squad 请构建向量库，参考readme


In [12]:
# Top-1 检索 & 手动命中检查
contexts_list = []
manual_hits   = []

print("=== Top-1 检索 & 手动检查 ===")
for i, (q, gt) in enumerate(zip(questions, ground_truths)):
    retrieved = get_relevant_chunks(q, k=1) #设置k
    contexts_list.append(retrieved)
    hit = bool(retrieved and gt in retrieved[0])
    manual_hits.append(hit)
    print(f"[{i:03d}] hit={hit}  retrieved_len={len(retrieved)}")

print(f"\nManual Recall@1: {sum(manual_hits)/len(manual_hits):.3f}")

=== Top-1 检索 & 手动检查 ===
[000] hit=True  retrieved_len=1
[001] hit=True  retrieved_len=1
[002] hit=False  retrieved_len=1
[003] hit=True  retrieved_len=1
[004] hit=True  retrieved_len=1
[005] hit=True  retrieved_len=1
[006] hit=True  retrieved_len=1
[007] hit=True  retrieved_len=1
[008] hit=True  retrieved_len=1
[009] hit=True  retrieved_len=1
[010] hit=True  retrieved_len=1
[011] hit=True  retrieved_len=1
[012] hit=True  retrieved_len=1
[013] hit=True  retrieved_len=1
[014] hit=True  retrieved_len=1
[015] hit=True  retrieved_len=1
[016] hit=True  retrieved_len=1
[017] hit=True  retrieved_len=1
[018] hit=True  retrieved_len=1
[019] hit=True  retrieved_len=1
[020] hit=True  retrieved_len=1
[021] hit=True  retrieved_len=1
[022] hit=True  retrieved_len=1
[023] hit=True  retrieved_len=1
[024] hit=True  retrieved_len=1
[025] hit=True  retrieved_len=1
[026] hit=True  retrieved_len=1
[027] hit=True  retrieved_len=1
[028] hit=True  retrieved_len=1
[029] hit=True  retrieved_len=1
[030] hit=True 

In [13]:
# 组装 RAGAS 数据并评估
answers = [" ".join(c)[:500] for c in contexts_list]

data = {
    "question":     questions,
    "contexts":     contexts_list,
    "answer":       answers,
    "ground_truth": ground_truths,
}
print("Data lengths:", {k: len(v) for k,v in data.items()})

eval_ds = Dataset.from_dict(data)
scores  = evaluate(
    dataset=eval_ds,
    metrics=[context_recall, context_precision, answer_correctness]
)
df = scores.to_pandas()

print("\n=== RAGAS Evaluation Results ===")
print(df)

avg = df[["context_recall","context_precision","answer_correctness"]].mean()
print("\n=== Average Metrics ===")
print(avg.to_string())

Data lengths: {'question': 100, 'contexts': 100, 'answer': 100, 'ground_truth': 100}


Evaluating:   0%|          | 0/300 [00:00<?, ?it/s]


=== RAGAS Evaluation Results ===
                                           user_input  \
0   Which NFL team represented the AFC at Super Bo...   
1   Which Carolina Panthers player was named Most ...   
2                      Who was the Super Bowl 50 MVP?   
3   Which network broadcasted Super Bowl 50 in the...   
4         Who was the NFL Commissioner in early 2012?   
..                                                ...   
95  What type of city has Warsaw been for as long ...   
96  What is the basic unit of territorial division...   
97  Who in Warsaw has the power of legislative act...   
98                What is the mayor of Warsaw called?   
99  What is the city centre of Warsaw called in Po...   

                                   retrieved_contexts  \
0   [Super Bowl 50 was an American football game t...   
1   [The Panthers finished the regular season with...   
2   [Super Bowl 50 was an American football game t...   
3   [CBS broadcast Super Bowl 50 in the U.S., and ...

In [17]:
#打印第n个样本
n = 0  # ← 在这里改成你想查看的索引

print(f"=== Sample {n} ===")
print("Q :", questions[n])
print("GT:", ground_truths[n])
print("\nRetrieved Contexts:")
for idx, ctx in enumerate(contexts_list[n]):
    print(f"  [{idx}] {ctx}\n")

print("Answer:")
print(answers[n])


=== Sample 0 ===
Q : Which NFL team represented the AFC at Super Bowl 50?
GT: Denver Broncos

Retrieved Contexts:
  [0] Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated

Answer:
Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated
