In [1]:
import pandas as pd
PATH_COLLECTION_DATA = 'subtask4b_collection_data.pkl' #MODIFY PATH
df_collection = pd.read_pickle(PATH_COLLECTION_DATA)

In [2]:
PATH_QUERY_TRAIN_DATA = 'subtask4b_query_tweets_train.tsv' #MODIFY PATH
PATH_QUERY_DEV_DATA = 'subtask4b_query_tweets_dev.tsv' #MODIFY PATH
PATH_QUERY_DEV_TEST = 'subtask4b_query_tweets_test.tsv' #MODIFY PATH
df_query_train = pd.read_csv(PATH_QUERY_TRAIN_DATA, sep = '\t')
df_query_dev = pd.read_csv(PATH_QUERY_DEV_DATA, sep = '\t')
df_query_test = pd.read_csv(PATH_QUERY_DEV_TEST, sep = '\t')

In [3]:
import pandas as pd
import random
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments


model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
#model_name = "Alibaba-NLP/gte-reranker-modernbert-base"
tokenizer_cross_encoder = AutoTokenizer.from_pretrained(model_name)
cross_encoder = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)

# --- DATA PREPARATION ---
# Merge metadata into positive pairs
df_merged = pd.merge(df_query_train, df_collection, on='cord_uid', how='inner')

def format_paper(row):
    return f"{row['title'].strip()} [SEP] {row['abstract'].strip()}"

df_merged['paper_text'] = df_merged.apply(format_paper, axis=1)

# Positive examples
positive_samples = [
    {"tweet": row['tweet_text'], "paper": row['paper_text'], "label": 1.0}
    for _, row in df_merged.iterrows()
]

# Negative examples
all_paper_ids = set(df_collection['cord_uid'])
paper_text_lookup = {
    row['cord_uid']: format_paper(row) for _, row in df_collection.iterrows()
}

negative_samples = []
for _ in range(3):
    for _, row in df_query_train.iterrows():
        tweet = row['tweet_text']
        correct_id = row['cord_uid']
        negative_id = random.choice(list(all_paper_ids - {correct_id}))
        negative_text = paper_text_lookup[negative_id]
        negative_samples.append({
            "tweet": tweet, "paper": negative_text, "label": 0.0
        })

# Combine and shuffle
all_samples = positive_samples + negative_samples
random.shuffle(all_samples)

# --- DATASET CLASS ---
class TweetPaperDataset(Dataset):
    def __init__(self, data, tokenizer_cross_encoder, max_length=512):
        self.data = data
        self.tokenizer_cross_encoder = tokenizer_cross_encoder
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        encoded = self.tokenizer_cross_encoder(
            item["tweet"],
            item["paper"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "labels": torch.tensor(item["label"], dtype=torch.float)
        }

# Create Dataset
train_dataset = TweetPaperDataset(all_samples, tokenizer_cross_encoder)

# --- TRAINING ---
training_args = TrainingArguments(
    output_dir="./fine-tuned-cross-encoder",
    num_train_epochs=3,
    per_device_train_batch_size=32,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    save_strategy="epoch",
    eval_strategy="no", 
    report_to="none"
)

trainer = Trainer(
    model=cross_encoder,
    args=training_args,
    train_dataset=train_dataset,
)

# Start fine-tuning
trainer.train()


2025-05-09 14:08:56.513084: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746799736.554421    4840 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746799736.567207    4840 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746799736.593645    4840 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746799736.593665    4840 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746799736.593668    4840 computation_placer.cc:177] computation placer alr

Step,Training Loss
50,14.159
100,0.1833
150,0.1199
200,0.0862
250,0.0653
300,0.0468
350,0.0465
400,0.041
450,0.0387
500,0.0342


KeyboardInterrupt: 

In [4]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer
class ClaimSourceDataset(Dataset):
    def __init__(self, df, collection_df, tokenizer, max_len=512):
        self.df = df
        self.collection = collection_df.set_index('cord_uid')
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        tweet = row['tweet_text']
        paper_id = row['cord_uid']
        paper_row = self.collection.loc[paper_id]
        paper = f"{paper_row['title']} {paper_row['abstract']}"

        tweet_enc = self.tokenizer(tweet, truncation=True, padding='max_length', max_length=self.max_len, return_tensors="pt")
        paper_enc = self.tokenizer(paper, truncation=True, padding='max_length', max_length=self.max_len, return_tensors="pt")

        return {
            'tweet_input_ids': tweet_enc['input_ids'].squeeze(),
            'tweet_attention_mask': tweet_enc['attention_mask'].squeeze(),
            'paper_input_ids': paper_enc['input_ids'].squeeze(),
            'paper_attention_mask': paper_enc['attention_mask'].squeeze()
        }


In [5]:
import torch
import torch.nn as nn
from transformers import AutoModel

class DualEncoder(nn.Module):
    def __init__(self, model_name="allenai/scibert_scivocab_uncased"):
        super().__init__()
        if True: 
            self.encoder = torch.load('final_scibert_512_5_r.pt')
        else: 
            self.encoder = AutoModel.from_pretrained(model_name)


    def forward(self, tweet_ids, tweet_mask, paper_ids, paper_mask):
        tweet_vec = self.encoder(tweet_ids, attention_mask=tweet_mask).last_hidden_state[:, 0]
        paper_vec = self.encoder(paper_ids, attention_mask=paper_mask).last_hidden_state[:, 0]
        return tweet_vec, paper_vec


