In [1]:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from torch.utils.data import DataLoader
import torch
import os

dataset = load_dataset("neural-bridge/rag-dataset-12000", split="train")
pairs = [(q, c) for q, c in zip(dataset["question"], dataset["context"]) if isinstance(q, str) and q.strip() and isinstance(c, str) and c.strip()]
train_samples = [InputExample(texts=[q, c]) for q, c in pairs]

word_embedding_model = models.Transformer("sentence-transformers/all-MiniLM-L6-v2")
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
dense = models.Dense(in_features=384, out_features=512, activation_function=None)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"✅ Using device: {device}")

model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense])
model.to(device)

for param in model[0].auto_model.parameters():
    param.requires_grad = False
for param in model[1].parameters():
    param.requires_grad = False

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=32)
train_loss = losses.MultipleNegativesRankingLoss(model)

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=10,
    show_progress_bar=True
)

model.save("Projection_head_model")
print("Model saved")

  from .autonotebook import tqdm as notebook_tqdm


✅ Using device: mps


 56%|█████▌    | 500/900 [04:08<03:09,  2.11it/s]                    

{'loss': 0.2649, 'grad_norm': 1.4082971811294556, 'learning_rate': 8.95152198421646e-06, 'epoch': 1.67}


100%|██████████| 900/900 [07:30<00:00,  2.00it/s]


{'train_runtime': 450.8889, 'train_samples_per_second': 63.861, 'train_steps_per_second': 1.996, 'train_loss': 0.25171715630425345, 'epoch': 3.0}
Model saved
