# SBERT Hyperparameter Tuning Experiment 2

**Experiment focus:** Tuning epochs and warmup steps for SBERT fine-tuning.

This notebook explores the effect of different epochs and warmup steps on SBERT performance for scientific claim source retrieval.

In [1]:
import pandas as pd
import pickle
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Data Loading

Load the train and dev datasets, and prepare InputExample objects for SBERT.

In [None]:
# Load train and dev data
train_df = pd.read_csv('../subtask4b_query_tweets_train.tsv', sep='\t', names=['post_id', 'tweet_text', 'cord_uid'])
dev_df = pd.read_csv('../subtask4b_query_tweets_dev.tsv', sep='\t', names=['post_id', 'tweet_text', 'cord_uid'])
with open('../subtask4b_collection_data.pkl', 'rb') as f:
    papers_df = pickle.load(f)
papers_df['text'] = papers_df['title'] + '. ' + papers_df['abstract']
paper_dict = dict(zip(papers_df['cord_uid'], papers_df['text']))

# Prepare training examples
train_samples = []
for _, row in train_df.iterrows():
    if row['cord_uid'] in paper_dict:
        tweet = row['tweet_text']
        paper = paper_dict[row['cord_uid']]
        train_samples.append(InputExample(texts=[tweet, paper]))

# Prepare dev examples
dev_samples = []
for _, row in dev_df.iterrows():
    if row['cord_uid'] in paper_dict:
        tweet = row['tweet_text']
        paper = paper_dict[row['cord_uid']]
        dev_samples.append(InputExample(texts=[tweet, paper]))

## Hyperparameter Grid

Define the epochs and warmup steps to try.

In [None]:
epochs_list = [2, 5, 8]
warmup_steps_list = [50, 90, 120]
learning_rate = 2e-5  # Fixed for this experiment
batch_size = 16       # Fixed for this experiment
model_name = 'multi-qa-mpnet-base-cos-v1'

## Training and Evaluation Loop

For each combination, fine-tune SBERT and evaluate on the dev set.

In [None]:
def evaluate_mrr(model, dev_df, papers_df, top_k=5):
    # Encode dev queries
    query_embeddings = model.encode(dev_df['tweet_text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
    # Encode papers
    paper_embeddings = model.encode(papers_df['text'].tolist(), show_progress_bar=True, convert_to_tensor=True)
    # Compute similarity
    paper_norm = torch.nn.functional.normalize(paper_embeddings, p=2, dim=1)
    paper_ids = papers_df['cord_uid'].tolist()
    predictions = []
    for query_emb in query_embeddings:
        query_norm = torch.nn.functional.normalize(query_emb.unsqueeze(0), p=2, dim=1)
        similarity = torch.matmul(query_norm, paper_norm.T).squeeze()
        top_indices = torch.topk(similarity, k=min(top_k, len(paper_norm))).indices.tolist()
        preds = [paper_ids[i] for i in top_indices]
        predictions.append(preds)
    # Calculate MRR
    scores = []
    for i, row in dev_df.iterrows():
        gold = row['cord_uid']
        preds = predictions[i]
        if gold in preds:
            rank = preds.index(gold) + 1
            scores.append(1.0 / rank)
        else:
            scores.append(0.0)
    return np.mean(scores) if scores else 0.0

In [None]:
results = []
for epochs in epochs_list:
    for warmup_steps in warmup_steps_list:
        print(f'Training {model_name} | epochs={epochs} | warmup={warmup_steps}')
        model = SentenceTransformer(model_name)
        train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size)
        train_loss = losses.MultipleNegativesRankingLoss(model)
        model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=epochs,
            warmup_steps=warmup_steps,
            show_progress_bar=True,
            optimizer_params={'lr': learning_rate}
        )
        mrr = evaluate_mrr(model, dev_df, papers_df, top_k=5)
        results.append({
            'model': model_name,
            'epochs': epochs,
            'warmup_steps': warmup_steps,
            'learning_rate': learning_rate,
            'batch_size': batch_size,
            'dev_mrr': mrr
        })
        print(f'Result: MRR={mrr}')

## Save Results

Save the results to a CSV file for later analysis.

In [None]:
pd.DataFrame(results).to_csv('hyperparam_results_2.csv', index=False)
print('Results saved to hyperparam_results_2.csv')