In [1]:
import pandas as pd
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
)
from dataclasses import dataclass
from typing import Tuple, List, Optional, Mapping, Any

torch.manual_seed(101)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Loading bm-25 ranked documents

df = pd.read_json("../data/bm_25_ranking.json")
df.head()

In [None]:
# Re-rank documents using fine-tuned model


class Reranker():

    def __init__(self, model_name):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def rank_documents(self,
                       query: str,
                       documents: List[str],
                       top_k: int = 10) -> List[Tuple[str, float]]:
        
        scores = []
        for doc in documents:
            input_text = f"Query: {query} Document: {doc} Relevant: "
            inputs = self.tokenizer(
                input_text,
                return_tensors="pt",
                max_length=512,
                truncation=True
            ).to(self.device)
        
            with torch.no_grad():
                outputs = self.model(**inputs, labels=inputs["input_ids"])
                score = outputs.loss.item()
            scores.append(score)

        ranked_docs = list(zip(documents, scores))
        ranked_docs.sort(key=lambda x: x[1], reverse=True)

        return ranked_docs[:top_k]

    def print_ranking(self,
                      query: str,
                      documents: List[str]) -> None:
        ranked_results = self.rank_documents(query, documents)
        print(f"\n {query} \n")
        for rank, (doc, score) in enumerate(ranked_results, 1):
            print(f"{rank}. Document: {doc[:100]}...\n")

In [None]:
# Example of a re-rank on 1000 documents 

query = df["query"].unique()[0]
relevant_documents = df[df["query"] == query]["doc"].values

In [None]:
model_1 = "pratham4521/T5-base-masmarco-finetuned"
ranker = Reranker(model_name = model_1)
ranker.print_ranking(query, relevant_documents)