# Cross-Encoder 训练实战

官方文档：https://sbert.net/docs/cross_encoder/training_overview.html

Cross-Encoder 与 Sentence Transformer（Bi-Encoder）的核心区别：
- **Bi-Encoder**：两个句子分别编码，输出向量，用余弦相似度比较。速度快，适合大规模检索。
- **Cross-Encoder**：两个句子拼接后一起输入模型，输出相关性分数。精度高，适合重排序。

## 训练组件

与 Sentence Transformer 训练类似，也是 6 要素：
1. **CrossEncoder** — 模型
2. **Dataset** — 训练数据
3. **Loss** — 损失函数（注意：从 `cross_encoder.losses` 导入）
4. **CrossEncoderTrainingArguments** — 训练参数
5. **Evaluator** — 评估器（可选）
6. **CrossEncoderTrainer** — 训练器

## 损失函数选择

| 场景 | Loss | 数据格式 |
|------|------|----------|
| **二分类（相关/不相关）** | BinaryCrossEntropyLoss | (query, doc, 0/1) |
| **多分类** | CrossEntropyLoss | (sent1, sent2, class_id) |
| **排序（无标签）** | MultipleNegativesRankingLoss | (query, positive) 或三元组 |
| **排序（有标签）** | LambdaLoss / ListNetLoss | (query, [doc1, doc2...], [score1, score2...]) |
| **蒸馏** | MSELoss / MarginMSELoss | (sent1, sent2, teacher_score) |

最常用：**BinaryCrossEntropyLoss**（训练 reranker）和 **CrossEntropyLoss**（NLI 分类）

## 实战1：NLI 分类（CrossEntropyLoss）

使用 AllNLI 数据集训练 Cross-Encoder 做三分类（蕴含/矛盾/中立）。

In [None]:
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.trainer import CrossEncoderTrainer
from sentence_transformers.cross_encoder.training_args import CrossEncoderTrainingArguments
from sentence_transformers.cross_encoder.losses import CrossEntropyLoss
from sentence_transformers.cross_encoder.evaluation import CrossEncoderClassificationEvaluator
from datasets import load_dataset

In [None]:
# 1. 加载模型（num_labels=3：蕴含/中立/矛盾）
model = CrossEncoder("distilroberta-base", num_labels=3)
print(model)

In [None]:
# 2. 加载 AllNLI 数据集（pair-class 格式：sentence1, sentence2, label）
# label: 0=蕴含, 1=中立, 2=矛盾
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train").select(range(10000))
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev").select(range(1000))

print(f"训练集: {len(train_dataset)} 条")
print("样本:", train_dataset[0])

In [None]:
# 3. 损失函数: CrossEntropyLoss（多分类）
train_loss = CrossEntropyLoss(model)

In [None]:
# 4. 评估器: 分类准确率
dev_evaluator = CrossEncoderClassificationEvaluator(
    sentence_pairs=list(zip(eval_dataset["premise"], eval_dataset["hypothesis"])),
    labels=eval_dataset["label"],
    name="nli-dev",
)
print("训练前评估:")
dev_evaluator(model)

In [None]:
# 5. 训练参数
args = CrossEncoderTrainingArguments(
    output_dir="output/nli-cross-encoder",
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_ratio=0.1,
    fp16=True,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=2,
    logging_steps=50,
)

In [None]:
# 6. 训练
trainer = CrossEncoderTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
)
trainer.train()

In [None]:
# 7. 测试推理
pairs = [
    ("A man is eating pizza", "A man eats something"),      # 蕴含
    ("A black race car starts up", "A man is driving"),     # 中立
    ("A woman is sleeping", "A man is running"),            # 矛盾
]
predictions = model.predict(pairs)
label_map = {0: "蕴含", 1: "中立", 2: "矛盾"}
for pair, pred in zip(pairs, predictions):
    label = label_map[pred.argmax()]
    print(f"  {pair[0]} | {pair[1]} → {label} {pred}")

## 实战2：Reranker 训练（BinaryCrossEntropyLoss）

使用 MS MARCO 数据集训练搜索重排序模型。数据格式：(query, passage, 0/1)。

这是 Cross-Encoder 最核心的应用场景。

In [None]:
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator

# 1. 加载 reranker 模型（num_labels=1：输出单个相关性分数）
reranker = CrossEncoder("distilroberta-base", num_labels=1)
print(reranker)

In [None]:
# 2. 加载 GooAQ 数据集（question-answer 对，适合 reranker 训练）
train_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(10000))
eval_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(10000, 11000))

print(f"训练集: {len(train_dataset)} 条")
print("样本:", train_dataset[0])

In [None]:
# 3. 损失函数: BinaryCrossEntropyLoss
# GooAQ 是正样本对格式，BCE 会利用 in-batch negatives
reranker_loss = BinaryCrossEntropyLoss(reranker)

In [None]:
# 4. 评估器: 重排序评估（MRR@10）
# 构造评估数据：每个 query 对应 1 个正样本 + batch 内其他 answer 作为负样本
samples = [
    {
        "query": eval_dataset[i]["question"],
        "positive": [eval_dataset[i]["answer"]],
        "negative": [eval_dataset[j]["answer"] for j in range(len(eval_dataset)) if j != i][:9],
    }
    for i in range(100)  # 取 100 条做评估
]
dev_evaluator = CrossEncoderRerankingEvaluator(samples=samples, name="gooaq-dev")
print("训练前评估:")
dev_evaluator(reranker)

In [None]:
# 5. 训练参数
reranker_args = CrossEncoderTrainingArguments(
    output_dir="output/gooaq-reranker",
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_ratio=0.1,
    fp16=True,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=2,
    logging_steps=50,
)

In [None]:
# 6. 训练
reranker_trainer = CrossEncoderTrainer(
    model=reranker,
    args=reranker_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=reranker_loss,
    evaluator=dev_evaluator,
)
reranker_trainer.train()

In [None]:
# 7. 测试 reranker 排序效果
query = "How many people live in Berlin?"
documents = [
    "Berlin has a population of 3,520,031 registered inhabitants.",
    "Berlin is well known for its museums.",
    "Germany is a country in Europe.",
    "New York City is the most populous city in the United States.",
]
rankings = reranker.rank(query, documents)
print(f"Query: {query}\n")
for r in rankings:
    print(f"  [{r['score']:.4f}] {documents[r['corpus_id']]}")

In [None]:
# 8. 保存模型
reranker.save("output/gooaq-reranker/final")

## Cross-Encoder vs Bi-Encoder 训练对比

| 对比项 | Bi-Encoder (SentenceTransformer) | Cross-Encoder |
|--------|----------------------------------|---------------|
| **导入路径** | `sentence_transformers.losses` | `sentence_transformers.cross_encoder.losses` |
| **Trainer** | `SentenceTransformerTrainer` | `CrossEncoderTrainer` |
| **TrainingArgs** | `SentenceTransformerTrainingArguments` | `CrossEncoderTrainingArguments` |
| **num_labels** | 不需要 | 分类=类别数，回归/排序=1 |
| **推理方式** | `model.encode()` → 向量 | `model.predict()` → 分数 |
| **典型应用** | 大规模检索、聚类 | 精排、NLI 分类 |
| **速度** | 快（独立编码） | 慢（拼接编码） |
| **精度** | 相对低 | 相对高 |