In [1]:
# path to the folder containing the ordered list of documents
type = 'test'
ranker = 'splade'

res_path = f'data/trec_runfile_{type}_qr_{ranker}.txt'

# path for the query
query_path = f'data/queries_{type}_gpt4.csv'

# path for the documents
docs_path = 'data/collection.tsv'

# path to save the file
out_path = f'data/trec_runfile_{type}_qr_{ranker}_reranked.txt'

In [2]:
# import the reranker

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")
# tokenizer = AutoTokenizer.from_pretrained("nboost/pt-bert-large-msmarco")

model = AutoModelForSequenceClassification.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")
# model = AutoModelForSequenceClassification.from_pretrained("nboost/pt-bert-large-msmarco")

Downloading (…)okenizer_config.json:   0%|          | 0.00/39.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/413 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

In [3]:
# open res_path and divide it into subcollection based on the first column

from tqdm import tqdm
import pandas as pd

results = pd.read_csv(res_path, sep=' ', header=None)
results.columns = ['query_id', 'Q0', 'doc_id', 'rank', 'score', 'run_name']

queries = pd.read_csv(query_path)
# make qid index
queries = queries.set_index('qid')

collection = pd.read_csv(docs_path, sep='\t', header=None)
collection.columns = ['doc_id', 'text']
# make doc_id index
collection = collection.set_index('doc_id')

In [4]:
from collections import defaultdict

# create a dictionary with key = query_id and value = list of doc_id
# the list of doc_id is ordered by rank

res_dict = defaultdict(list)    
for index, row in results.iterrows():
    res_dict[row['query_id']].append(row['doc_id'])

In [11]:
import os

sorted_df = pd.DataFrame(columns=['query_id','Q0', 'doc_id', 'rank', 'score', 'model'])

device = 'mps'
model = model.to(device)

# if file does not exist create it
if not os.path.isfile(out_path):
    with open(out_path, 'w') as f:
        f.write('')

for key, doc_ids in tqdm(res_dict.items()):
    # if the file already contains the key, skip it
    with open(out_path, 'r') as f:
        if key in f.read():
            print(f'Query {key} already in the file')
            continue
        
    # read the query
    query = queries.loc[key]['query']
    scores = {}
    for doc_id in doc_ids:
        # read the document
        doc = collection.loc[doc_id]['text']
        # encode the query and the document
        encoding = tokenizer(query, doc, return_tensors='pt', truncation=True).to(device)

        # truncate the document to 512 tokens
        if encoding['input_ids'].shape[1] > 512:
            encoding['input_ids'] = encoding['input_ids'][:, :512]
            encoding['token_type_ids'] = encoding['token_type_ids'][:, :512]
            encoding['attention_mask'] = encoding['attention_mask'][:, :512]

        # rerank the document
        output = model(**encoding)
        # update the score in the dataframe
        scores[doc_id] = output.logits[0][1].item()

    # sort the dictionary by value
    sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    # update the dataframe
    with open(out_path, 'a') as f:
        for i, (doc_id, score) in enumerate(sorted_scores):
            f.write(f'{key} Q0 {doc_id} {i+1} {score} bert \n')

  0%|          | 0/248 [00:00<?, ?it/s]

Query 1_1 already in the file
Query 1_2 already in the file
Query 1_3 already in the file
Query 1_4 already in the file
Query 1_5 already in the file
Query 1_6 already in the file
Query 1_7 already in the file
Query 1_8 already in the file
Query 1_9 already in the file
Query 1_10 already in the file
Query 1_11 already in the file
Query 1_12 already in the file
Query 2_1 already in the file
Query 2_2 already in the file
Query 2_3 already in the file
Query 2_4 already in the file
Query 2_5 already in the file
Query 2_6 already in the file
Query 2_7 already in the file
Query 2_8 already in the file
Query 2_9 already in the file
Query 2_10 already in the file
Query 2_11 already in the file
Query 7_1 already in the file
Query 7_2 already in the file
Query 7_3 already in the file
Query 7_4 already in the file
Query 7_5 already in the file
Query 7_6 already in the file
Query 7_7 already in the file
Query 7_8 already in the file


100%|██████████| 248/248 [2:19:10<00:00, 33.67s/it]  
