In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import pandas as pd
import torch
from tqdm.std import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader
import json

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

In [None]:
model_name = "Soyoung97/RankT5-base"
data_path = "../passage_ranking_input_true_data/passage_ranking_query.tsv"
output_path = "../passage_output_result/rankT5_result.tsv"

In [None]:
dataset = pd.read_csv(data_path,sep='\t')

In [None]:
dataset

In [None]:
group_dataset = dataset.groupby("qid")

In [None]:
input_data = []
for qid in tqdm(group_dataset.groups.keys()):
    temp_df = group_dataset.get_group(qid)
    docid_list = temp_df.pid.tolist()
    query = temp_df['query'].tolist()[0]
    passage_list = temp_df.passage.tolist()

    input_data.append({"qid": qid, "query": query, "pids":docid_list, "passages":passage_list})

## loading tokenizer and model

In [None]:
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)

In [None]:
def format_input(query, passage):
    return f"query: {query} passage: {passage}"

In [None]:
def rank_passages(query, passages, docids, batch_size=16):
    scores = []
    formatted_inputs = [format_input(query, passage) for passage in passages]
    
    # Process tokenization and inference in batches
    for i in range(0, len(formatted_inputs), batch_size):
        batch_inputs = formatted_inputs[i:i + batch_size]
        batch_docids = docids[i:i + batch_size]
        
        # Batch tokenization
        inputs = tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        
        # Set decoder_input_ids to <pad> token, suitable for batching
        decoder_input_ids = torch.full(
            (len(batch_inputs), 1), tokenizer.pad_token_id, dtype=torch.long, device=device
        )
        
        # Perform batch inference
        with torch.no_grad():
            outputs = model(input_ids=inputs["input_ids"], 
                            attention_mask=inputs["attention_mask"], 
                            decoder_input_ids=decoder_input_ids)
        
        logits = outputs.logits
        # Compute the score for each passage
        batch_scores = logits[:, 0].mean(dim=-1).tolist()
        
        # Store docid along with the corresponding score
        scores.extend(zip(batch_docids, batch_scores))

    # Sort based on scores
    ranked_passages = sorted(scores, key=lambda x: x[1], reverse=True)
    
    # Extract sorted docids, passages, and scores
    ranked_docids = [item[0] for item in ranked_passages]
    ranked_scores = [item[1] for item in ranked_passages]

    return ranked_docids, ranked_scores


## runing model

In [None]:
final_result = []
chunk_size = 100
for input_ in tqdm(input_data):
    query = input_['query']
    qid =  input_['qid']
    docids = input_['pids']
    passages_list = input_['passages']

    ranked_docids, ranked_scores = rank_passages(query, passages_list, docids)
    final_result.append({"qid":qid, "pid":docids, "ranked_pid":ranked_docids, "scores":ranked_scores})

In [None]:
result_df = pd.DataFrame(final_result)
result_df = result_df.explode(["pid", "ranked_pid", "scores"], ignore_index=True)

In [None]:
result_df.to_csv(output_path,sep="\t",index=False)