In [1]:
# Load data
import pandas as pd
# Read from the JSON file
def load_series_from_json(filename):
    loaded_series = pd.read_json(filename)
    return loaded_series

filename = 'ranked_train'
ranked_train = load_series_from_json(filename)

filename = 'ranked_dev'
ranked_dev = load_series_from_json(filename)

# Import references
PATH_COLLECTION_DATA = '../subtask4b_collection_data.pkl'
df_collection = pd.read_pickle(PATH_COLLECTION_DATA)

paper_info = df_collection.set_index('cord_uid')[['title', 'abstract', 'authors', 'journal']]

tweet_info_train = ranked_train[["tweet_text", "cord_uid", "tfidf_topk"]]
tweet_info_dev = ranked_dev[["post_id", "tweet_text", "cord_uid", "tfidf_topk"]]

In [None]:
in_top_1 = 0
in_top_5 = 0
in_top_10 = 0

for index, sample in tweet_info_dev.iterrows():
    correct_uid = sample['cord_uid']
    tfidf_topk = sample['tfidf_topk']

    for i in range(1, 11):
        uid = tfidf_topk[i-1]
        if correct_uid == uid:
            if i == 1:
                in_top_1 += 1
            elif i > 1 and i <= 5:
                in_top_5 += 1
            else:
                in_top_10 += 1
print(f"In top 1: {in_top_1}")
print(f"In top 5: {in_top_5}")
print(f"In top 10: {in_top_10}")

In top 1: 805
In top 5: 167
In top 10: 65


In [2]:
from sentence_transformers import SentenceTransformer

def load_sentence_transformer_model(name):
    try:
        model = SentenceTransformer(name)
        print(f"Successfully loaded {name}")
    except Exception as e:
        print(f"Failed to load the model. Error: {e}")
        import traceback
        traceback.print_exc() # Print full traceback if it fails
    return model

