# 03 · Train Bi-Encoder Retriever

## Purpose

Fine-tune a sentence-transformer bi-encoder for semantic retrieval over controls.

## Inputs

- `data/processed/pairs/train.jsonl` for training instances.
- `data/processed/pairs/dev.jsonl` for validation metrics.

## Outputs

- `models/bi_encoder/` directory with fine-tuned weights, tokenizer, and training config.
- Optional precomputed control embeddings array (e.g., `controls_embeddings.npy`).

## Steps

1. Load the base checkpoint `sentence-transformers/multi-qa-mpnet-base-dot-v1`.
2. Set up MultipleNegativesRankingLoss with batch size 64 (gradient accumulation if needed).
3. Train for 3–5 epochs with learning rate 2e-5 and 10% warmup, logging progress.
4. Evaluate on the dev split to compute MRR@10 and select the best checkpoint.
5. Persist the tuned model and optionally encode the control catalog for faster inference.

## Acceptance Checks

- `models/bi_encoder/` contains the saved model artifacts.
- Validation MRR@10 is computed and reported for the best checkpoint.

In [10]:
import pandas as pd
import numpy as np
import json
from pathlib import Path
from collections import defaultdict
import torch
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from torch.utils.data import DataLoader

# Set random seeds
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Check device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


## 1. Load training pairs

In [11]:
def load_pairs(jsonl_path):
    """Load pairs from JSONL file"""
    pairs = []
    with open(jsonl_path, "r") as f:
        for line in f:
            pairs.append(json.loads(line))
    return pairs

# Load train and dev pairs
train_pairs = load_pairs("../data/processed/pairs/train.jsonl")
dev_pairs = load_pairs("../data/processed/pairs/dev.jsonl")

print(f"✓ Loaded {len(train_pairs)} train pairs")
print(f"✓ Loaded {len(dev_pairs)} dev pairs")

# Filter to only positive pairs for training (MultipleNegativesRankingLoss uses in-batch negatives)
train_positives = [p for p in train_pairs if p["label"] == 1]
print(f"\n✓ Filtered to {len(train_positives)} positive training pairs")
print(f"  (MultipleNegativesRankingLoss will use in-batch negatives)")

✓ Loaded 59373 train pairs
✓ Loaded 11683 dev pairs

✓ Filtered to 3146 positive training pairs
  (MultipleNegativesRankingLoss will use in-batch negatives)


## 2. Prepare training data for sentence-transformers

In [12]:
# Convert positive pairs to InputExample format for MultipleNegativesRankingLoss
# Format: InputExample(texts=[query, passage])
train_examples = [
    InputExample(texts=[pair["artifact_text"], pair["control_text"]])
    for pair in train_positives
]

print(f"✓ Created {len(train_examples)} training examples")
print(f"\n  Sample example:")
print(f"    Query: {train_examples[0].texts[0][:80]}...")
print(f"    Passage: {train_examples[0].texts[1][:80]}...")

# Convert to Dataset format for the trainer
from datasets import Dataset as HFDataset

train_dataset_dict = {
    "sentence1": [ex.texts[0] for ex in train_examples],
    "sentence2": [ex.texts[1] for ex in train_examples],
}
train_dataset = HFDataset.from_dict(train_dataset_dict)
print(f"\n✓ Converted to HuggingFace Dataset format")
print(f"  Dataset size: {len(train_dataset)}")

✓ Created 3146 training examples

  Sample example:
    Query: User 'svc-api' failed login 11 times in 2 minutes; account was not automatically...
    Passage: Audit Review, Analysis, and Reporting. Regularly review and analyze audit logs; ...

✓ Converted to HuggingFace Dataset format
  Dataset size: 3146


## 3. Load base model and prepare training

In [13]:
# Load base model
BASE_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
print(f"Loading base model: {BASE_MODEL}")

model = SentenceTransformer(BASE_MODEL, device=device)
print(f"✓ Model loaded on {device}")
print(f"  Max sequence length: {model.max_seq_length}")

# Create data loader
BATCH_SIZE = 16  # Reduced from 64 due to memory constraints
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)
print(f"\n✓ Created DataLoader with batch size {BATCH_SIZE}")
print(f"  Total batches per epoch: {len(train_dataloader)}")

Loading base model: sentence-transformers/multi-qa-mpnet-base-dot-v1
✓ Model loaded on mps
  Max sequence length: 512

✓ Created DataLoader with batch size 16
  Total batches per epoch: 197


In [14]:
# Set up loss function
train_loss = losses.MultipleNegativesRankingLoss(model)
print(f"✓ Loss function: MultipleNegativesRankingLoss")
print(f"  This uses in-batch negatives for contrastive learning")

