In [2]:
from datasets import load_dataset
from sentence_transformers import InputExample
import random

In [3]:
hotpot = load_dataset("hotpot_qa", "distractor", split="train[:1000]")

In [4]:
train_examples = []
for sample in hotpot:
    q      = sample["question"].strip()
    titles = sample["context"]["title"]       # 段落标题列表
    sents  = sample["context"]["sentences"]   # 对应的句子列表

    # 拼回段落文本
    paras = [" ".join(ss).strip() for ss in sents]

    # 正例：supporting_facts 指向的标题
    sup_titles = sample["supporting_facts"]["title"]
    pos_paras  = [paras[titles.index(t)] for t in sup_titles if t in titles]

    # 负例：同条样本里其余段落
    neg_paras  = [p for i, p in enumerate(paras) if titles[i] not in sup_titles]

    # 构造 InputExample
    for pos in pos_paras:
        train_examples.append(InputExample(texts=[q, pos], label=1.0))
        if neg_paras:
            train_examples.append(InputExample(texts=[q, neg_paras[0]], label=0.0))

print(f"训练样本数：{len(train_examples)}")

训练样本数：4741


In [5]:
for ex in train_examples[:5]:
    print("Q:", ex.texts[0])
    print("P:", ex.texts[1][:80], "…", "| label:", ex.label)

Q: Which magazine was started first Arthur's Magazine or First for Women?
P: Arthur's Magazine (1844–1846) was an American literary periodical published in P … | label: 1.0
Q: Which magazine was started first Arthur's Magazine or First for Women?
P: Radio City is India's first private FM radio station and was started on 3 July 2 … | label: 0.0
Q: Which magazine was started first Arthur's Magazine or First for Women?
P: First for Women is a woman's magazine published by Bauer Media Group in the USA. … | label: 1.0
Q: Which magazine was started first Arthur's Magazine or First for Women?
P: Radio City is India's first private FM radio station and was started on 3 July 2 … | label: 0.0
Q: The Oberoi family is part of a hotel company that has a head office in what city?
P: The Oberoi family is an Indian family that is famous for its involvement in hote … | label: 1.0


In [5]:
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from torch.utils.data import DataLoader
from datasets import load_dataset

In [6]:
base_model = "sentence-transformers/all-mpnet-base-v2"
model = SentenceTransformer(base_model)

In [None]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)

train_loss = losses.MultipleNegativesRankingLoss(model=model)

num_epochs = 2
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
#默认学习率2e-5

In [11]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=2,
    warmup_steps=warmup_steps,
    output_path="fine_tuned_mpnet_base_v2",
    checkpoint_path="checkpoints/",
    checkpoint_save_steps=5,     
    checkpoint_save_total_limit=2,
    use_amp=True,  # 自动混合精度
)
print("微调完成，模型保存在 fine_tuned_mpnet_base_v2")

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Step,Training Loss
500,0.1249
1000,0.0356


微调完成，模型保存在 fine_tuned_mpnet_base_v2