from sentence_transformers import (
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import RerankingEvaluator

In [2]:
# Function for generating the document string
def get_document_string(paper_info, cord_uid):
    
    title = paper_info['title'][cord_uid]
    abstract = paper_info['abstract'][cord_uid]
    authors = paper_info['authors'][cord_uid]
    journal = paper_info['journal'][cord_uid]

    if not (isinstance(abstract, str) and abstract.strip()):
        return 'Abstract missing!'

    if not (isinstance(title, str) and title.strip()):
        # Title is missing, leave it blank
        title = ''

    if not (isinstance(authors, str) and authors.strip()):
        authors = ''

    if not (isinstance(journal, str) and journal.strip()):
        journal = ''

    document_string = '[TITLE]: ' + title + ' [AUTHORS]: ' + authors + ' [JOURNAL]: ' + journal + ' [ABSTRACT]: ' + abstract
    return document_string

In [5]:
from datasets import Dataset

train_data_list_of_dicts = []

up_to_top = 5  # How many (excluding the correct one) of the top x should be taken, i.e. 10 means that (excluding the correct one) the top 10 documents will be taken.

# For hyperparameter tuning, we reduced the number of tweets to have faster learning:
perc_n_tweets = 0.1
if perc_n_tweets < 1:
    print(f"sampling {perc_n_tweets}% of the tweets ...")
    tweet_info_train = tweet_info_train.sample(frac=perc_n_tweets, random_state=42)

# Iterate through the rows of the tweet_info_train DataFrame
for index, row in tweet_info_train.iterrows():
    tweet_text = row["tweet_text"]
    correct_cord_uid = row["cord_uid"]
    negative_cord_uids = [elem for elem in row['tfidf_topk'] if elem != correct_cord_uid]

    # Get the title, abstract, authors and journal of the correct paper
    if correct_cord_uid in paper_info['title']:
        
        document_string_pos = get_document_string(paper_info, correct_cord_uid)
        
        if document_string_pos == 'Abstract missing!':
            print(f"Warning: Abstract missing for {correct_cord_uid} in paper_info, skipping tweet.")
            continue

        negatives = []
        for i in range(0, up_to_top):
            negative_cord_uid = negative_cord_uids[i]
            
            document_string_neg = get_document_string(paper_info, negative_cord_uid)
            if document_string_neg == 'Abstract missing!':
                print(f"Warning: Abstract missing for {correct_cord_uid} in paper_info, skipping tweet.")
                continue

            negatives.append(document_string_neg)

        # Create a dictionary for this training example
        if not tweet_text or not document_string_pos or not negatives:
            print(f"Warning: One of tweet_text {tweet_text} or document_string_pos {document_string_pos} or {document_string_neg} is None or empty!")
            continue

        for negative in negatives:
            example_dict = {
                'anchor': tweet_text,
                'positive': document_string_pos,
                'negative': negative
            }
            train_data_list_of_dicts.append(example_dict)

    else:
        # Handle cases where the correct paper's abstract is not found in your paper_info data
        print(f"Warning: Correct data not found for {correct_cord_uid} for tweet at index {index}, skipping tweet.")

print(f"Created {len(train_data_list_of_dicts)} training examples as dictionaries.")

# Convert the list of dictionaries into a datasets.Dataset
train_dataset = Dataset.from_list(train_data_list_of_dicts)

print(f"Converted to Hugging Face Triplet Dataset with {len(train_dataset)} rows and columns: {train_dataset.column_names}")
print(train_dataset[0])

sampling 0.1% of the tweets ...
Created 6425 training examples as dictionaries.
Converted to Hugging Face Triplet Dataset with 6425 rows and columns: ['anchor', 'positive', 'negative']
{'anchor': 'and as you are well aware the mortality rate of covid is not the only concern. the number of covid patients requiring icu is over 30% worldwide. a huge burden on hospital infrastructure, blocking essential services required for large numbers of vital surgical work.', 'positive': '[TITLE]: Rate of Intensive Care Unit admission and outcomes among patients with coronavirus: A systematic review and Meta-analysis [AUTHORS]: Abate, Semagn Mekonnen; Ahmed Ali, Siraj; Mantfardo, Bahiru; Basu, Bivash [JOURNAL]: PLoS One [ABSTRACT]: BACKGROUND: The rate of ICU admission among patients with coronavirus varied from 3% to 100% and the mortality was as high as 86% of admitted patients. The objective of the systematic review was to investigate the rate of ICU admission, mortality, morbidity, and complicatio

In [6]:
print(train_dataset[0]['negative'])

[TITLE]: High expression of ACE2 receptor of 2019-nCoV on the epithelial cells of oral mucosa [AUTHORS]: Xu, Hao; Zhong, Liang; Deng, Jiaxin; Peng, Jiakuan; Dan, Hongxia; Zeng, Xin; Li, Taiwen; Chen, Qianming [JOURNAL]: Int J Oral Sci [ABSTRACT]: It has been reported that ACE2 is the main host cell receptor of 2019-nCoV and plays a crucial role in the entry of virus into the cell to cause the final infection. To investigate the potential route of 2019-nCov infection on the mucosa of oral cavity, bulk RNA-seq profiles from two public databases including The Cancer Genome Atlas (TCGA) and Functional Annotation of The Mammalian Genome Cap Analysis of Gene Expression (FANTOM5 CAGE) dataset were collected. RNA-seq profiling data of 13 organ types with para-carcinoma normal tissues from TCGA and 14 organ types with normal tissues from FANTOM5 CAGE were analyzed in order to explore and validate the expression of ACE2 on the mucosa of oral cavity. Further, single-cell transcriptomes from an in

In [6]:
# Generate Evaluation Examples
eval_examples = []
n_correct_cord_uid_not_in_top_k = 0

# Iterate through the rows of the validation tweet_info DataFrame
for index, row in tweet_info_dev.iterrows():
    query_text = row['tweet_text']
    correct_cord_uid = row['cord_uid'] # This is the ID of the correct paper
    top_k_candidate_uids = row['tfidf_topk'] # This is the list of UIDs from the first stage ranker

    # We need the document for the correct paper and for all candidate papers
    positive_document = None
    negative_documents_map = {} # Map UID to abstract text for candidates

    # Get the positive document
    if correct_cord_uid in paper_info['title']:
        positive_document = get_document_string(paper_info, correct_cord_uid)

        # Get documents for all negative candidates (all in the top-k list except the positive document)
        for uid in top_k_candidate_uids:
            if uid in paper_info['title'] and uid != correct_cord_uid:
                negative_documents_map[uid] = get_document_string(paper_info, uid)

        if positive_document and negative_documents_map: # Ensure we have the positive and at least one negative abstract
            negative_uids_list = list(negative_documents_map.keys())
            negative_document_list = list(negative_documents_map.values())

            if correct_cord_uid in top_k_candidate_uids:
                eval_examples.append({
                    "query": query_text,
                    "positive": [positive_document],
                    "negative": negative_document_list
                })
                
            else:
                # This case means the correct paper was not found in the top-k list from the first stage ranker.
                n_correct_cord_uid_not_in_top_k += 1
    
        else:
            print(f"Either not positive document {positive_document} or not negative_documents map {negative_documents_map}")
    else:
        print("Positive document not found!")

if n_correct_cord_uid_not_in_top_k > 0:
    print(f"Warning: {n_correct_cord_uid_not_in_top_k} correct cord_uid's are not in top-k for validation tweets, cannot evaluate re-ranking for them.")

print(f"Created {len(eval_examples)} evaluation examples for RerankingEvaluator.")
print(eval_examples[0])

Created 1214 evaluation examples for RerankingEvaluator.
{'query': 'covid recovery: this study from the usa reveals that a proportion of cases experience impairment in some cognitive functions for several months after infection. some possible biases &amp; limitations but more research is required on impact of these long term effects.', 'positive': ['[TITLE]: Assessment of Cognitive Function in Patients After COVID-19 Infection [AUTHORS]: Becker, Jacqueline H.; Lin, Jenny J.; Doernberg, Molly; Stone, Kimberly; Navis, Allison; Festa, Joanne R.; Wisnivesky, Juan P. [JOURNAL]: JAMA Netw Open [ABSTRACT]: This cross-sectional study examines rates of cognitive impairment among patients who survived COVID-19 and whether the care setting was associated with cognitive impairment rates.'], 'negative': ['[TITLE]: Long covid-mechanisms, risk factors, and management. [AUTHORS]: Crook, Harry; Raza, Sanara; Nowell, Joseph; Young, Megan; Edison, Paul [JOURNAL]: BMJ [ABSTRACT]: Since its emergence in Wu

In [8]:
print(eval_examples[0]['positive'])

['[TITLE]: Assessment of Cognitive Function in Patients After COVID-19 Infection [AUTHORS]: Becker, Jacqueline H.; Lin, Jenny J.; Doernberg, Molly; Stone, Kimberly; Navis, Allison; Festa, Joanne R.; Wisnivesky, Juan P. [JOURNAL]: JAMA Netw Open [ABSTRACT]: This cross-sectional study examines rates of cognitive impairment among patients who survived COVID-19 and whether the care setting was associated with cognitive impairment rates.']


In [9]:
print(len(eval_examples[0]['negative']))

99


Medical Model
===

In [8]:
model = load_sentence_transformer_model('pritamdeka/S-PubMedBert-MS-MARCO')

Successfully loaded pritamdeka/S-PubMedBert-MS-MARCO


In [9]:
# Define train loss
loss = MultipleNegativesRankingLoss(model)

Training
---

In [10]:
BATCH_SIZE = 50
num_epochs = 5
#every_steps = int((len(train_dataset) / BATCH_SIZE) / 3)
#print(f"Evaluating every {every_steps} steps (1 Batch has {int((len(train_dataset) / BATCH_SIZE))} steps)")

# Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir='output/bi-encoder-S-PubMedBert-final',
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    weight_decay=0.01,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="epoch",
    eval_steps=None,
    save_strategy="epoch",
    save_steps=None,
    save_total_limit=5,
    logging_steps=10,
    run_name="bi-encoder-S-PubMedBert-final",  # Will be used in W&B if `wandb` is installed
)

# (Optional) Create an evaluator & evaluate the base model
# !!! WARNING !!! This evaluator EXCLUDES cases where the true paper IS NOT in the top k !!
dev_evaluator = RerankingEvaluator(eval_examples, batch_size=BATCH_SIZE, name='validation_reranking')
dev_evaluator(model)

# Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=None,
    loss=loss,
    evaluator=dev_evaluator,
)

trainer.train()

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Epoch,Training Loss,Validation Loss,Validation Reranking Map,Validation Reranking Mrr@10,Validation Reranking Ndcg@10
1,0.3675,No log,0.726384,0.72028,0.760579
2,0.1296,No log,0.726089,0.721268,0.763428
3,0.0532,No log,0.719331,0.714358,0.757602
4,0.0563,No log,0.716562,0.711595,0.755624
5,0.0291,No log,0.715011,0.70979,0.753311


TrainOutput(global_step=6430, training_loss=0.2525162873473916, metrics={'train_runtime': 2675.4962, 'train_samples_per_second': 120.099, 'train_steps_per_second': 2.403, 'total_flos': 0.0, 'train_loss': 0.2525162873473916, 'epoch': 5.0})

Store on Huggingface
===

In [8]:
# Save the trained model on hugging face (in a new repo)
from huggingface_hub import create_repo

def create_repo_on_huggingface(repo_id_str):
    try:
        repo_url = create_repo(repo_id=repo_id_str, exist_ok=True, private=True)
        print(f"Created or found repository on Hugging Face Hub: {repo_url}")
        # create_repo returns the URL of the repository, not the repo_id string.
        # Let's keep the repo_id string for upload_folder
        repo_id = repo_id_str

    except TypeError as e:
        print(f"Error creating repository: {e}")
        print("It seems your huggingface_hub library version is incompatible.")
        print("Please update it: pip install -U huggingface_hub")
    except Exception as e:
        print(f"An unexpected error occurred while creating the repository: {e}")
    return repo_id

In [9]:
# Uploads the model to hugging face
from huggingface_hub import upload_folder

def upload_model_to_huggingface(local_folder_path, repo_id):
    # Path to your local directory containing the trained model files

    print(f"Uploading files from {local_folder_path} to {repo_id}...")

    upload_folder(
        folder_path=local_folder_path,
        repo_id=repo_id,
        repo_type='model', # Specify the type of repository
        commit_message='Upload final model from checkpoint',
    )

    print("Upload complete!")

In [10]:
repo_id_str = 'LukasXperiaZ/bi-encoder-S-PubMedBert-final'
repo_id = create_repo_on_huggingface(repo_id_str)

Created or found repository on Hugging Face Hub: https://huggingface.co/LukasXperiaZ/bi-encoder-S-PubMedBert-final


In [14]:
local_folder_path = 'output/bi-encoder-S-PubMedBert-final/checkpoint-2572'
upload_model_to_huggingface(local_folder_path, repo_id)

Uploading files from output/bi-encoder-S-PubMedBert-final/checkpoint-2572 to LukasXperiaZ/bi-encoder-S-PubMedBert-final...


scheduler.pt:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/871M [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/6.03k [00:00<?, ?B/s]

Upload complete!


In [11]:
model_from_hub = SentenceTransformer(repo_id)
print(f"Model loaded successfully from Hugging Face Hub: {repo_id}")

Model loaded successfully from Hugging Face Hub: LukasXperiaZ/bi-encoder-S-PubMedBert-final


Evaluation
===

In [None]:
from sentence_transformers import util

# --- Define the Re-ranking Function for a single tweet ---
def rerank_tweet(tweet_text: str, initial_top_k_uids: list, paper_info: pd.Series, model: SentenceTransformer) -> list:
    """
    Re-ranks a list of candidate paper UIDs for a given tweet using a SentenceTransformer model.

    Args:
        tweet_text: The text of the query tweet.
        initial_top_k_uids: A list of paper UIDs from the initial ranker.
        paper_info: A pandas Series mapping CORD UIDs to abstract texts.
        model: The loaded SentenceTransformer model for encoding.

    Returns:
        A list of re-ranked paper UIDs sorted by relevance score (descending).
        Returns an empty list if no valid candidates are available or re-ranking fails.
    """

    candidate_document_texts = []
    candidate_uids = []
    for uid in initial_top_k_uids:
        candidate_document_texts.append(get_document_string(paper_info, uid))
        candidate_uids.append(uid)


    try:
        # Encode the Tweet and Candidate Documents
        # Ensure inputs are on the same device as the model
        query_embedding = model.encode(tweet_text, convert_to_tensor=True, show_progress_bar=False)
        candidate_embeddings = model.encode(candidate_document_texts, convert_to_tensor=True, show_progress_bar=False)

        # Calculate Similarity Scores (Dot Product)
        query_embedding = query_embedding.unsqueeze(0) # Ensure 2D
        scores = util.cos_sim(query_embedding, candidate_embeddings)[0] # Get scores for single query

        # Pair UIDs with Scores and Sort
        score_uid_pairs = sorted(zip(scores.tolist(), candidate_uids), key=lambda x: x[0], reverse=True)

        # Return the re-ranked list of UIDs
        return [uid for score, uid in score_uid_pairs]

    except Exception as e:
        print(f"Error during encoding or scoring for tweet: '{tweet_text[:50]}...' - {e}")
        return [] # Return empty list on error

In [33]:
tweet_info = tweet_info_dev.iloc[7]
query_text = tweet_info['tweet_text']
print(tweet_info['cord_uid'])
initial_top_k_uids = tweet_info['tfidf_topk'][:10]
print(initial_top_k_uids)
reranked_list = rerank_tweet(query_text, initial_top_k_uids, paper_info, model_from_hub)
print(reranked_list)

yoiq6cgt
['yoiq6cgt', 'xurnbrod', '8uxntauq', 'w7o2r4g1', '4gr6i8rf', 'vmmztj0a', 'ueb7mjnv', 'ecobfbpg', 'ii0ceksc', 'jp3pijw2']
['yoiq6cgt', '8uxntauq', 'jp3pijw2', 'ueb7mjnv', 'vmmztj0a', 'ecobfbpg', 'ii0ceksc', 'xurnbrod', '4gr6i8rf', 'w7o2r4g1']


In [34]:
def rerank_tweets(model_from_hub, only_rerank_top_k=100):
    # --- Prepare DataFrame to store re-ranked results ---
    # Create a new column in tweet_info_dev
    tweet_info_dev['reranked_uids'] = None


    # --- Iterate and Re-rank for each tweet ---
    print(f"\nStarting re-ranking for {len(tweet_info_dev)} tweets ...")

    for index, row in tweet_info_dev.iterrows():
        query_text = row['tweet_text']
        initial_top_k_uids = row['tfidf_topk'][:only_rerank_top_k]

        # Call the rerank_tweet function
        reranked_list = rerank_tweet(query_text, initial_top_k_uids, paper_info, model_from_hub)

        # Store the re-ranked list in the DataFrame
        tweet_info_dev.at[index, 'reranked_uids'] = reranked_list

        # Optional: Print progress
        if (index + 1) % 100 == 0:
            print(f"Processed {index + 1}/{len(tweet_info_dev)} tweets.")


    print(f"\nRe-ranking complete for all tweets.")

    # --- The 'reranked_uids' column in tweet_info_dev now contains the results ---
    # You can access and evaluate tweet_info_dev['reranked_uids']
    print(tweet_info_dev['reranked_uids'][:5])

In [42]:
rerank_tweets(model_from_hub, only_rerank_top_k=10)


Starting re-ranking for 1400 tweets ...
Processed 100/1400 tweets.
Processed 200/1400 tweets.
Processed 300/1400 tweets.
Processed 400/1400 tweets.
Processed 500/1400 tweets.
Processed 600/1400 tweets.
Processed 700/1400 tweets.
Processed 800/1400 tweets.
Processed 900/1400 tweets.
Processed 1000/1400 tweets.
Processed 1100/1400 tweets.
Processed 1200/1400 tweets.
Processed 1300/1400 tweets.
Processed 1400/1400 tweets.

Re-ranking complete for all tweets.
0    [hg3xpej0, styavbvi, 59up4v56, bqn29m9k, jwei2...
1    [r58aohnu, s2vckt2w, yrowv62k, icgsbelo, j1ucr...
2    [gruir7aw, vtcq6jgf, sgo76prc, mkwgkkoi, l6kcp...
3    [3sr2exq9, k0f4cwig, z795y51f, sv48gjkk, 8j3bb...
4    [ybwwmyqy, ouvq2wpq, rs3umc1x, sxx3yid9, lzddn...
Name: reranked_uids, dtype: object


In [36]:
# Evaluate retrieved candidates using MRR@k
def get_performance_mrr(data, col_gold, col_pred, list_k=[1, 5, 10]):
    d_performance = {}
    for k in list_k:
        data["in_topx"] = data.apply(
            lambda x: (1 / ([i for i in x[col_pred][:k]].index(x[col_gold]) + 1)
                      if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
        d_performance[k] = data["in_topx"].mean()
    return d_performance

In [43]:
df_dev_eval = pd.DataFrame({
    "cord_uid": tweet_info_dev["cord_uid"],
    "topk": tweet_info_dev["reranked_uids"]
})
print(df_dev_eval[:2])

# Evaluate MRR@k
results_def_rerank_roberta = get_performance_mrr(df_dev_eval, 'cord_uid', 'topk')
print(f"Reranking Results on the dev set: {results_def_rerank_roberta}")

   cord_uid                                               topk
0  3qvh482o  [hg3xpej0, styavbvi, 59up4v56, bqn29m9k, jwei2...
1  r58aohnu  [r58aohnu, s2vckt2w, yrowv62k, icgsbelo, j1ucr...
Reranking Results on the dev set: {1: np.float64(0.5671428571428572), 5: np.float64(0.6249642857142857), 10: np.float64(0.6298973922902494)}


Cross-Encoder Model
===

self-contained from here on

In [1]:
# Load data
import pandas as pd
# Read from the JSON file
def load_series_from_json(filename):
    loaded_series = pd.read_json(filename)
    return loaded_series

filename = 'ranked_train'
ranked_train = load_series_from_json(filename)

filename = 'ranked_dev'
ranked_dev = load_series_from_json(filename)

# Import references
PATH_COLLECTION_DATA = '../subtask4b_collection_data.pkl'
df_collection = pd.read_pickle(PATH_COLLECTION_DATA)

paper_info = df_collection.set_index('cord_uid')[['title', 'abstract', 'authors', 'journal']]

tweet_info_train = ranked_train[["tweet_text", "cord_uid", "tfidf_topk"]]
tweet_info_dev = ranked_dev[["post_id", "tweet_text", "cord_uid", "tfidf_topk"]]

In [2]:
# Function for generating the document string
def get_document_string(paper_info, cord_uid):
    
    title = paper_info['title'][cord_uid]
    abstract = paper_info['abstract'][cord_uid]
    authors = paper_info['authors'][cord_uid]
    journal = paper_info['journal'][cord_uid]

    if not (isinstance(abstract, str) and abstract.strip()):
        return 'Abstract missing!'

    if not (isinstance(title, str) and title.strip()):
        # Title is missing, leave it blank
        title = ''

    if not (isinstance(authors, str) and authors.strip()):
        authors = ''

    if not (isinstance(journal, str) and journal.strip()):
        journal = ''

    document_string = '[TITLE]: ' + title + ' [AUTHORS]: ' + authors + ' [JOURNAL]: ' + journal + ' [ABSTRACT]: ' + abstract
    return document_string

In [3]:
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments
from sentence_transformers.cross_encoder import CrossEncoderTrainer

from datasets import Dataset

In [4]:
def build_train_samples(tweet_info_df, paper_info, up_to_top):
    train_samples = {
        'query': [],
        'response': [],
        'label': []
    }

    for index, row in tweet_info_df.iterrows():
        tweet_text = row["tweet_text"]
        correct_cord_uid = row["cord_uid"]
        negative_cord_uids = [uid for uid in row['tfidf_topk'] if uid != correct_cord_uid]

        if correct_cord_uid not in paper_info.index:
            print(f"Warning: Correct data not found for {correct_cord_uid} for tweet at index {index}, skipping tweet.")
            continue

        document_string_pos = get_document_string(paper_info, correct_cord_uid)
        if document_string_pos == 'Abstract missing!':
            print(f"Warning: Abstract missing for {correct_cord_uid} in paper_info, skipping tweet.")
            continue

        negatives = []
        for uid in negative_cord_uids[:up_to_top]:
            doc = get_document_string(paper_info, uid)
            if doc != 'Abstract missing!':
                negatives.append(doc)

        if len(negatives) < up_to_top:
            print(f"Warning: Not enough negative documents for {correct_cord_uid}, skipping tweet.")
            continue

        # --- For training: individual examples ---
        # Append positive sample
        train_samples['query'].append(tweet_text)
        train_samples['response'].append(document_string_pos)
        train_samples['label'].append(1.0)

        # Append negative samples
        for neg_doc in negatives:
            train_samples['query'].append(tweet_text)
            train_samples['response'].append(neg_doc)
            train_samples['label'].append(0.0)
        # --- --- --- --- --- --- --- --- --- --- ---

    return train_samples

def build_eval_samples(tweet_info_df, paper_info):
    eval_samples = []

    for index, row in tweet_info_df.iterrows():
        tweet_text = row["tweet_text"]
        correct_cord_uid = row["cord_uid"]
        negative_cord_uids = [uid for uid in row['tfidf_topk'] if uid != correct_cord_uid]

        if correct_cord_uid not in paper_info.index:
            print(f"Warning: Correct data not found for {correct_cord_uid} for tweet at index {index}, skipping tweet.")
            continue

        document_string_pos = get_document_string(paper_info, correct_cord_uid)
        if document_string_pos == 'Abstract missing!':
            print(f"Warning: Abstract missing for {correct_cord_uid} in paper_info, skipping tweet.")
            continue

        negatives = []
        for uid in negative_cord_uids:
            doc = get_document_string(paper_info, uid)
            if doc != 'Abstract missing!':
                negatives.append(doc)

        # --- For RerankingEvaluator: group by query ---
        eval_samples.append({
            'query': tweet_text,
            'positive': [document_string_pos],
            'negative': negatives
            })
        # --- --- --- --- --- --- --- --- --- --- ---
    return eval_samples


def get_data(frac_train=1, frac_eval=1, up_to_top=5):
    
    if frac_train > 0 and frac_train < 1:
        print(f"Sampling {frac_train * 100}% of the training data ...")
        tweet_info_train_sample = tweet_info_train.sample(frac=frac_train, random_state=42)
        
    else:
        print("Using the full training data ...")
        tweet_info_train_sample = tweet_info_train

    if frac_eval > 0 and frac_eval < 1:
        print(f"Sampling {frac_eval * 100}% of the evaluation data ...")
        tweet_info_dev_sample = tweet_info_dev.sample(frac=frac_eval, random_state=42)
    
    else:
        print("Using the full evaluation data ...")
        tweet_info_dev_sample = tweet_info_dev
    
    train_samples = build_train_samples(tweet_info_train_sample, paper_info, up_to_top)
    eval_samples = build_eval_samples(tweet_info_dev_sample, paper_info)
    
    print(f"Training examples: {len(train_samples['query'])} | Evaluation queries: {len(eval_samples)}")
    return train_samples, eval_samples

In [5]:
model_name = 'cross-encoder/ms-marco-MiniLM-L6-v2'
model = CrossEncoder(model_name, num_labels=1)

In [15]:
name = 'crossencoder-reranker-ms-marco-MiniLM-L6-v2'
batch_size = 350    # 350 optimal for RTX 4090

In [None]:
# === Hyperparameters ===

# Fraction of the set of tweets to use for training
frac_train = 1
# Fraction of the set of tweets to use for evaluating
frac_eval = 1

# How much of the top k negatives to use as negative samples
up_to_top = 50

num_epochs = 5
warmup_ratio = 0.1

# Setup
train_samples, eval_samples = get_data(frac_train=frac_train, frac_eval=frac_eval, up_to_top=up_to_top)
print("\nFirst eval sample:\n", eval_samples[:1], "\n")
train_samples = Dataset.from_dict(train_samples)
print(train_samples)

# Define our training loss.
loss = BinaryCrossEntropyLoss(
    model=model
)

# Evaluator
evaluator = CrossEncoderRerankingEvaluator(
    samples=eval_samples,
    name='dev_reranking',
    batch_size=batch_size,
    show_progress_bar=True
)

# Training Args
args = CrossEncoderTrainingArguments(
    # Required parameter:
    output_dir='output/' + name,
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=2e-5,
    warmup_ratio=warmup_ratio,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    # Optional tracking/debugging parameters:
    eval_strategy="epoch",
    eval_steps=None,
    save_strategy="epoch",
    save_steps=None,
    save_total_limit=num_epochs,
    logging_steps=100,
    run_name=name,  # Will be used in W&B if `wandb` is installed
    seed=12
)

Using the full training data ...
Using the full evaluation data ...
Training examples: 655503 | Evaluation queries: 1400

First eval sample:
 [{'query': 'covid recovery: this study from the usa reveals that a proportion of cases experience impairment in some cognitive functions for several months after infection. some possible biases &amp; limitations but more research is required on impact of these long term effects.', 'positive': ['[TITLE]: Assessment of Cognitive Function in Patients After COVID-19 Infection [AUTHORS]: Becker, Jacqueline H.; Lin, Jenny J.; Doernberg, Molly; Stone, Kimberly; Navis, Allison; Festa, Joanne R.; Wisnivesky, Juan P. [JOURNAL]: JAMA Netw Open [ABSTRACT]: This cross-sectional study examines rates of cognitive impairment among patients who survived COVID-19 and whether the care setting was associated with cognitive impairment rates.'], 'negative': ['[TITLE]: Long covid-mechanisms, risk factors, and management. [AUTHORS]: Crook, Harry; Raza, Sanara; Nowell, J

KeyboardInterrupt: 

In [7]:
trainer = CrossEncoderTrainer(
    model=model,
    args=args,
    train_dataset=train_samples,
    eval_dataset=None,
    loss=loss,
    evaluator=evaluator
)

trainer.train()

Token indices sequence length is longer than the specified maximum sequence length for this model (623 > 512). Running this sequence through the model will result in indexing errors


Epoch,Training Loss,Validation Loss,Dev Reranking Map,Dev Reranking Mrr@10,Dev Reranking Ndcg@10
1,0.0494,No log,0.723702,0.717599,0.750762
2,0.0392,No log,0.726664,0.720466,0.753783
3,0.0392,No log,0.714428,0.707654,0.740395


                                                                       

KeyboardInterrupt: 

Save to hugging_face

In [6]:
# Save the trained model on hugging face (in a new repo)
from huggingface_hub import create_repo, upload_folder

def create_repo_on_huggingface(repo_id_str):
    try:
        repo_url = create_repo(repo_id=repo_id_str, exist_ok=True, private=True)
        print(f"Created or found repository on Hugging Face Hub: {repo_url}")
        # create_repo returns the URL of the repository, not the repo_id string.
        # Let's keep the repo_id string for upload_folder
        repo_id = repo_id_str

    except TypeError as e:
        print(f"Error creating repository: {e}")
        print("It seems your huggingface_hub library version is incompatible.")
        print("Please update it: pip install -U huggingface_hub")
    except Exception as e:
        print(f"An unexpected error occurred while creating the repository: {e}")
    return repo_id

# Uploads the model to hugging face
def upload_model_to_huggingface(local_folder_path, repo_id):
    # Path to your local directory containing the trained model files
    print(f"Uploading files from {local_folder_path} to {repo_id}...")
    upload_folder(
        folder_path=local_folder_path,
        repo_id=repo_id,
        repo_type='model', # Specify the type of repository
        commit_message='Upload final model from checkpoint',
    )
    print("Upload complete!")

In [7]:
repo_id_str = 'LukasXperiaZ/' + name
repo_id = create_repo_on_huggingface(repo_id_str)

Created or found repository on Hugging Face Hub: https://huggingface.co/LukasXperiaZ/crossencoder-reranker-ms-marco-MiniLM-L6-v2


In [10]:
local_folder_path = 'output/' + name + '/checkpoint-3746'
upload_model_to_huggingface(local_folder_path, repo_id)

Uploading files from output/crossencoder-reranker-ms-marco-MiniLM-L6-v2/checkpoint-3746 to LukasXperiaZ/crossencoder-reranker-ms-marco-MiniLM-L6-v2...


training_args.bin:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

optimizer.pt:   0%|          | 0.00/182M [00:00<?, ?B/s]

scheduler.pt:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

rng_state.pth:   0%|          | 0.00/14.6k [00:00<?, ?B/s]

Upload 5 LFS files:   0%|          | 0/5 [00:00<?, ?it/s]

Upload complete!


In [8]:
model_from_hub = CrossEncoder(repo_id)
print(f"Model loaded successfully from Hugging Face Hub: {repo_id}")

Model loaded successfully from Hugging Face Hub: LukasXperiaZ/crossencoder-reranker-ms-marco-MiniLM-L6-v2


Evaluation
===

In [40]:
evaluator(model_from_hub)

                                                                       

{'dev_reranking_map': 0.7266636129050209,
 'dev_reranking_mrr@10': 0.7204657029478458,
 'dev_reranking_ndcg@10': 0.7537827224501197}

Use top 500 for evaluation
===
(worse)

In [49]:
filename = 'ranked_dev_500'
ranked_dev_500 = load_series_from_json(filename)

tweet_info_dev = ranked_dev_500[["post_id", "tweet_text", "cord_uid", "tfidf_topk"]]
print(len(tweet_info_dev['tfidf_topk'][0]))
eval_samples = build_eval_samples(tweet_info_dev, paper_info)

500


In [51]:
evaluator = CrossEncoderRerankingEvaluator(
    samples=eval_samples,
    name='dev_reranking_500',
    batch_size=batch_size,
    show_progress_bar=True
)
evaluator(model_from_hub)

                                                                       

{'dev_reranking_500_map': 0.6856913593217681,
 'dev_reranking_500_mrr@10': 0.6802120181405896,
 'dev_reranking_500_ndcg@10': 0.7054858550615047}

Evaluate Train Samples
---

In [20]:
train_samples = build_eval_samples(tweet_info_train, paper_info)
print(train_samples[:1])



In [21]:
evaluator = CrossEncoderRerankingEvaluator(
    samples=train_samples,
    name='train_reranking',
    batch_size=batch_size,
    show_progress_bar=True
)
evaluator(model_from_hub)

                                                                          

{'train_reranking_map': 0.8089165473727975,
 'train_reranking_mrr@10': 0.8050934622143678,
 'train_reranking_ndcg@10': 0.8333077189432733}

Export results to prepare the submission on Codalab
---

In [10]:
filename = 'ranked_test'
ranked_test = load_series_from_json(filename)

tweet_info_test = ranked_test[["post_id", "tweet_text", "tfidf_topk"]]

In [11]:
print(tweet_info_test)

      post_id                                         tweet_text  \
0           1  A recent research study published yesterday cl...   
1           2  "We should track the long-term effects of thes...   
2           3        the agony of "long haul" covid-19 symptoms.   
3           4  Home and online monitoring and assessment of b...   
4           5  it may be a long one, folks! to avoid exceedin...   
...       ...                                                ...   
1441     1442  Clinical presentations, predisposing factors, ...   
1442     1443  risk factors for post-covid-19 condition in ho...   
1443     1444       do not assume children are less susceptible.   
1444     1445  eurosurveillance | estimated number of fatalit...   
1445     1446  breaking update: hydroxychloroquine still does...   

                                             tfidf_topk  
0     [nswj8x43, j0bu0upi, 41jqgsv0, bttme4wn, ix4zo...  
1     [evf9nz05, nie9mud9, 5vp2r2bd, mnsm39a8, z0hy5...  
2     [tr

In [12]:
def rerank_row_for_submission(row: pd.Series, model: CrossEncoder, paper_info: pd.Series, batch_size: int) -> tuple[str, str]:
    post_id = row['post_id']
    tweet_text = row['tweet_text']
    initial_top_k_uids = row['tfidf_topk']

    # Retrieve document texts for candidates, keeping only valid ones
    valid_candidates = [] # List of (uid, document_string) tuples
    for uid in initial_top_k_uids:
        doc_text = get_document_string(paper_info, uid)
        # Assuming get_document_string returns a non-empty string for valid docs
        # and handles missing abstracts (e.g., returns '' or 'Abstract missing!')
        if doc_text and doc_text != 'Abstract missing!':
                valid_candidates.append((uid, doc_text))
        else:
             print("Warning, abstract is missing! Skipping this tweet!")

    # Separate UIDs and texts for valid candidates
    valid_candidate_uids = [uid for uid, doc_text in valid_candidates]
    valid_candidate_document_texts = [doc_text for uid, doc_text in valid_candidates]

    # Create pairs for Cross-Encoder prediction: [[query, doc1], [query, doc2], ...]
    sentences_to_score = [[tweet_text, doc] for doc in valid_candidate_document_texts]

    # Get scores from the Cross-Encoder model
    try:
        # Use a reasonable batch size for prediction (can be larger than training batch size)
        # Set show_progress_bar=True if you want to see progress during prediction
        scores = model.predict(sentences_to_score, batch_size=batch_size, show_progress_bar=False)
    except Exception as e:
        print(f"Error during Cross-Encoder prediction for tweet {post_id}: {e}")
        return post_id, "" # Return empty prediction string on error

    # Pair scores with the valid candidate UIDs and Sort
    # scores correspond to valid_candidate_uids in the same order
    score_uid_pairs = sorted(zip(scores.tolist(), valid_candidate_uids), key=lambda x: x[0], reverse=True)

    # Get the re-ranked list of UIDs (all of them, in sorted order)
    reranked_uids = [uid for score, uid in score_uid_pairs]

    return post_id, reranked_uids


In [16]:
reranking_results = []

total = len(tweet_info_test)
print(f"Reranking {total} posts")

steps = int(total/10)
i = 1
for index, row in tweet_info_test.iterrows():
    post_id, reranked_uids = rerank_row_for_submission(row, model_from_hub, paper_info, batch_size)
    reranking_results.append({
        'post_id': post_id,
        'preds': reranked_uids[:5]
        })
    
    if i % steps == 0:
        print(f"{i} out of {total} posts reranked.")
    i += 1

df_preds = pd.DataFrame(reranking_results)

Reranking 1446 posts
144 out of 1446 posts reranked.
288 out of 1446 posts reranked.
432 out of 1446 posts reranked.
576 out of 1446 posts reranked.
720 out of 1446 posts reranked.
864 out of 1446 posts reranked.
1008 out of 1446 posts reranked.
1152 out of 1446 posts reranked.
1296 out of 1446 posts reranked.
1440 out of 1446 posts reranked.


In [17]:
print(df_preds['preds'][0])

['x4zuv4jo', '8zufbeuz', 'bv7hvc1e', 'tpic8ddl', 'rbgoabfk']


In [18]:
df_preds.to_csv('predictions.tsv', index=None, sep='\t')