✓ Loss function: MultipleNegativesRankingLoss
  This uses in-batch negatives for contrastive learning


## 4. Prepare dev set evaluator (MRR@10)

In [15]:
# Prepare dev set for evaluation
# Group by artifact_id to create queries with relevant controls
dev_queries = {}
dev_corpus = {}

for pair in dev_pairs:
    artifact_id = pair["artifact_id"]
    control_id = pair["control_id"]
    
    # Add to corpus
    if control_id not in dev_corpus:
        dev_corpus[control_id] = pair["control_text"]
    
    # Add to queries
    if artifact_id not in dev_queries:
        dev_queries[artifact_id] = {
            "query": pair["artifact_text"],
            "relevant": []
        }
    
    # Mark as relevant if positive
    if pair["label"] == 1:
        dev_queries[artifact_id]["relevant"].append(control_id)

# Create evaluator format
queries = {aid: q["query"] for aid, q in dev_queries.items()}
relevant_docs = {aid: set(q["relevant"]) for aid, q in dev_queries.items()}

print(f"✓ Prepared dev evaluation set:")
print(f"  Queries: {len(queries)}")
print(f"  Corpus: {len(dev_corpus)}")
print(f"  Avg relevant docs per query: {np.mean([len(r) for r in relevant_docs.values()]):.2f}")

# Create InformationRetrievalEvaluator
ir_evaluator = evaluation.InformationRetrievalEvaluator(
    queries=queries,
    corpus=dev_corpus,
    relevant_docs=relevant_docs,
    name="dev_mrr",
    show_progress_bar=True,
    mrr_at_k=[10]
)
print(f"✓ Created InformationRetrievalEvaluator for MRR@10")

✓ Prepared dev evaluation set:
  Queries: 365
  Corpus: 34
  Avg relevant docs per query: 1.59
✓ Created InformationRetrievalEvaluator for MRR@10


## 5. Train the model

In [16]:
# Training configuration
NUM_EPOCHS = 3
WARMUP_STEPS = int(len(train_dataloader) * NUM_EPOCHS * 0.1)
OUTPUT_PATH = "../models/bi_encoder"

print(f"Training configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: 2e-5")
print(f"  Warmup steps: {WARMUP_STEPS} (10% of total)")
print(f"  Output path: {OUTPUT_PATH}")

# Create output directory
Path(OUTPUT_PATH).mkdir(parents=True, exist_ok=True)

print(f"\n{'='*60}")
print(f"Starting training...")
print(f"{'='*60}\n")

# Use the SentenceTransformerTrainer
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments

# Training arguments
training_args = SentenceTransformerTrainingArguments(
    output_dir=OUTPUT_PATH,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    warmup_steps=WARMUP_STEPS,
    learning_rate=2e-5,
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_dev_mrr_dot_mrr@10",  # Use MRR@10 as the metric
    greater_is_better=True,  # Higher MRR is better
    eval_strategy="epoch",
    report_to="none",
)

# Create trainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,  # Use the HF Dataset
    loss=train_loss,
    evaluator=ir_evaluator,
)

# Train
trainer.train()

print(f"\n{'='*60}")
print(f"✓ Training complete!")
print(f"{'='*60}")

Training configuration:
  Epochs: 3
  Batch size: 16
  Learning rate: 2e-5
  Warmup steps: 59 (10% of total)
  Output path: ../models/bi_encoder

Starting training...



Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]



Epoch,Training Loss,Validation Loss,Dev Mrr Dot Accuracy@1,Dev Mrr Dot Accuracy@3,Dev Mrr Dot Accuracy@5,Dev Mrr Dot Accuracy@10,Dev Mrr Dot Precision@1,Dev Mrr Dot Precision@3,Dev Mrr Dot Precision@5,Dev Mrr Dot Precision@10,Dev Mrr Dot Recall@1,Dev Mrr Dot Recall@3,Dev Mrr Dot Recall@5,Dev Mrr Dot Recall@10,Dev Mrr Dot Ndcg@10,Dev Mrr Dot Mrr@10,Dev Mrr Dot Map@100
1,0.9793,No log,0.838356,0.964384,0.983562,0.989041,0.838356,0.481279,0.302466,0.156438,0.584018,0.915068,0.954795,0.981279,0.908415,0.900441,0.874341
2,0.9711,No log,0.852055,0.978082,0.983562,0.991781,0.852055,0.494977,0.307945,0.15726,0.594977,0.940639,0.968493,0.986301,0.924503,0.913751,0.89646
3,0.9613,No log,0.860274,0.978082,0.986301,0.991781,0.860274,0.49863,0.306301,0.15726,0.599087,0.946119,0.964384,0.986301,0.926926,0.918995,0.899015


