# Retrieve & Re-Rank

官方文档：https://sbert.net/examples/sentence_transformer/applications/retrieve_rerank/README.html

两阶段检索流水线：
1. **Retrieve**: 用 Bi-Encoder 快速召回 top-k 候选
2. **Re-Rank**: 用 Cross-Encoder 对候选精排

In [None]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch

# 第一阶段：Bi-Encoder 召回
bi_encoder = SentenceTransformer("all-MiniLM-L6-v2")

corpus = [
    "Python是一种解释型高级编程语言",
    "Java是一种面向对象的编程语言",
    "深度学习使用多层神经网络",
    "机器学习是人工智能的子领域",
    "HTML是网页的标记语言",
    "CSS用于网页样式设计",
    "JavaScript是网页的脚本语言",
    "数据库用于存储和管理数据",
]

corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True)

query = "什么编程语言适合人工智能"
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)

# 召回 top-5
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=5)[0]
print("=== Bi-Encoder 召回结果 ===")
for hit in hits:
    print(f"  [{hit['score']:.4f}] {corpus[hit['corpus_id']]}")

In [None]:
# 第二阶段：Cross-Encoder 精排
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

# 构造 query-document 对
cross_input = [[query, corpus[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_input)

# 按 cross-encoder 分数重排
for idx in cross_scores.argsort()[::-1]:
    hit = hits[idx]
    print(f"  [CE: {cross_scores[idx]:.4f} | BE: {hit['score']:.4f}] {corpus[hit['corpus_id']]}")