# Neural Re-ranking Pipeline  
This Jupyter Notebook replicates the script for the MonoT5 and bi-encoder fallback re-ranking pipeline.

In [2]:
import os
import random
import torch
import numpy as np
from pathlib import Path

# 0) Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f"Reproducibility seeded at {SEED}")

Reproducibility seeded at 42


## 1) Metrics

In [3]:
import numpy as np

def compute_list_metrics(pred_lists, refs, ks=(1,5)):
    ranks = []
    for pred, gold in zip(pred_lists, refs):
        try:
            rank = next(i+1 for i, pid in enumerate(pred) if pid in gold)
        except StopIteration:
            rank = len(pred) + 1
        ranks.append(rank)
    ranks = np.array(ranks)
    metrics = {}
    for k in ks:
        rr = [1.0/r if r <= k else 0.0 for r in ranks]
        metrics[f"MRR@{k}"]    = float(np.mean(rr))
        metrics[f"Recall@{k}"] = float((ranks <= k).mean())
    return metrics

# Example usage
print(compute_list_metrics([[1,2,3],[4,5,6]], [[2],[5]]))

{'MRR@1': 0.0, 'Recall@1': 0.0, 'MRR@5': 0.5, 'Recall@5': 1.0}


## 2) MonoT5-base cross-encoder rerank

In [10]:
from transformers import T5ForConditionalGeneration, T5TokenizerFast

def score_pair(model, tokenizer, query, doc, device):
    inp = f"Query: {query} Document: {doc} Relevant:"
    tokens = tokenizer(inp, return_tensors="pt", truncation=True, max_length=512).to(device)
    labels = tokenizer("true", return_tensors="pt").input_ids.to(device)
    out = model(**tokens, labels=labels)
    return -out.loss.item()

def rerank_monot5(model, tokenizer, query, cands, id2text, device, top_k=5, batch_size=16):
    scores = []
    for i in range(0, min(len(cands), top_k), batch_size):
        for pid in cands[i:i+batch_size]:
            sc = score_pair(model, tokenizer, query, id2text[pid], device)
            scores.append((pid, sc))
    scores.sort(key=lambda x: x[1], reverse=True)
    return [pid for pid, _ in scores[:top_k]]

## 3) Fallback candidates with bi-encoder

In [11]:
from sentence_transformers import SentenceTransformer, util
import torch

def fallback_candidates(query, gold_id, bi_encoder, paper_emb, paper_ids, device, top_k=5):
    q_emb = bi_encoder.encode(query, convert_to_tensor=True, normalize_embeddings=True).to(device)
    cos_scores = util.cos_sim(q_emb, paper_emb)[0]
    top_idxs = torch.topk(cos_scores, k=top_k).indices.cpu().tolist()
    cands = [paper_ids[i] for i in top_idxs]
    if gold_id not in cands:
        cands.append(gold_id)
    return cands

## 4) MonoT5-3B batched rerank

In [6]:
import numpy as np