In [6]:
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

In [10]:
import faiss
import numpy as np
from tqdm.auto import tqdm
def encode_papers(model, df_collection, tokenizer, batch_size=32):
    model.eval()
    paper_texts = df_collection.apply(lambda row: f"{row['title']} {row['abstract']}", axis=1).tolist()
    paper_ids = df_collection['cord_uid'].tolist()

    all_embeddings = []
    with torch.no_grad():
        for i in tqdm(range(0, len(paper_texts), batch_size)):
            batch = paper_texts[i:i+batch_size]
            encodings = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=256, return_overflowing_tokens=False)
            input_ids = encodings['input_ids'].to(device)
            attention_mask = encodings['attention_mask'].to(device)
            # remove one .encoder if you download the model and do not use the local weight file
            vecs = model.encoder.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
            all_embeddings.append(vecs.cpu().numpy())

    return paper_ids, np.vstack(all_embeddings)

In [8]:
torch.save(cross_encoder, 'ranking.pt')

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert = DualEncoder()
paper_ids, paper_embeddings = encode_papers(bert, df_collection, tokenizer)

faiss_index = faiss.IndexFlatIP(paper_embeddings.shape[1])
faiss_index.add(paper_embeddings)

paper_id_map = {i: pid for i, pid in enumerate(paper_ids)}


  0%|          | 0/242 [00:00<?, ?it/s]

In [12]:
import torch.nn.functional as F
from tqdm.auto import tqdm
def retrieve(model, cross_encoder, df_query_dev, tokenizer, tokenizer_cross_encoder, faiss_index, paper_id_map, paper_text_lookup, topk=10):
    model.eval()
    cross_encoder.eval()
    predictions = []

    with torch.no_grad():
        for text in tqdm(df_query_dev['tweet_text']):

            enc = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
            tweet_vec = model.encoder.encoder(enc['input_ids'], attention_mask=enc['attention_mask']).last_hidden_state[:, 0]
            tweet_vec = F.normalize(tweet_vec, dim=1).cpu().numpy()


            D, I = faiss_index.search(tweet_vec, 20)
            candidate_ids = [paper_id_map[idx] for idx in I[0]]
            candidate_texts = [paper_text_lookup[pid] for pid in candidate_ids]

            # Step 3: Use cross-encoder to score the pairs
            inputs = tokenizer_cross_encoder(
                [text] * len(candidate_texts),
                candidate_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(device)

            outputs = cross_encoder(**inputs)

            scores = outputs.logits.squeeze().tolist()  # shape: (topk,)

            reranked = sorted(zip(candidate_ids, scores), key=lambda x: x[1], reverse=True)
            top_docs = [doc_id for doc_id, _ in reranked[:10]]
            predictions.append(top_docs)

    df_query_dev['dense_topk'] = predictions


In [13]:
paper_text_lookup = {
    row['cord_uid']: format_paper(row) for _, row in df_collection.iterrows()
}
retrieve(bert, cross_encoder, df_query_test, tokenizer, tokenizer_cross_encoder, faiss_index, paper_id_map, paper_text_lookup)

  0%|          | 0/1446 [00:00<?, ?it/s]

In [14]:
df_query_test['preds'] = df_query_test['dense_topk'].apply(lambda x: x[:5])
df_query_test[['post_id', 'preds']].to_csv('predictions_final_ranking.tsv', index=None, sep='\t')

In [15]:
# Evaluate retrieved candidates using MRR@k
def get_performance_mrr(data, col_gold, col_pred, list_k = [1, 5, 10]):
    d_performance = {}
    for k in list_k:
        data["in_topx"] = data.apply(lambda x: (1/([i for i in x[col_pred][:k]].index(x[col_gold]) + 1) if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
        #performances.append(data["in_topx"].mean())
        d_performance[k] = data["in_topx"].mean()
    return d_performance

# Evaluate
retrieve(bert, cross_encoder, df_query_dev, tokenizer, tokenizer_cross_encoder, faiss_index, paper_id_map, paper_text_lookup)
results_test = get_performance_mrr(df_query_dev, 'cord_uid', 'dense_topk')
print("MRR Results:", results_test)

  0%|          | 0/1400 [00:00<?, ?it/s]

MRR Results: {1: 0.5057142857142857, 5: 0.5909999999999999, 10: 0.598437074829932}


In [None]:
# Top 20, weak cross encoder
# MRR Results: {1: 0.5914285714285714, 5: 0.6446071428571429, 10: 0.6499339569160998}
# top 20, loaded cross encoder
# MRR Results: {1: 0.27714285714285714, 5: 0.37324999999999997, 10: 0.3931893424036282}
# not finetuned, top 10
# MRR Results: {1: 0.585, 5: 0.6383928571428571, 10: 0.6441590136054423

# 20, little bit more ft, just positive. 
#MRR Results: {1: 0.595, 5: 0.6461071428571429, 10: 0.651310941043084}

# 20 reranking big one
# MRR Results: {1: 0.5057142857142857, 5: 0.5909999999999999, 10: 0.598437074829932}