In [27]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class ClaimSourceDataset(Dataset):
    def __init__(self, df, collection_df, tokenizer, max_len=256):
        self.df = df
        self.collection = collection_df.set_index('cord_uid')
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        tweet = row['tweet_text']
        paper_id = row['cord_uid']
        paper_row = self.collection.loc[paper_id]
        paper = f"{paper_row['title']} {paper_row['abstract']}"

        tweet_enc = self.tokenizer(tweet, truncation=True, padding='max_length', max_length=self.max_len, return_tensors="pt")
        paper_enc = self.tokenizer(paper, truncation=True, padding='max_length', max_length=self.max_len, return_tensors="pt")

        return {
            'tweet_input_ids': tweet_enc['input_ids'].squeeze(),
            'tweet_attention_mask': tweet_enc['attention_mask'].squeeze(),
            'paper_input_ids': paper_enc['input_ids'].squeeze(),
            'paper_attention_mask': paper_enc['attention_mask'].squeeze()
        }


In [28]:
import torch
import torch.nn as nn
from transformers import AutoModel

class DualEncoder(nn.Module):
    def __init__(self, model_name="scibert_scivocab_uncased"):
        super().__init__()
        self.encoder = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')


    def forward(self, tweet_ids, tweet_mask, paper_ids, paper_mask):
        tweet_vec = self.encoder(tweet_ids, attention_mask=tweet_mask).last_hidden_state[:, 0]
        paper_vec = self.encoder(paper_ids, attention_mask=paper_mask).last_hidden_state[:, 0]
        return tweet_vec, paper_vec


In [29]:
import torch.nn.functional as F

def contrastive_loss(tweet_vecs, paper_vecs, temperature=0.05):
    tweet_vecs = F.normalize(tweet_vecs, dim=1)
    paper_vecs = F.normalize(paper_vecs, dim=1)

    logits = torch.matmul(tweet_vecs, paper_vecs.T) / temperature
    labels = torch.arange(len(tweet_vecs)).to(tweet_vecs.device)
    return F.cross_entropy(logits, labels)


In [30]:
import numpy as np
import pandas as pd

In [31]:
# 1) Download the collection set from the Gitlab repository: https://gitlab.com/checkthat_lab/clef2025-checkthat-lab/-/tree/main/task4/subtask_4b
# 2) Drag and drop the downloaded file to the "Files" section (left vertical menu on Colab)
# 3) Modify the path to your local file path
PATH_COLLECTION_DATA = 'subtask4b_collection_data.pkl' #MODIFY PATH
df_collection = pd.read_pickle(PATH_COLLECTION_DATA)


In [32]:
PATH_QUERY_TRAIN_DATA = 'subtask4b_query_tweets_train.tsv' #MODIFY PATH
PATH_QUERY_DEV_DATA = 'subtask4b_query_tweets_dev.tsv' #MODIFY PATH

In [33]:
df_query_train = pd.read_csv(PATH_QUERY_TRAIN_DATA, sep = '\t')
df_query_dev = pd.read_csv(PATH_QUERY_DEV_DATA, sep = '\t')

In [34]:
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = DualEncoder().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

train_dataset = ClaimSourceDataset(df_query_train, df_collection, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


for epoch in range(3):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        tweet_ids = batch['tweet_input_ids'].to(device)
        tweet_mask = batch['tweet_attention_mask'].to(device)
        paper_ids = batch['paper_input_ids'].to(device)
        paper_mask = batch['paper_attention_mask'].to(device)

        tweet_vecs, paper_vecs = model(tweet_ids, tweet_mask, paper_ids, paper_mask)

        loss = contrastive_loss(tweet_vecs, paper_vecs)
   
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
    print(f"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}")


100%|██████████| 804/804 [19:13<00:00,  1.43s/it]


Epoch 1 Loss: 0.2893


100%|██████████| 804/804 [19:15<00:00,  1.44s/it]


Epoch 2 Loss: 0.1295


100%|██████████| 804/804 [19:16<00:00,  1.44s/it]

Epoch 3 Loss: 0.0693





In [35]:
import faiss
import numpy as np

def encode_papers(model, df_collection, tokenizer, batch_size=64):
    model.eval()
    paper_texts = df_collection.apply(lambda row: f"{row['title']} {row['abstract']}", axis=1).tolist()
    paper_ids = df_collection['cord_uid'].tolist()

    all_embeddings = []
    with torch.no_grad():
        for i in range(0, len(paper_texts), batch_size):
            batch = paper_texts[i:i+batch_size]
            encodings = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=256)
            input_ids = encodings['input_ids'].to(device)
            attention_mask = encodings['attention_mask'].to(device)
            vecs = model.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
            all_embeddings.append(vecs.cpu().numpy())

    return paper_ids, np.vstack(all_embeddings)


In [36]:
paper_ids, paper_embeddings = encode_papers(model, df_collection, tokenizer)

faiss_index = faiss.IndexFlatIP(paper_embeddings.shape[1])
faiss_index.add(paper_embeddings)

paper_id_map = {i: pid for i, pid in enumerate(paper_ids)}


In [37]:
def retrieve(model, df_query_dev, tokenizer, faiss_index, paper_id_map, topk=5):
    model.eval()
    predictions = []

    with torch.no_grad():
        for text in df_query_dev['tweet_text']:
            enc = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)
            tweet_vec = model.encoder(enc['input_ids'], attention_mask=enc['attention_mask']).last_hidden_state[:, 0]
            tweet_vec = F.normalize(tweet_vec, dim=1).cpu().numpy()

            D, I = faiss_index.search(tweet_vec, topk)
            preds = [paper_id_map[idx] for idx in I[0]]
            predictions.append(preds)

    df_query_dev['dense_topk'] = predictions


In [38]:
def get_performance_mrr(data, col_gold, col_pred, list_k = [1, 5]):
    d_performance = {}
    for k in list_k:
        data["in_topx"] = data.apply(lambda x: (1/([i for i in x[col_pred][:k]].index(x[col_gold]) + 1) if x[col_gold] in [i for i in x[col_pred][:k]] else 0), axis=1)
        d_performance[k] = data["in_topx"].mean()
    return d_performance

# Evaluate
retrieve(model, df_query_dev, tokenizer, faiss_index, paper_id_map)
results_dev = get_performance_mrr(df_query_dev, 'cord_uid', 'dense_topk', list_k=[1, 5])
print("MRR Results:", results_dev)


MRR Results: {1: 0.5092857142857142, 5: 0.5843452380952381}


In [39]:
df_query_dev['preds'] = df_query_dev['dense_topk'].apply(lambda x: x[:5])
df_query_dev[['post_id', 'preds']].to_csv('predictions_scibert.tsv', index=False, sep='\t')


In [40]:
torch.save(model, "out_scibert.pth")

In [17]:
df_query_train

Unnamed: 0,post_id,tweet_text,cord_uid
0,0,Oral care in rehabilitation medicine: oral vul...,htlvpvz5
1,1,this study isn't receiving sufficient attentio...,4kfl29ul
2,2,"thanks, xi jinping. a reminder that this study...",jtwb17u8
3,3,Taiwan - a population of 23 million has had ju...,0w9k8iy1
4,4,Obtaining a diagnosis of autism in lower incom...,tiqksd69
...,...,...,...
12848,14248,"""evidence on covid-19 reveals a growing body o...",9169o29b
12849,14249,Outdoor lighting has detrimental impacts on lo...,s2bpha8l
12850,14250,"26/ and influenza virus (and other pathogens, ...",atloc9th
12851,14251,does it?'sars-cov-2-naïve vaccinees had a 13.0...,t4y1ylb3