def rerank_batched(model, tokenizer, query, cands, id2text, device, top_k=5, batch_size=16, alpha=0.7):
    window = cands[:top_k]
    inputs = [f"Query: {query} Document: {id2text[pid]} Relevant:" for pid in window]
    raw_scores = []
    for i in range(0, len(inputs), batch_size):
        batch_in = inputs[i:i+batch_size]
        enc = tokenizer(batch_in, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        labs = tokenizer(["true"]*len(batch_in), return_tensors="pt", padding=True).input_ids.to(device)
        with torch.no_grad():
            out = model(**enc, labels=labs)
        batch_score = -out.loss.item()
        raw_scores.extend([batch_score]*len(batch_in))
    arr = np.array(raw_scores)
    mn, mx = arr.min(), arr.max()
    norm = (arr - mn)/(mx - mn + 1e-8)
    base = np.array([1 - (i/(len(window)-1)) for i in range(len(window))])
    final = alpha*norm + (1-alpha)*base
    scored = list(zip(window, final))
    scored.sort(key=lambda x: x[1], reverse=True)
    return [pid for pid, _ in scored[:top_k]]

## 5) Main pipeline

In [None]:
import os
import pandas as pd
import ast
from pathlib import Path
from tqdm.auto import tqdm
from transformers import T5ForConditionalGeneration, T5TokenizerFast

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")

    # Paths & data
    base_dir = Path.cwd().parent.parent
    data_dir = Path(os.getenv("DATA_DIR", base_dir/"data"))
    col_path = next(data_dir.rglob("subtask4b_collection_data.pkl"))
    train_preds_path = Path("../../predictions/Neural_Representation_Learning/E5_base_setup_train_TOP100.tsv")
    dev_preds_path   = Path("../../predictions/Neural_Representation_Learning/E5_base_setup_dev_TOP100.tsv")

    # Load collection
    papers_df = pd.read_pickle(col_path)
    papers_df['text'] = papers_df.apply(lambda r: (r.get('title','') + ' ' + r.get('abstract','')).strip(), axis=1)
    id2text   = dict(zip(papers_df['cord_uid'], papers_df['text']))
    paper_ids = papers_df['cord_uid'].tolist()

    # Load queries & preds
    train_df   = pd.read_csv(next(data_dir.rglob("subtask4b_query_tweets_train.tsv")), sep='\t')
    dev_df     = pd.read_csv(next(data_dir.rglob("subtask4b_query_tweets_dev.tsv")),   sep='\t')
    pred_train = pd.read_csv(train_preds_path, sep='\t', index_col='post_id')
    pred_dev   = pd.read_csv(dev_preds_path,   sep='\t', index_col='post_id')
    pred_col   = 'preds' if 'preds' in pred_dev.columns else pred_dev.columns[0]

    # Setup fallback encoder
    bi_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)
    paper_emb  = bi_encoder.encode(papers_df['text'].tolist(), convert_to_tensor=True, normalize_embeddings=True).to(device)

    # Load models
    m1_name = "castorini/monot5-base-msmarco-10k"
    tok1 = T5TokenizerFast.from_pretrained(m1_name)
    m1   = T5ForConditionalGeneration.from_pretrained(m1_name).to(device)
    m1.eval()
    m2_name = "castorini/monot5-3b-msmarco"
    tok2 = T5TokenizerFast.from_pretrained(m2_name)
    m2   = T5ForConditionalGeneration.from_pretrained(m2_name).to(device)
    m2.eval()

    # Evaluate both splits
    for split_name, df, pred_df in [("Train", train_df, pred_train), ("Dev", dev_df, pred_dev)]:
        refs = [[row['cord_uid']] for _, row in df.iterrows()]
        print(f"\n=== Split: {split_name} ===")

        # MonoT5-base
        preds1 = []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Base rerank"):
            qid = row['post_id']
            cands = ast.literal_eval(pred_df.at[qid, pred_col])[:5]
            preds1.append(rerank_monot5(m1, tok1, row['tweet_text'], cands, id2text, device))
        metrics1 = compute_list_metrics(preds1, refs, ks=(1,5))
        print(f"MRR@1: {metrics1['MRR@1']:.4f}, MRR@5: {metrics1['MRR@5']:.4f}, Recall@5: {metrics1['Recall@5']:.4f}")

        # MonoT5-3B batched
        preds2 = []
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Batched rerank"):
            qid, gold = row['post_id'], row['cord_uid']
            cands = ast.literal_eval(pred_df.at[qid, pred_col])
            if gold not in cands:
                cands = fallback_candidates(row['tweet_text'], gold, bi_encoder, paper_emb, paper_ids, device)
            preds2.append(rerank_batched(m2, tok2, row['tweet_text'], cands, id2text, device))
        metrics2 = compute_list_metrics(preds2, refs, ks=(1,5))
        print(f"MRR@1: {metrics2['MRR@1']:.4f}, MRR@5: {metrics2['MRR@5']:.4f}, Recall@5: {metrics2['Recall@5']:.4f}")

if __name__ == "__main__":
    main()