In [5]:
# Load data
import pandas as pd
# Read from the JSON file
def load_series_from_json(filename):
    loaded_series = pd.read_json(filename)
    return loaded_series

filename = 'ranked_train'
ranked_train = load_series_from_json(filename)

filename = 'ranked_dev'
ranked_dev = load_series_from_json(filename)

# Import references
PATH_COLLECTION_DATA = '../subtask4b_collection_data.pkl'
df_collection = pd.read_pickle(PATH_COLLECTION_DATA)

paper_info = df_collection.set_index('cord_uid')[['title', 'abstract', 'authors', 'journal']]

tweet_info_train = ranked_train[["tweet_text", "cord_uid", "tfidf_topk"]]
tweet_info_dev = ranked_dev[["post_id", "tweet_text", "cord_uid", "tfidf_topk"]]

In [2]:
in_top_1 = 0
in_top_5 = 0
in_top_10 = 0

for index, entry in tweet_info_dev.iterrows():
    correct_uid = entry['cord_uid']
    tfidf_topk = entry['tfidf_topk']

    for i in range(1, 11):
        uid = tfidf_topk[i-1]
        if correct_uid == uid:
            if i == 1:
                in_top_1 += 1
            elif i > 1 and i <= 5:
                in_top_5 += 1
            else:
                in_top_10 += 1
print(f"In top 1: {in_top_1}")
print(f"In top 5: {in_top_5}")
print(f"In top 10: {in_top_10}")

In top 1: 805
In top 5: 167
In top 10: 65


In [3]:
from sentence_transformers import SentenceTransformer

def load_sentence_transformer_model(name):
    try:
        model = SentenceTransformer(name)
        print(f"Successfully loaded {name}")
    except Exception as e:
        print(f"Failed to load the model. Error: {e}")
        import traceback
        traceback.print_exc() # Print full traceback if it fails
    return model

from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import RerankingEvaluator

In [6]:
# Function for generating the document string
def get_document_string(paper_info, cord_uid):
    
    title = paper_info['title'][cord_uid]
    abstract = paper_info['abstract'][cord_uid]
    authors = paper_info['authors'][cord_uid]
    journal = paper_info['journal'][cord_uid]

    if not (isinstance(abstract, str) and abstract.strip()):
        return 'Abstract missing!'

    if not (isinstance(title, str) and title.strip()):
        # Title is missing, leave it blank
        title = ''

    if not (isinstance(authors, str) and authors.strip()):
        authors = ''

    if not (isinstance(journal, str) and journal.strip()):
        journal = ''

    document_string = '[TITLE]: ' + title + ' [AUTHORS]: ' + authors + ' [JOURNAL]: ' + journal + ' [ABSTRACT]: ' + abstract
    return document_string

In [None]:
from datasets import Dataset

train_data_list_of_dicts = []

up_to_top = 5  # How many (excluding the correct one) of the top x should be taken, i.e. 10 means that (excluding the correct one) the top 10 documents will be taken.

# For hyperparameter tuning, we reduced the number of tweets to have faster learning:
perc_n_tweets = 0.1
if perc_n_tweets < 1:
    print(f"sampling {perc_n_tweets}% of the tweets ...")
    tweet_info_train = tweet_info_train.sample(frac=perc_n_tweets, random_state=42)

