AIRwaves at CheckThat! 2025: Scientific Claim Source Retrieval
Implementation of a two‐stage IR pipeline for CLEF 2025 Task 4b:
1) Sparse + dense candidate generation  
2) Neural re‐ranking with MonoT5 and BERT variants  

In [None]:
import os, ast, random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import T5ForConditionalGeneration, T5TokenizerFast
from sentence_transformers import SentenceTransformer, util

# deterministic seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
try:
    torch.use_deterministic_algorithms(True)
except AttributeError:
    pass

# 1) Metrics

In [None]:
def compute_list_metrics(pred_lists, refs, ks=(1, 5)):
    ranks = []
    for pred, gold in zip(pred_lists, refs):
        try:
            r = next(i+1 for i,p in enumerate(pred) if p in gold)
        except StopIteration:
            r = len(pred)+1
        ranks.append(r)
    ranks = np.array(ranks)
    res = {}
    for k in ks:
        rr = [1.0/r if r<=k else 0.0 for r in ranks]
        res[f"MRR@{k}"]    = float(np.mean(rr))
        res[f"Recall@{k}"] = float((ranks<=k).mean())
    return res


## 2) Reranking Functions

In [None]:
# MonoT5
def score_pair(model, tok, query, doc, device):
    with torch.no_grad():
        inp = f"Query: {query} Document: {doc} Relevant:"
        enc = tok(inp, return_tensors="pt", truncation=True, max_length=512).to(device)
        labs = tok("true", return_tensors="pt").input_ids.to(device)
        out = model(**enc, labels=labs)
        return -out.loss.item()

def rerank_monot5(model, tok, query, cands, id2text, device, top_k=5, batch_size=16):
    scores=[]
    for i in range(0,len(cands),batch_size):
        for pid in cands[i:i+batch_size]:
            scores.append((pid, score_pair(model,tok,query,id2text[pid],device)))
    scores.sort(key=lambda x:x[1], reverse=True)
    return [p for p,_ in scores[:top_k]]

# Bi‐encoder fallback
def fallback_candidates(query, gold, bi, emb, ids, device, top_k=5):
    qv = bi.encode(query, convert_to_tensor=True, normalize_embeddings=True).to(device)
    sims = util.cos_sim(qv,emb)[0]
    idxs = torch.topk(sims, k=top_k).indices.cpu().tolist()
    c = [ids[i] for i in idxs]
    if gold not in c: c.append(gold)
    return c

# MonoT5‐3B batched
def rerank_batched(model, tok, query, cands, id2text, device, top_k=5, batch_size=16, alpha=0.7):
    window = cands[:top_k]
    inputs = [f"Query: {query} Document: {id2text[p]} Relevant:" for p in window]
    raw=[]
    for i in range(0,len(inputs),batch_size):
        enc = tok(inputs[i:i+batch_size], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        with torch.no_grad():
            logits = model(**enc).logits
        tid = tok.convert_tokens_to_ids("true")
        raw.extend(torch.log_softmax(logits,-1)[:,0,tid].cpu().tolist())
    arr=np.array(raw); mn, mx = arr.min(), arr.max()
    norm=(arr-mn)/(mx-mn+1e-8)
    base=np.linspace(1,0,len(window))
    final=alpha*norm+(1-alpha)*base
    paired=list(zip(window,final))
    paired.sort(key=lambda x:x[1], reverse=True)
    return [p for p,_ in paired[:top_k]]

## 3) Load Data & Models

In [None]:
col_path    = Path("/home/fs72760/nikitaz/data/subtask4b_collection_data.pkl")
train_q     = Path("/home/fs72760/nikitaz/data/subtask4b_query_tweets_train.tsv")
dev_q       = Path("/home/fs72760/nikitaz/data/subtask4b_query_tweets_dev.tsv")
test_q      = Path("/home/fs72760/nikitaz/data/subtask4b_query_tweets_test.tsv")

train_p     = Path("/home/fs72760/nikitaz/predictions/train_ranked_preds.tsv")
dev_p       = Path("/home/fs72760/nikitaz/predictions/dev_ranked_preds.tsv")
test_p      = Path("/home/fs72760/nikitaz/predictions/test_ranked_preds.tsv")

# load collection
papers = pd.read_pickle(col_path)
papers["text"] = (papers["title"].fillna("")+" "+papers["abstract"].fillna("")).str.strip()
id2text = dict(zip(papers.cord_uid, papers.text))
paper_ids = papers.cord_uid.tolist()

# load queries & preds
train_df = pd.read_csv(train_q, sep="\t")
dev_df   = pd.read_csv(dev_q,   sep="\t")
test_df  = pd.read_csv(test_q,  sep="\t")
pred_train = pd.read_csv(train_p, sep="\t", index_col="post_id")
pred_dev   = pd.read_csv(dev_p,   sep="\t", index_col="post_id")
pred_test  = pd.read_csv(test_p,  sep="\t", index_col="post_id")
pred_col   = "preds" if "preds" in pred_dev.columns else pred_dev.columns[0]

# embed corpus with bi-encoder
device = "cuda" if torch.cuda.is_available() else "cpu"
bi = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)
def encode_corpus(texts, model, bs=64):
    vecs=[]
    for i in tqdm(range(0,len(texts),bs), desc="Encode corpus"):
        vecs.append(model.encode(texts[i:i+bs], convert_to_tensor=True, normalize_embeddings=True))
    return torch.cat(vecs).to(device)
paper_emb = encode_corpus(papers.text.tolist(), bi)

# load rerankers
tok = T5TokenizerFast.from_pretrained("t5-base")
m1  = T5ForConditionalGeneration.from_pretrained("castorini/monot5-base-msmarco-10k").to(device).eval()
m2  = T5ForConditionalGeneration.from_pretrained("castorini/monot5-3b-msmarco").to(device).eval()

## 4) Evaluate on Train / Dev / Test

In [None]:
for name, df, preds in [("Train", train_df, pred_train),
                        ("Dev",   dev_df,   pred_dev),
                        ("Test",  test_df,  pred_test)]:
    refs = [[r] for r in df.cord_uid]
    print(f"\n=== {name} ===")
    # MonoT5-base
    p1=[]
    for _,row in tqdm(df.iterrows(), total=len(df), desc="MonoT5-base"):
        qid, q = row.post_id, row.tweet_text
        cands = ast.literal_eval(preds.at[qid,pred_col])
        p1.append(rerank_monot5(m1, tok, q, cands, id2text, device))
    print("Base:", compute_list_metrics(p1, refs))
    # MonoT5-3B
    p2=[]
    for _,row in tqdm(df.iterrows(), total=len(df), desc="MonoT5-3B"):
        qid, q, gold = row.post_id, row.tweet_text, row.cord_uid
        cands = ast.literal_eval(preds.at[qid,pred_col])
        if gold not in cands:
            cands = fallback_candidates(q, gold, bi, paper_emb, paper_ids, device)
        p2.append(rerank_batched(m2, tok, q, cands, id2text, device))
    print("Batched:", compute_list_metrics(p2, refs))