# 02 – Train Cross-Encoder Reranker

Fine‑tune `cross-encoder/ms-marco-MiniLM-L-6-v2` trên `pairs.jsonl`.
Chạy notebook này trên Colab/GPU sẽ nhanh hơn.

In [1]:
!pip install -U sentence-transformers transformers datasets scikit-learn

Collecting sentence-transformers
  Downloading sentence_transformers-5.1.0-py3-none-any.whl.metadata (16 kB)
Collecting transformers
  Downloading transformers-4.55.2-py3-none-any.whl.metadata (41 kB)
Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.6.1-cp39-cp39-win_amd64.whl.metadata (15 kB)
Collecting tqdm (from sentence-transformers)
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting torch>=1.11.0 (from sentence-transformers)
  Downloading torch-2.8.0-cp39-cp39-win_amd64.whl.metadata (30 kB)
Collecting scipy (from sentence-transformers)
  Downloading scipy-1.13.1-cp39-cp39-win_amd64.whl.metadata (60 kB)
Collecting huggingface-hub>=0.20.0 (from sentence-transformers)
  Using cached huggingface_hub-0.34.4-py3-none-any.whl.metadata (14 kB)
Collecting filelock (from transformers)
  Downloading filelock-3.19.1-py3-none-any.whl.metadata (2.1 kB)
Collecting regex!=2019.12.17 (from t

In [3]:
!pip install "accelerate>=0.26.0"

Collecting accelerate>=0.26.0
  Downloading accelerate-1.10.0-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.10.0-py3-none-any.whl (374 kB)
Installing collected packages: accelerate
Successfully installed accelerate-1.10.0


In [None]:
import os, json, random, torch
from torch.utils.data import DataLoader
from sentence_transformers import CrossEncoder, InputExample
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from torch.nn import BCEWithLogitsLoss

random.seed(42)

# ==== 1) Load & split ====
rows = [json.loads(l) for l in open('data/pairs.jsonl', encoding='utf-8')]
random.shuffle(rows)

n_total = len(rows)
n_train = int(0.8 * n_total)
n_dev   = int(0.1 * n_total)

train_rows = rows[:n_train]
dev_rows   = rows[n_train:n_train+n_dev]
test_rows  = rows[n_train+n_dev:]

def to_pointwise(rows):
    data = []
    for r in rows:
        q = r['query']
        data.append(InputExample(texts=[q, r['positive']], label=1.0))
        for neg in r['negatives']:
            data.append(InputExample(texts=[q, neg], label=0.0))
    return data

train_data = to_pointwise(train_rows)
dev_data   = to_pointwise(dev_rows)
test_data  = to_pointwise(test_rows)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
dev_loader   = DataLoader(dev_data,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_data,  batch_size=64, shuffle=False)

# ==== 2) Model ====
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', num_labels=1, max_length=256)
bce = BCEWithLogitsLoss()

@torch.no_grad()
def compute_accuracy(_model, loader, max_batches=None):
    _model.model.eval()
    correct, total, seen = 0, 0, 0
    for batch in loader:
        pairs  = [ex.texts for ex in batch]
        labels = torch.tensor([ex.label for ex in batch], dtype=torch.float32, device=_model._target_device)
        logits = _model.predict(pairs, convert_to_numpy=False)  # torch.Tensor
        preds  = (torch.sigmoid(logits) > 0.5).float().view(-1)
        correct += (preds == labels).sum().item()
        total   += labels.numel()
        seen += 1
        if max_batches and seen >= max_batches:
            break
    return correct / max(1, total)

@torch.no_grad()
def compute_loss(_model, loader, max_batches=None):
    _model.model.eval()
    total_loss, total_count, seen = 0.0, 0, 0
    for batch in loader:
        pairs  = [ex.texts for ex in batch]
        labels = torch.tensor([ex.label for ex in batch], dtype=torch.float32, device=_model._target_device)
        logits = _model.predict(pairs, convert_to_numpy=False).view(-1).to(_model._target_device)
        loss = bce(logits, labels)
        bs = labels.size(0)
        total_loss  += loss.item() * bs
        total_count += bs
        seen += 1
        if max_batches and seen >= max_batches:
            break
    return total_loss / max(1, total_count)

# ==== 3) Evaluator "hợp lệ" cho trainer ====
evaluator_dev = CEBinaryClassificationEvaluator.from_input_examples(dev_data, name='dev')

# ==== 4) Train config ====
OUTPUT_DIR = "models/reranker_food"
os.makedirs(OUTPUT_DIR, exist_ok=True)

EPOCHS = 3
EVAL_EVERY = 500         # mỗi 500 step thì eval
PATIENCE = 3             # early stopping
best_val_acc = 0.0
no_improve = 0

def callback(score, epoch, steps):
    # score ở đây là metric từ evaluator_dev (không dùng cũng được)
    global best_val_acc, no_improve
    # để không tốn thời gian, chỉ tính trên 100 batch train và toàn bộ dev
    train_acc  = compute_accuracy(model, train_loader, max_batches=100)
    train_loss = compute_loss(model,    train_loader, max_batches=100)
    val_acc    = compute_accuracy(model, dev_loader)         # dev nhỏ, tính full
    val_loss   = compute_loss(model,     dev_loader)

    print(f"[EVAL] epoch={epoch} step={steps} | "
          f"Train_Acc={train_acc:.4f} Train_Loss={train_loss:.4f} | "
          f"Val_Acc={val_acc:.4f} Val_Loss={val_loss:.4f}")

    # lưu best theo Val_Acc
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        no_improve = 0
        model.save(OUTPUT_DIR)
        print(f"[BEST] Val_Acc={val_acc:.4f} → saved to {OUTPUT_DIR}")
    else:
        no_improve += 1
        print(f"[NO IMPROVE {no_improve}/{PATIENCE}] best={best_val_acc:.4f}")
        if no_improve >= PATIENCE:
            print("[EARLY STOPPING] Không cải thiện trên val set.")
            raise KeyboardInterrupt("early_stopping_triggered")

# ==== 5) Train ====
try:
    model.fit(
        train_dataloader=train_loader,
        epochs=EPOCHS,
        warmup_steps=100,
        evaluator=evaluator_dev,      # để trainer biết đánh giá theo steps
        evaluation_steps=EVAL_EVERY,
        output_path=OUTPUT_DIR,       # bật ghi CSV/checkpoints; CSV nằm trong OUTPUT_DIR
        show_progress_bar=True,
        use_amp=True,
        callback=callback             # in số liệu thủ công rõ ràng
    )
except KeyboardInterrupt as e:
    if "early_stopping_triggered" in str(e):
        print("Đã dừng train sớm do early stopping.")
    else:
        raise

# ==== 6) Kết thúc: in kết quả cuối ====
final_train_acc = compute_accuracy(model, train_loader)
final_val_acc   = compute_accuracy(model, dev_loader)
final_test_acc  = compute_accuracy(model, test_loader)
print("\n== KẾT QUẢ CUỐI CÙNG ==")
print(f"Train Accuracy: {final_train_acc:.4f}")
print(f"Val   Accuracy: {final_val_acc:.4f}")
print(f"Test  Accuracy: {final_test_acc:.4f}")
print(f"Best model tại: {os.path.abspath(OUTPUT_DIR)}")


Token indices sequence length is longer than the specified maximum sequence length for this model (322 > 256). Running this sequence through the model will result in indexing errors


Step,Training Loss,Validation Loss


KeyboardInterrupt: 