# Iterate through the rows of the tweet_info_train DataFrame
for index, row in tweet_info_train.iterrows():
    tweet_text = row["tweet_text"]
    correct_cord_uid = row["cord_uid"]
    negative_cord_uids = [elem for elem in row['tfidf_topk'] if elem != correct_cord_uid]

    # Get the title, abstract, authors and journal of the correct paper
    if correct_cord_uid in paper_info['title']:
        
        document_string_pos = get_document_string(paper_info, correct_cord_uid)
        
        if document_string_pos == 'Abstract missing!':
            print(f"Warning: Abstract missing for {correct_cord_uid} in paper_info, skipping tweet.")
            continue

        negatives = []
        for i in range(0, up_to_top):
            negative_cord_uid = negative_cord_uids[i]
            
            document_string_neg = get_document_string(paper_info, negative_cord_uid)
            if document_string_neg == 'Abstract missing!':
                print(f"Warning: Abstract missing for {correct_cord_uid} in paper_info, skipping tweet.")
                continue

            negatives.append(document_string_neg)

        # Create a dictionary for this training example
        if not tweet_text or not document_string_pos or not negatives:
            print(f"Warning: One of tweet_text {tweet_text} or document_string_pos {document_string_pos} or {document_string_neg} is None or empty!")
            continue

        for negative in negatives:
            example_dict = {
                'anchor': tweet_text,
                'positive': document_string_pos,
                'negative': negative
            }
            train_data_list_of_dicts.append(example_dict)

    else:
        # Handle cases where the correct paper's abstract is not found in your paper_info data
        print(f"Warning: Correct data not found for {correct_cord_uid} for tweet at index {index}, skipping tweet.")

print(f"Created {len(train_data_list_of_dicts)} training examples as dictionaries.")

# Convert the list of dictionaries into a datasets.Dataset
train_dataset = Dataset.from_list(train_data_list_of_dicts)

print(f"Converted to Hugging Face Triplet Dataset with {len(train_dataset)} rows and columns: {train_dataset.column_names}")
print(train_dataset[0])

