# NLI 微调实战 — MultipleNegativesRankingLoss

官方示例：https://github.com/UKPLab/sentence-transformers/blob/master/examples/sentence_transformer/training/nli/training_nli_v2.py

使用 AllNLI 三元组数据集微调模型。MNRL 是训练 embedding 模型最常用、效果最好的 loss。

核心思想：蕴含关系作为正样本对，矛盾关系作为硬负样本，同时利用 batch 内其他样本作为负样本。

In [None]:
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments
from datasets import load_dataset

In [None]:
# 1. 加载基础模型
model = SentenceTransformer("distilroberta-base")
print(model)

In [None]:
# 2. 加载 AllNLI 三元组数据（取子集加速训练）
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train").select(range(10000))
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev").select(range(1000))

print(f"训练集: {len(train_dataset)} 条")
print("样本:", train_dataset[0])

In [None]:
# 3. 损失函数: MNRL
# 三元组中 anchor+positive 作为正对，batch 内其他 positive 和 negative 作为负样本
train_loss = losses.MultipleNegativesRankingLoss(model)

In [None]:
# 4. 用 STS 数据集做评估（评估 embedding 质量）
stsb_eval = load_dataset("sentence-transformers/stsb", split="validation")
dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=stsb_eval["sentence1"],
    sentences2=stsb_eval["sentence2"],
    scores=stsb_eval["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)
print("训练前评估:")
dev_evaluator(model)

In [None]:
# 5. 训练参数
args = SentenceTransformerTrainingArguments(
    output_dir="output/nli-distilroberta",
    num_train_epochs=1,
    per_device_train_batch_size=128,  # MNRL: batch 越大效果越好
    per_device_eval_batch_size=128,
    warmup_ratio=0.1,
    fp16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # 避免 batch 内重复样本
    eval_strategy="steps",
    eval_steps=10,
    save_strategy="steps",
    save_steps=10,
    save_total_limit=2,
    logging_steps=10,
)

In [None]:
# 6. 训练
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
)
trainer.train()

In [None]:
# 7. 测试集评估
stsb_test = load_dataset("sentence-transformers/stsb", split="test")
test_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=stsb_test["sentence1"],
    sentences2=stsb_test["sentence2"],
    scores=stsb_test["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)
test_evaluator(model)

In [None]:
# 8. 保存模型
model.save("output/nli-distilroberta/final")