In [9]:
# path to the folder containing the ordered list of documents

type = 'train'
ranker = 'bm25'

res_path = f'data/trec_runfile_{type}_qr_{ranker}.txt'

# path for the query
query_path = f'data/queries_{type}_gpt4.csv'

# path for the documents
docs_path = 'data/collection.tsv'

# path to save the file
out_path = f'data/trec_runfile_{type}_qr_{ranker}_reranked.txt'

In [2]:
# import the reranker

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")

model = AutoModelForSequenceClassification.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")

In [10]:
# open res_path and divide it into subcollection based on the first column

from tqdm import tqdm
import pandas as pd

results = pd.read_csv(res_path, sep=' ', header=None)
results.columns = ['query_id', 'Q0', 'doc_id', 'rank', 'score', 'run_name']

queries = pd.read_csv(query_path)
# make qid index
queries = queries.set_index('qid')

collection = pd.read_csv(docs_path, sep='\t', header=None)
collection.columns = ['doc_id', 'text']
# make doc_id index
collection = collection.set_index('doc_id')

In [11]:
from collections import defaultdict

# create a dictionary with key = query_id and value = list of doc_id
# the list of doc_id is ordered by rank

res_dict = defaultdict(list)    
for index, row in results.iterrows():
    res_dict[row['query_id']].append(row['doc_id'])

In [12]:
sorted_df = pd.DataFrame(columns=['query_id','Q0', 'doc_id', 'rank', 'score', 'model'])

device = 'mps'
model = model.to(device)

for key, doc_ids in tqdm(res_dict.items()):
    # if the file already contains the key, skip it
    with open(out_path, 'r') as f:
        if key in f.read():
            print(f'Query {key} already in the file')
            continue
        
    # read the query
    query = queries.loc[key]['query']
    scores = {}
    for doc_id in doc_ids:
        # read the document
        doc = collection.loc[doc_id]['text']
        # encode the query and the document
        encoding = tokenizer(query, doc, return_tensors='pt', truncation=True).to(device)

        # truncate the document to 512 tokens
        if encoding['input_ids'].shape[1] > 512:
            encoding['input_ids'] = encoding['input_ids'][:, :512]
            encoding['token_type_ids'] = encoding['token_type_ids'][:, :512]
            encoding['attention_mask'] = encoding['attention_mask'][:, :512]

        # rerank the document
        output = model(**encoding)
        # update the score in the dataframe
        scores[doc_id] = output.logits[0][1].item()

    # sort the dictionary by value
    sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    # update the dataframe
    with open(out_path, 'a') as f:
        for i, (doc_id, score) in enumerate(sorted_scores):
            f.write(f'{key} Q0 {doc_id} {i+1} {score} bert \n')

    
    

 17%|█▋        | 42/253 [00:00<00:00, 413.11it/s]

Query 4_1 already in the file
Query 4_2 already in the file
Query 4_3 already in the file
Query 4_4 already in the file
Query 4_5 already in the file
Query 4_6 already in the file
Query 4_7 already in the file
Query 4_8 already in the file
Query 4_9 already in the file
Query 17_1 already in the file
Query 17_2 already in the file
Query 17_3 already in the file
Query 17_4 already in the file
Query 17_5 already in the file
Query 17_6 already in the file
Query 17_7 already in the file
Query 17_8 already in the file
Query 17_9 already in the file
Query 17_10 already in the file
Query 18_1 already in the file
Query 18_2 already in the file
Query 18_3 already in the file
Query 18_4 already in the file
Query 18_5 already in the file
Query 18_6 already in the file
Query 18_7 already in the file
Query 18_8 already in the file
Query 18_9 already in the file
Query 18_10 already in the file
Query 18_11 already in the file
Query 22_1 already in the file
Query 22_2 already in the file
Query 22_3 alr

 33%|███▎      | 84/253 [00:00<00:00, 287.70it/s]

Query 40_7 already in the file
Query 40_8 already in the file
Query 49_1 already in the file
Query 49_2 already in the file
Query 49_3 already in the file
Query 49_4 already in the file
Query 49_5 already in the file
Query 49_6 already in the file
Query 49_7 already in the file
Query 49_8 already in the file
Query 50_1 already in the file
Query 50_2 already in the file
Query 50_3 already in the file
Query 50_4 already in the file
Query 50_5 already in the file
Query 50_6 already in the file
Query 50_7 already in the file
Query 50_8 already in the file
Query 54_1 already in the file
Query 54_2 already in the file
Query 54_3 already in the file
Query 54_4 already in the file
Query 54_5 already in the file
Query 54_6 already in the file
Query 54_7 already in the file
Query 54_8 already in the file
Query 54_9 already in the file
Query 56_1 already in the file
Query 56_2 already in the file
Query 56_3 already in the file
Query 56_4 already in the file
Query 56_5 already in the file
Query 56

 45%|████▌     | 115/253 [00:00<00:00, 231.68it/s]

Query 58_6 already in the file
Query 58_7 already in the file
Query 58_8 already in the file
Query 61_1 already in the file
Query 61_2 already in the file
Query 61_3 already in the file
Query 61_4 already in the file
Query 61_5 already in the file
Query 61_6 already in the file
Query 61_7 already in the file
Query 61_8 already in the file
Query 69_1 already in the file
Query 69_2 already in the file
Query 69_3 already in the file
Query 69_4 already in the file
Query 69_5 already in the file
Query 69_6 already in the file
Query 69_7 already in the file
Query 69_8 already in the file
Query 69_9 already in the file
Query 69_10 already in the file
Query 75_1 already in the file
Query 75_2 already in the file
Query 75_3 already in the file
Query 75_4 already in the file
Query 75_5 already in the file
Query 75_6 already in the file
Query 75_8 already in the file
Query 79_1 already in the file
Query 79_2 already in the file
Query 79_3 already in the file


 64%|██████▎   | 161/253 [00:00<00:00, 176.33it/s]

Query 79_4 already in the file
Query 79_5 already in the file
Query 79_6 already in the file
Query 79_7 already in the file
Query 79_8 already in the file
Query 79_9 already in the file
Query 82_1 already in the file
Query 82_2 already in the file
Query 82_3 already in the file
Query 82_4 already in the file
Query 82_5 already in the file
Query 82_6 already in the file
Query 82_7 already in the file
Query 82_8 already in the file
Query 82_9 already in the file
Query 82_10 already in the file
Query 83_1 already in the file
Query 83_2 already in the file
Query 83_3 already in the file
Query 83_4 already in the file
Query 83_5 already in the file
Query 83_6 already in the file
Query 83_7 already in the file
Query 83_8 already in the file
Query 84_1 already in the file
Query 84_2 already in the file


100%|██████████| 253/253 [29:52<00:00,  7.09s/it] 
