In [1]:
# Cell 1 — 环境初始化
from scripts.build_index import load_index, search
from scripts.reranker import load_reranker, merge_results, rerank, _load_chunk_texts
from sentence_transformers import SentenceTransformer

COLLECTION = "D"       # ← 改成你要测的 collection
QUERY      = "What is the signature dish of Pamela's Diner?"   # ← 改成你的查询词

# 加载 dense 检索模型
model = SentenceTransformer("BAAI/bge-m3")

# 加载本地 reranker（自动选 MPS）
reranker = load_reranker()

  from .autonotebook import tqdm as notebook_tqdm
Loading weights: 100%|██████████| 391/391 [00:00<00:00, 2131.13it/s, Materializing param=pooler.dense.weight]                               


  Reranker        : BAAI/bge-reranker-v2-m3  (本地推理)
  Device          : mps


`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|██████████| 393/393 [00:00<00:00, 1779.42it/s, Materializing param=roberta.encoder.layer.23.output.dense.weight]              


In [3]:
# Cell 2 — 加载 index + 检索
idx, meta = load_index(COLLECTION)
chunk_texts = _load_chunk_texts(COLLECTION)

dense_hits = search(QUERY, idx, meta, model, top_k=20)
print(f"Dense 检索到 {len(dense_hits)} 条候选")

  Loaded index: dim=1024  ntotal=1434
Dense 检索到 20 条候选


In [4]:
# Cell 3 — Rerank 并打印结果
candidates = merge_results(dense_hits, sparse_hits=[])
results    = rerank(QUERY, candidates, reranker, chunk_texts, top_n=5)

print(f"\n{'Rank':<6} {'Rerank':>8}  {'Dense':>8}  chunk_id")
print("-" * 55)
for r in results:
    print(f"  #{r['rerank_rank']:<4} {r['rerank_score']:>8.4f}  {r['dense_score']:>8.4f}  {r['chunk_id']}")


Rank     Rerank     Dense  chunk_id
-------------------------------------------------------
  #1      0.0866    0.5442  D_D_pittsburgh_restaurants__0101
  #2      0.0482    0.4923  D_D_pittsburgh_restaurants__0083
  #3      0.0052    0.4873  D_D_pittsburgh_restaurants__0140
  #4      0.0042    0.4934  D_D_pittsburgh_restaurants__0103
  #5      0.0029    0.4670  D_bananasplitfest_0001__0000
