In [None]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from tqdm.std import tqdm

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

In [None]:
model_name = "amberoad/bert-multilingual-passage-reranking-msmarco"
data_path = "../passage_ranking_input_true_data/passage_ranking_query.tsv"
output_path = '../passage_output_result/bert_multilingual_result.tsv'

## loading data

In [None]:
data_df = pd.read_csv(data_path,sep='\t')

In [None]:
data_df

In [None]:
data_sample = data_df.head(5)

In [None]:
data_sample

## loading model and tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

In [None]:
def encode_batch(query, docs):

    if not isinstance(query, str):
        query = str(query)
    
    docs = [str(doc) if not isinstance(doc, str) else doc for doc in docs]
    
    # Batch encode query and document pairs, ensuring that query is fixed and docs are variable
    queries = [query] * len(docs)  # Each document matches the same query.
    encoded_inputs = tokenizer.batch_encode_plus(
        list(zip(queries, docs)),
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )
    return {key: value.to(device) for key, value in encoded_inputs.items()}

## inference

In [None]:
def rerank_batch(query, docs, batch_size=16):
    scores = []
    # Process documents in batches
    for i in range(0, len(docs), batch_size):
        # Get the current batch of documents
        doc_batch = docs[i:i + batch_size]
        
        # Batch encode the current batch of documents and queries
        inputs = encode_batch(query, doc_batch)
        
        # inference
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Get the relevance score (select the second logit value to represent the positive class score)
        batch_scores = outputs.logits[:, 1].tolist()
        scores.extend(batch_scores)
    
    return scores

In [None]:
def process_reranking(data, batch_size=16):
    reranked_results = {}
    for qid in tqdm(data['qid'].unique()):
        subset = data[data['qid'] == qid]
        query = subset['query'].iloc[0]
        # docs = subset['body'].tolist()  # Get all the documents corresponding to the query
        docs = subset['passage'].tolist()
        pids = subset['pid'].tolist()
        
        # Reranking documents for each query using batch processing
        scores = rerank_batch(query, docs, batch_size=batch_size)
        
        # Sort documents by score
        sorted_docs = sorted(zip(pids, scores), key=lambda x: x[1], reverse=True)

        ranked_pids = [pid for pid,_ in sorted_docs]
        scores_new =  [s for _, s in sorted_docs]
        # Store the sorted results
        reranked_results[qid] = list(zip(pids, ranked_pids, scores_new))
    
    return reranked_results

In [None]:
batch_size = 32  
reranked_results = process_reranking(data_df, batch_size=batch_size)

## saving the results

In [None]:
with open(output_path, 'w') as f:
    f.write("qid\tpid\tranked_pid\tscores\n")
    for qid, sorted_docs in reranked_results.items():
        for pid, ranked_pid, score in sorted_docs:
            # print(doc)
            f.write(f"{qid}\t{pid}\t{ranked_pid}\t{score}\n")