# Label Reranker

This notebook contains code that is used to analyse the BERT top 100 reranking. It will provide each query top 100 of a label.

The labels are:
- high: rank 1-20
- medium: rank 21-80
- low: rank 81-100
- outside scope: rank >100

## Imports

In [1]:
import pandas as pd
import numpy as np

## Paths

In [2]:
msmarco_dir = "../data/msmarco_files/"
anserini_dir = "../data/anserini_output/"
output_dir = "../data/output/"

## Load Data

CHANGE THE FOLLOWING FILENAMES

In [3]:
bert_output_filename = 'bert_run_development_top100.tsv'
relevance_filename = 'qrels.dev.small.tsv'
output_filename = 'bert_rank_labels_dev_small.tsv'

#### Bert Reranks

In [4]:
bert_output_df = pd.read_csv(output_dir + bert_output_filename,delimiter='\t',encoding='utf-8', header=None)
bert_output_df.columns = ['query_id', 'passage_id', 'bm25_rank', 'query', 'passage', 'score_bert', 'bert_rank']

In [5]:
bert_output_df.head(1)

Unnamed: 0,query_id,passage_id,bm25_rank,query,passage,score_bert,bert_rank
0,2,1782337,1,androgen receptor define,enzalutamide is an androgen receptor inhibitor...,-1.713628,34


#### MSMARCO Relevance File

In [6]:
msmarco_relevance_df = pd.read_csv(msmarco_dir + relevance_filename,delimiter='\t',encoding='utf-8', header=None)
msmarco_relevance_df.columns = ['query_id', 'label1', 'passage_id', 'label2']

In [7]:
msmarco_relevance_df.head(1)

Unnamed: 0,query_id,label1,passage_id,label2
0,300674,0,7067032,1


## Merge Bert Reranking with MSMARCO Relevance

#### Sort Query ids

In [8]:
ranking_relevance_merged_df = bert_output_df.merge(msmarco_relevance_df,how='left',on=['query_id','passage_id'])
ranking_relevance_merged_df['true_label'] = ranking_relevance_merged_df['label2'].fillna(0)
ranking_relevance_merged_df['true_label'].sum()

4925.0

In [9]:
len(np.unique(ranking_relevance_merged_df['query_id'].tolist()))

4737

In [10]:
ranking_relevance_merged_df.head(5)

Unnamed: 0,query_id,passage_id,bm25_rank,query,passage,score_bert,bert_rank,label1,label2,true_label
0,2,1782337,1,androgen receptor define,enzalutamide is an androgen receptor inhibitor...,-1.713628,34,,,0.0
1,2,1001873,2,androgen receptor define,the ar gene provides instructions for making a...,2.283455,5,,,0.0
2,2,4339075,3,androgen receptor define,during androgen independent progression prost...,-1.800174,36,,,0.0
3,2,6285817,4,androgen receptor define,the term sarms stands for aselective androgen ...,2.404728,4,,,0.0
4,2,3634076,5,androgen receptor define,sarms or selective androgen receptor modulator...,2.02906,8,,,0.0


In [11]:
top100_query_ids = np.unique(ranking_relevance_merged_df[ranking_relevance_merged_df['label2'] == 1]['query_id'].tolist())

In [12]:
len(top100_query_ids)

4737

## Label Relevance Rank

#### Sort Query ids

In [13]:
query_label_dict = {}

query_ids_high = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] < 21) & (ranking_relevance_merged_df['label2'] == 1)]['query_id'].tolist()
passage_ids_high = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] < 21) & (ranking_relevance_merged_df['label2'] == 1)]['passage_id'].tolist()
query_ids_medium = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] > 20) & (ranking_relevance_merged_df['bert_rank'] < 81) & (ranking_relevance_merged_df['label2'] == 1)]['query_id'].tolist()
passage_ids_medium = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] > 20) & (ranking_relevance_merged_df['bert_rank'] < 81) & (ranking_relevance_merged_df['label2'] == 1)]['passage_id'].tolist()
query_ids_low = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] > 80) & (ranking_relevance_merged_df['label2'] == 1)]['query_id'].tolist()
passage_ids_low = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] > 80) & (ranking_relevance_merged_df['label2'] == 1)]['passage_id'].tolist()

In [14]:
high_ids = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] < 21) & (ranking_relevance_merged_df['label2'] == 1)][['query_id', 'passage_id']].values.tolist()
medium_ids = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] > 20) & (ranking_relevance_merged_df['bert_rank'] < 81) & (ranking_relevance_merged_df['label2'] == 1)][['query_id', 'passage_id']].values.tolist()
low_ids = ranking_relevance_merged_df[(ranking_relevance_merged_df['bert_rank'] > 80) & (ranking_relevance_merged_df['label2'] == 1)][['query_id', 'passage_id']].values.tolist()

Bert ranks some passages equally, which means that not all queries have a full top 100. Certain passages might share a rank. Also certain queries have more than one msmarco relevant passage. It is key to label them according to passage used for the bm25 labeling.

In [15]:
bert_output_df.groupby(['query_id'])['bert_rank'].sum()

query_id
2          5050
1215       5050
1288       5050
2235       5050
2798       5050
2962       5037
4696       5007
4947       5050
6217       5050
7968       5050
8701       5050
8798       5050
8854       5050
9083       5050
9926       5025
10157      5050
10205      5050
10264      5050
11050      5025
12903      5050
13397      5050
14151      5050
14947      5050
14963      5050
15039      5050
15063      5050
15382      4963
15441      5050
16559      5050
16860      5014
           ... 
1101723    4995
1101739    5050
1101761    5050
1101784    5050
1101806    5050
1101822    4855
1101861    5050
1101868    5050
1101870    5013
1101871    5050
1101902    4892
1101906    5050
1101961    5012
1101977    5050
1102001    5050
1102028    5050
1102088    4932
1102099    4986
1102121    5050
1102177    5050
1102206    5050
1102235    5050
1102240    5050
1102262    5050
1102300    5050
1102325    5050
1102330    5050
1102351    5017
1102393    5050
1102400    5050
Name: bert_rank

#### Create new Dataframe

In [16]:
def getLabel(query_id, passage_id):
    entry = [query_id,passage_id]
    if entry in high_ids:
        return "high"
    elif entry in medium_ids:
        return "medium"
    elif entry in low_ids:
        return "low"
    else:
        return "outside scope"

In [17]:
output_df = pd.DataFrame(top100_query_ids)
output_df.columns = ['query_id']
output_df.shape

(4737, 1)

In [18]:
output_df = output_df.merge(msmarco_relevance_df,how='left',on=['query_id'])
output_df.shape

(5057, 4)

In [19]:
output_df['label'] = output_df.apply(lambda x: getLabel(x.query_id, x.passage_id), axis=1)
output_df.shape

(5057, 5)

In [20]:
ids_of_interest = []
vc = output_df['query_id'].value_counts()
for k,v in vc.items():
    if v > 1:
        ids_of_interest.append(k)

In [21]:
output_df = output_df[~output_df['query_id'].isin(ids_of_interest)]
output_df.shape

(4460, 5)

In [22]:
del output_df['label1']
del output_df['label2']

In [23]:
output_df.to_csv(output_dir + output_filename,sep="\t", header=False,index=False)