Created 64265 training examples as dictionaries.
Converted to Hugging Face Triplet Dataset with 64265 rows and columns: ['anchor', 'positive', 'negative']
{'anchor': 'Oral care in rehabilitation medicine: oral vulnerability, oral muscle wasting, and hospital-associated oral issues', 'positive': '[TITLE]: Oral Management in Rehabilitation Medicine: Oral Frailty, Oral Sarcopenia, and Hospital-Associated Oral Problems [AUTHORS]: Shiraishi, A.; Wakabayashi, Hidetaka; Yoshimura, Y. [JOURNAL]: J Nutr Health Aging [ABSTRACT]: Oral health is a crucial but often neglected aspect of rehabilitation medicine. Approximately 71% of hospitalized rehabilitation patients and 91% of hospitalized acute care patients have impaired oral health. Poor oral condition in hospitalized patients can be attributed to factors such as age, physical dependency, cognitive decline, malnutrition, low skeletal muscle mass and strength, and multimorbidity. Another major factor is a lack of knowledge and interest in oral p

In [6]:
print(train_dataset[0]['negative'])

[TITLE]: High expression of ACE2 receptor of 2019-nCoV on the epithelial cells of oral mucosa [AUTHORS]: Xu, Hao; Zhong, Liang; Deng, Jiaxin; Peng, Jiakuan; Dan, Hongxia; Zeng, Xin; Li, Taiwen; Chen, Qianming [JOURNAL]: Int J Oral Sci [ABSTRACT]: It has been reported that ACE2 is the main host cell receptor of 2019-nCoV and plays a crucial role in the entry of virus into the cell to cause the final infection. To investigate the potential route of 2019-nCov infection on the mucosa of oral cavity, bulk RNA-seq profiles from two public databases including The Cancer Genome Atlas (TCGA) and Functional Annotation of The Mammalian Genome Cap Analysis of Gene Expression (FANTOM5 CAGE) dataset were collected. RNA-seq profiling data of 13 organ types with para-carcinoma normal tissues from TCGA and 14 organ types with normal tissues from FANTOM5 CAGE were analyzed in order to explore and validate the expression of ACE2 on the mucosa of oral cavity. Further, single-cell transcriptomes from an in

In [7]:
# Generate Evaluation Examples
eval_examples = []
n_correct_cord_uid_not_in_top_k = 0

# Iterate through the rows of the validation tweet_info DataFrame
for index, row in tweet_info_dev.iterrows():
    query_text = row['tweet_text']
    correct_cord_uid = row['cord_uid'] # This is the ID of the correct paper
    top_k_candidate_uids = row['tfidf_topk'] # This is the list of UIDs from the first stage ranker

    # We need the document for the correct paper and for all candidate papers
    positive_document = None
    negative_documents_map = {} # Map UID to abstract text for candidates

    # Get the positive document
    if correct_cord_uid in paper_info['title']:
        positive_document = get_document_string(paper_info, correct_cord_uid)

        # Get documents for all negative candidates (all in the top-k list except the positive document)
        for uid in top_k_candidate_uids:
            if uid in paper_info['title'] and uid != correct_cord_uid:
                negative_documents_map[uid] = get_document_string(paper_info, uid)

        if positive_document and negative_documents_map: # Ensure we have the positive and at least one negative abstract
            negative_uids_list = list(negative_documents_map.keys())
            negative_document_list = list(negative_documents_map.values())

            if correct_cord_uid in top_k_candidate_uids:
                eval_examples.append({
                    "query": query_text,
                    "positive": [positive_document],
                    "negative": negative_document_list
                })
                
            else:
                # This case means the correct paper was not found in the top-k list from the first stage ranker.
                n_correct_cord_uid_not_in_top_k += 1
    
        else:
            print(f"Either not positive document {positive_document} or not negative_documents map {negative_documents_map}")
    else:
        print("Positive document not found!")

if n_correct_cord_uid_not_in_top_k > 0:
    print(f"Warning: {n_correct_cord_uid_not_in_top_k} correct cord_uid's are not in top-k for validation tweets, cannot evaluate re-ranking for them.")

print(f"Created {len(eval_examples)} evaluation examples for RerankingEvaluator.")
print(eval_examples[0])

Created 1214 evaluation examples for RerankingEvaluator.
{'query': 'covid recovery: this study from the usa reveals that a proportion of cases experience impairment in some cognitive functions for several months after infection. some possible biases &amp; limitations but more research is required on impact of these long term effects.', 'positive': ['[TITLE]: Assessment of Cognitive Function in Patients After COVID-19 Infection [AUTHORS]: Becker, Jacqueline H.; Lin, Jenny J.; Doernberg, Molly; Stone, Kimberly; Navis, Allison; Festa, Joanne R.; Wisnivesky, Juan P. [JOURNAL]: JAMA Netw Open [ABSTRACT]: This cross-sectional study examines rates of cognitive impairment among patients who survived COVID-19 and whether the care setting was associated with cognitive impairment rates.'], 'negative': ['[TITLE]: Long covid-mechanisms, risk factors, and management. [AUTHORS]: Crook, Harry; Raza, Sanara; Nowell, Joseph; Young, Megan; Edison, Paul [JOURNAL]: BMJ [ABSTRACT]: Since its emergence in Wu

In [8]:
print(eval_examples[0]['positive'])

['[TITLE]: Assessment of Cognitive Function in Patients After COVID-19 Infection [AUTHORS]: Becker, Jacqueline H.; Lin, Jenny J.; Doernberg, Molly; Stone, Kimberly; Navis, Allison; Festa, Joanne R.; Wisnivesky, Juan P. [JOURNAL]: JAMA Netw Open [ABSTRACT]: This cross-sectional study examines rates of cognitive impairment among patients who survived COVID-19 and whether the care setting was associated with cognitive impairment rates.']


In [9]:
print(len(eval_examples[0]['negative']))

99


Medical Model
===

In [8]:
model = load_sentence_transformer_model('pritamdeka/S-PubMedBert-MS-MARCO')

Successfully loaded pritamdeka/S-PubMedBert-MS-MARCO


In [9]:
# Define train loss
loss = MultipleNegativesRankingLoss(model)

Training
---

In [10]:
BATCH_SIZE = 50
num_epochs = 5
#every_steps = int((len(train_dataset) / BATCH_SIZE) / 3)
#print(f"Evaluating every {every_steps} steps (1 Batch has {int((len(train_dataset) / BATCH_SIZE))} steps)")

# Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir='output/bi-encoder-S-PubMedBert-final',
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    weight_decay=0.01,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="epoch",
    eval_steps=None,
    save_strategy="epoch",
    save_steps=None,
    save_total_limit=5,
    logging_steps=10,
    run_name="bi-encoder-S-PubMedBert-final",  # Will be used in W&B if `wandb` is installed
)

# (Optional) Create an evaluator & evaluate the base model
# !!! WARNING !!! This evaluator EXCLUDES cases where the true paper IS NOT in the top k !!
dev_evaluator = RerankingEvaluator(eval_examples, batch_size=BATCH_SIZE, name='validation_reranking')
dev_evaluator(model)

# Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=None,
    loss=loss,
    evaluator=dev_evaluator,
)

trainer.train()

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Epoch,Training Loss,Validation Loss,Validation Reranking Map,Validation Reranking Mrr@10,Validation Reranking Ndcg@10
1,0.3675,No log,0.726384,0.72028,0.760579
2,0.1296,No log,0.726089,0.721268,0.763428
3,0.0532,No log,0.719331,0.714358,0.757602
4,0.0563,No log,0.716562,0.711595,0.755624
5,0.0291,No log,0.715011,0.70979,0.753311


TrainOutput(global_step=6430, training_loss=0.2525162873473916, metrics={'train_runtime': 2675.4962, 'train_samples_per_second': 120.099, 'train_steps_per_second': 2.403, 'total_flos': 0.0, 'train_loss': 0.2525162873473916, 'epoch': 5.0})

Store on Huggingface
===

In [11]:
# Save the trained model on hugging face (in a new repo)
from huggingface_hub import create_repo

def create_repo_on_huggingface(repo_id_str):
    try:
        repo_url = create_repo(repo_id=repo_id_str, exist_ok=True, private=True)
        print(f"Created or found repository on Hugging Face Hub: {repo_url}")
        # create_repo returns the URL of the repository, not the repo_id string.
        # Let's keep the repo_id string for upload_folder
        repo_id = repo_id_str

    except TypeError as e:
        print(f"Error creating repository: {e}")
        print("It seems your huggingface_hub library version is incompatible.")
        print("Please update it: pip install -U huggingface_hub")
    except Exception as e:
        print(f"An unexpected error occurred while creating the repository: {e}")
    return repo_id

In [12]:
# Uploads the model to hugging face
from huggingface_hub import upload_folder

def upload_model_to_huggingface(local_folder_path, repo_id):
    # Path to your local directory containing the trained model files

    print(f"Uploading files from {local_folder_path} to {repo_id}...")

    upload_folder(
        folder_path=local_folder_path,
        repo_id=repo_id,
        repo_type='model', # Specify the type of repository
        commit_message='Upload final model from checkpoint',
    )

    print("Upload complete!")

In [13]:
repo_id_str = 'LukasXperiaZ/bi-encoder-S-PubMedBert-final'
repo_id = create_repo_on_huggingface(repo_id_str)

Created or found repository on Hugging Face Hub: https://huggingface.co/LukasXperiaZ/bi-encoder-S-PubMedBert-final


In [14]:
local_folder_path = 'output/bi-encoder-S-PubMedBert-final/checkpoint-2572'
upload_model_to_huggingface(local_folder_path, repo_id)

Uploading files from output/bi-encoder-S-PubMedBert-final/checkpoint-2572 to LukasXperiaZ/bi-encoder-S-PubMedBert-final...


scheduler.pt:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/871M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/6.03k [00:00<?, ?B/s]

Upload complete!


In [15]:
model_from_hub = SentenceTransformer(repo_id)
print(f"Model loaded successfully from Hugging Face Hub: {repo_id}")

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/205 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/77.4k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/706k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Model loaded successfully from Hugging Face Hub: LukasXperiaZ/bi-encoder-S-PubMedBert-final


Evaluation
===

TODO: DO EVALUATION WITH COSINE DISTANCE!!!
NOT DONE YET !!!!

In [None]:
# TODO TEST !!!

from sentence_transformers import util

# --- Define the Re-ranking Function for a single tweet ---
def rerank_tweet(tweet_text: str, initial_top_k_uids: list, paper_info: pd.Series, model: SentenceTransformer) -> list:
    """
    Re-ranks a list of candidate paper UIDs for a given tweet using a SentenceTransformer model.

    Args:
        tweet_text: The text of the query tweet.
        initial_top_k_uids: A list of paper UIDs from the initial ranker.
        paper_info: A pandas Series mapping CORD UIDs to abstract texts.
        model: The loaded SentenceTransformer model for encoding.

    Returns:
        A list of re-ranked paper UIDs sorted by relevance score (descending).
        Returns an empty list if no valid candidates are available or re-ranking fails.
    """

    candidate_document_texts = [get_document_string(paper_info, uid) for uid in initial_top_k_uids['title']]
    candidate_uids = [uid for uid in initial_top_k_uids['title']]

    try:
        # Encode the Tweet and Candidate Documents
        # Ensure inputs are on the same device as the model
        query_embedding = model.encode(tweet_text, convert_to_tensor=True, show_progress_bar=False)
        candidate_embeddings = model.encode(candidate_document_texts, convert_to_tensor=True, show_progress_bar=False)

        # Calculate Similarity Scores (Dot Product)
        query_embedding = query_embedding.unsqueeze(0) # Ensure 2D
        scores = util.cos_sim(query_embedding, candidate_embeddings)[0] # Get scores for single query

        # Pair UIDs with Scores and Sort
        score_uid_pairs = sorted(zip(scores.tolist(), candidate_uids), key=lambda x: x[0], reverse=True)

        # Return the re-ranked list of UIDs
        return [uid for score, uid in score_uid_pairs]

    except Exception as e:
        print(f"Error during encoding or scoring for tweet: '{tweet_text[:50]}...' - {e}")
        return [] # Return empty list on error

In [None]:
tweet_info = tweet_info_dev.iloc[0]
query_text = tweet_info['tweet_text']
initial_top_k_uids = tweet_info['tfidf_topk']
reranked_list = rerank_tweet(query_text, initial_top_k_uids, paper_info, model_from_hub)

In [None]:
def rerank_tweets(model_from_hub):
    # --- Prepare DataFrame to store re-ranked results ---
    # Create a new column in tweet_info_dev
    tweet_info_dev['reranked_uids'] = None


    # --- Iterate and Re-rank for each tweet ---
    print(f"\nStarting re-ranking for {len(tweet_info_dev)} tweets ...")

    for index, row in tweet_info_dev.iterrows():
        query_text = row['tweet_text']
        initial_top_k_uids = row['tfidf_topk']

        # Call the rerank_tweet function
        reranked_list = rerank_tweet(query_text, initial_top_k_uids, paper_info, model_from_hub)

        # Store the re-ranked list in the DataFrame
        tweet_info_dev.at[index, 'reranked_uids'] = reranked_list

        # Optional: Print progress
        if (index + 1) % 100 == 0:
            print(f"Processed {index + 1}/{len(tweet_info_dev)} tweets.")


    print(f"\nRe-ranking complete for all tweets.")

    # --- The 'reranked_uids' column in tweet_info_dev now contains the results ---
    # You can access and evaluate tweet_info_dev['reranked_uids']
    print(tweet_info_dev['reranked_uids'][:5])

In [None]:
rerank_tweets(model_from_hub)

In [None]:
# Evaluate retrieved candidates using MRR@k
def get_performance_mrr(data, col_gold, col_pred, list_k=[1, 5, 10]):
    d_performance = {}
    for k in list_k:
        data["in_topx"] = data.apply(
            lambda x: (1 / ([i for i in x[col_pred][:k]].index(x[col_gold]) + 1)
                      if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
        d_performance[k] = data["in_topx"].mean()
    return d_performance

In [None]:
df_dev_eval = pd.DataFrame({
    "cord_uid": tweet_info_dev["cord_uid"],
    "topk": tweet_info_dev["reranked_uids"]
})
print(df_dev_eval[:2])

# Evaluate MRR@k
results_def_rerank_roberta = get_performance_mrr(df_dev_eval, 'cord_uid', 'topk')
print(f"Reranking Results on the dev set: {results_def_rerank_roberta}")

Cross-Encoder Model
===

In [14]:
from sentence_transformers import InputExample
from collections import defaultdict
from sentence_transformers import CrossEncoder
from torch.utils.data import DataLoader
import torch
from torch.nn import BCEWithLogitsLoss

In [15]:
def build_samples(tweet_info_df, paper_info, up_to_top, max_queries=None, for_eval=False):
    samples = []
    count = 0

    for index, row in tweet_info_df.iterrows():
        tweet_text = row["tweet_text"]
        correct_cord_uid = row["cord_uid"]
        negative_cord_uids = [uid for uid in row['tfidf_topk'] if uid != correct_cord_uid]

        if correct_cord_uid not in paper_info.index:
            print(f"Warning: Correct data not found for {correct_cord_uid} for tweet at index {index}, skipping tweet.")
            continue

        document_string_pos = get_document_string(paper_info, correct_cord_uid)
        if document_string_pos == 'Abstract missing!':
            print(f"Warning: Abstract missing for {correct_cord_uid} in paper_info, skipping tweet.")
            continue

        negatives = []
        for uid in negative_cord_uids[:up_to_top]:
            doc = get_document_string(paper_info, uid)
            if doc != 'Abstract missing!':
                negatives.append(doc)

        if len(negatives) < up_to_top:
            print(f"Warning: Not enough negative documents for {correct_cord_uid}, skipping tweet.")
            continue

        if for_eval:
            # For RerankingEvaluator: group by query
            samples.append(InputExample(texts=[tweet_text] + [document_string_pos] + negatives))
        else:
            # For training: individual examples
            samples.append(InputExample(texts=[tweet_text, document_string_pos], label=1.0))
            for neg_doc in negatives:
                samples.append(InputExample(texts=[tweet_text, neg_doc], label=0.0))

        count += 1
        if max_queries and count >= max_queries:
            break

    return samples



def get_data(frac=1, up_to_top=5, max_queries=200):
    
    if frac > 0 and frac < 1:
        print(f"Sampling {frac * 100}% of the training data ...")
        tweet_info_train_sample = tweet_info_train.sample(frac=frac, random_state=42)
        
    else:
        print("Using the full training data ...")
        tweet_info_train_sample = tweet_info_train
    
    train_samples = build_samples(tweet_info_train_sample, paper_info, up_to_top)
    eval_samples = build_samples(tweet_info_dev, paper_info, up_to_top, max_queries, for_eval=True)
    
    print(f"Training examples: {len(train_samples)} | Evaluation queries: {len(eval_samples)}")
    return train_samples, eval_samples

In [16]:
def evaluate_mrr(model, eval_samples, batch_size):
    mrr_total = 0
    for example in eval_samples:
        query = example.texts[0]
        docs = example.texts[1:]
        
        sentences = [[query, doc] for doc in docs]
        scores = model.predict(sentences, batch_size=batch_size)

        # Ground-truth is always the first doc in the list
        sorted_scores = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)

        for rank, (idx, score) in enumerate(sorted_scores[:10]):
            if idx == 0:
                mrr_total += 1 / (rank + 1)
                break

    return mrr_total / len(eval_samples)


In [None]:
# Hyperparameters
frac = 0.05
up_to_top = 10
max_queries = 100

num_epochs = 3
warmup_steps = 50
batch_size = 16
model_save_path = 'output/crossencoder-reranker-ms-marco-MiniLM-L6-v2'

# Setup
train_samples, eval_samples = get_data(frac=frac, up_to_top=up_to_top, max_queries=max_queries)
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size)

model_name = 'cross-encoder/ms-marco-MiniLM-L6-v2'
model = CrossEncoder(model_name, num_labels=1)

for epoch in range(num_epochs):
    model.fit(
        train_dataloader=train_dataloader,
        epochs=1,
        warmup_steps=warmup_steps,
        output_path=model_save_path,
        save_best_model=False
    )
    mrr_score = evaluate_mrr(model, eval_samples, batch_size)
    print(f"Epoch {epoch + 1} — MRR@10: {mrr_score:.4f}")

Sampling 5.0% of the training data ...
Training examples: 7073 | Evaluation queries: 100


Token indices sequence length is longer than the specified maximum sequence length for this model (953 > 512). Running this sequence through the model will result in indexing errors


Step,Training Loss


Epoch 1 — MRR@10: 0.7771


Step,Training Loss


Epoch 2 — MRR@10: 0.7784


Step,Training Loss


Epoch 3 — MRR@10: 0.8075