Batches:   0%|          | 0/12 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  1.45it/s]


Batches:   0%|          | 0/12 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  5.04it/s]


Batches:   0%|          | 0/12 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  4.88it/s]



✓ Training complete!


## 6. Pre-encode control embeddings for inference

In [17]:
# Load the best model
best_model = SentenceTransformer(OUTPUT_PATH, device=device)
print(f"✓ Loaded best model from {OUTPUT_PATH}")

# Load enhanced controls and encode them
controls = pd.read_csv("../data/processed/controls_enhanced.csv", dtype=str)
# index_text is already created in controls_enhanced.csv

print(f"\n✓ Encoding {len(controls)} controls...")
control_embeddings = best_model.encode(
    controls["index_text"].tolist(),
    convert_to_numpy=True,
    show_progress_bar=True
)

# Save embeddings
embeddings_path = Path(OUTPUT_PATH) / "control_embeddings.npy"
np.save(embeddings_path, control_embeddings)

print(f"✓ Saved control embeddings to {embeddings_path}")
print(f"  Shape: {control_embeddings.shape}")
print(f"  Size: {embeddings_path.stat().st_size / 1024:.1f} KB")

✓ Loaded best model from ../models/bi_encoder

✓ Encoding 34 controls...


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Saved control embeddings to ../models/bi_encoder/control_embeddings.npy
  Shape: (34, 768)
  Size: 102.1 KB


## 7. Final evaluation and acceptance checks

In [18]:
# Run final evaluation
print("="*60)
print("FINAL EVALUATION")
print("="*60)

final_metrics = ir_evaluator(best_model, output_path=OUTPUT_PATH)

# Extract MRR@10
mrr_at_10 = final_metrics.get("dev_mrr_dot_mrr@10", 0.0)
print(f"\n✓ Final dev MRR@10: {mrr_at_10:.4f}")
print(f"\nAll metrics:")
for key, value in final_metrics.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")

# Acceptance checks
print("\n" + "="*60)
print("ACCEPTANCE CHECKS")
print("="*60)

# Check 1: Model saved
model_path = Path(OUTPUT_PATH)
check1 = model_path.exists() and (model_path / "config.json").exists()
print(f"\n✓ Check 1: Model saved at {OUTPUT_PATH}")
print(f"  Result: {'PASS' if check1 else 'FAIL'}")
if check1:
    files = list(model_path.glob("*"))
    print(f"  Files: {[f.name for f in files[:5]]}...")

# Check 2: MRR@10 computed
check2 = mrr_at_10 is not None and mrr_at_10 > 0
print(f"\n✓ Check 2: Dev MRR@10 computed and reported")
print(f"  MRR@10: {mrr_at_10:.4f}")
print(f"  Result: {'PASS' if check2 else 'FAIL'}")

# Overall
all_checks_passed = check1 and check2
print("\n" + "="*60)
if all_checks_passed:
    print("✅ ALL ACCEPTANCE CHECKS PASSED")
else:
    print("❌ SOME ACCEPTANCE CHECKS FAILED")
print("="*60)

FINAL EVALUATION


Batches:   0%|          | 0/12 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 1/1 [00:00<00:00,  5.50it/s]


✓ Final dev MRR@10: 0.8017

All metrics:
  dev_mrr_dot_accuracy@1: 0.7123
  dev_mrr_dot_accuracy@3: 0.8712
  dev_mrr_dot_accuracy@5: 0.9178
  dev_mrr_dot_accuracy@10: 0.9562
  dev_mrr_dot_precision@1: 0.7123
  dev_mrr_dot_precision@3: 0.4082
  dev_mrr_dot_precision@5: 0.2712
  dev_mrr_dot_precision@10: 0.1466
  dev_mrr_dot_recall@1: 0.4881
  dev_mrr_dot_recall@3: 0.7790
  dev_mrr_dot_recall@5: 0.8571
  dev_mrr_dot_recall@10: 0.9201
  dev_mrr_dot_ndcg@10: 0.8021
  dev_mrr_dot_mrr@10: 0.8017
  dev_mrr_dot_map@100: 0.7465

ACCEPTANCE CHECKS

✓ Check 1: Model saved at ../models/bi_encoder
  Result: PASS
  Files: ['model.safetensors', 'checkpoint-591', '1_Pooling', 'tokenizer_config.json', 'special_tokens_map.json']...

✓ Check 2: Dev MRR@10 computed and reported
  MRR@10: 0.8017
  Result: PASS

✅ ALL ACCEPTANCE CHECKS PASSED



