In [59]:
# always reload an imported module before executing a particular cell
# (used to let changes in python files take effect without restarting kernel)
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
from sentence_transformers import SentenceTransformer, LoggingHandler
import numpy as np
import os
import pandas as pd

In [61]:
model = SentenceTransformer('bert-base-nli-mean-tokens')

In [62]:
queries_file= '/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/queries.csv'

df = pd.read_csv(queries_file)

df.head(3)

Unnamed: 0,qid,query,rel,type,doc
0,0,are cnn ratings falling,0,original,CNN NewsStand. Not to be confused with routine...
1,0,are cnn ratings falling,0,original,Using Phone Numbers and Addresses. 1 1. Call ...
2,0,are cnn ratings falling,0,original,LSMW is the gold standard tool used by functio...


In [63]:
from abc import ABC, abstractmethod

class Ranker(ABC):
    
    @abstractmethod
    def rank(self, query, docs):
        """
        Return sorted indices
        """
        pass

In [64]:
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import cosine_similarity

class SBERTRanker(Ranker):
    def __init__(self, model):
        self.model = model
        
    def rank(self, query, docs):
        embeddings = model.encode([query] + docs)

        sims = cdist(embeddings[0].reshape(-1,1).T, embeddings[1:], "cosine")[0]
        indexed_sims = enumerate(sims.squeeze())   
        indexed_sims_sorted = sorted(indexed_sims, key=lambda x:x[1], reverse=False)
        indices_sorted = [x[0] for x in indexed_sims_sorted]
        
        return indices_sorted

In [65]:
def calc_rr(rels_sorted, types_sorted, ttype='original', rel_label=1):
    indices = np.where(rels_sorted == rel_label)[0]
    #print(indices)
    original = np.where(types_sorted[indices] == ttype)[0]
    if len(original) == 0:
        raise Exception('no relevant docs of this type')
    return np.mean(1 / (indices[original]+1))


In [66]:
def calculate_mrr_stats(ranker,
                        types=['original', 'degree_2', 'degree_3', 'degree_4', 
                               'degree_4_split', 'degree_8', 'degree_8_split'],
                        limit=None
                       ):

    rr_df = pd.DataFrame(columns=['qid'] + types + ['original_irrelevant'])

    qids = list(set(df['qid']))
    nr_qids = len(qids)
    for i, qid in enumerate(qids):
        
        if limit is not None and i >= limit:
            break
            
        query_df = df[df['qid'] == qid].reset_index(drop=True)
        query = query_df['query'].iloc[0]
        print('query:', query)

        indices_sorted = ranker.rank(query, query_df['doc'].values.tolist())   

        docs_sorted = query_df.iloc[indices_sorted]

        query_rr_row = [qid]
        for ttype in types:
            query_rr_row.append(calc_rr(docs_sorted['rel'].values, docs_sorted['type'].values, ttype=ttype))
        query_rr_row.append(calc_rr(docs_sorted['rel'].values, 
                                    docs_sorted['type'].values, ttype='original', 
                                    rel_label=0))        
        rr_df = rr_df.append(pd.Series(query_rr_row, index=rr_df.columns), ignore_index=True)

        if (i+1) % 5 ==  0:
            print('processed {:d} of {:d}'.format(i+1, nr_qids))
        
    
    return rr_df


In [67]:
ranker = SBERTRanker(model)

rr_df = calculate_mrr_stats(ranker, limit=100)

query: are cnn ratings falling
query: can polycythemia cause stroke
query: cancer of the pancreas symptoms
query: amnesty define
query: can you have a high magnesium with high calcium
processed 5 of 86
query: an electron group is defined as
query: can you get pregnant day one of your period
query: can ciprofloxacin be used to treat pneumonia
query: are most favored nation clauses legal
query: amnesty definition
processed 10 of 86
query: how large is the canadian military
query: foods that will help lower blood sugar
query: cost to install a sump pump
query: ferdinand digital release date
query: how long it takes to get a debit memo at uop
processed 15 of 86
query: how long does blood take to replenish afet lossof blood
query: how long is the typical bungee jump
query: how long is recovery from a broken humerus
query: causes of complex visual hallucinations
query: chart of good and bad cholesterol
processed 20 of 86
query: how many calories a day are lost breastfeeding
query: how much d

In [68]:
rr_df

Unnamed: 0,qid,original,degree_2,degree_3,degree_4,degree_4_split,degree_8,degree_8_split,original_irrelevant
0,0.0,0.125000,0.106061,0.138995,0.069765,0.119697,0.397500,0.069826,0.044553
1,1.0,0.083333,0.191850,0.168333,0.057706,0.100033,0.057555,0.078882,0.173504
2,2.0,0.500000,0.087302,0.097510,0.086710,0.072791,0.063782,0.157887,0.165609
3,3.0,0.071429,0.156136,0.107792,0.074496,0.191685,0.046238,0.066561,0.180717
4,4.0,0.050000,0.288383,0.142120,0.157023,0.096338,0.079176,0.077837,0.084411
...,...,...,...,...,...,...,...,...,...
81,81.0,0.071429,0.209028,0.084828,0.061150,0.128382,0.038967,0.080688,0.200650
82,82.0,0.166667,0.162202,0.108333,0.294850,0.069179,0.094428,0.056634,0.097453
83,83.0,0.142857,0.106755,0.055571,0.064624,0.078687,0.038101,0.051005,0.295871
84,84.0,0.333333,0.347619,0.061651,0.207917,0.046234,0.118189,0.058250,0.049503


In [69]:
from datetime import datetime

def get_timestamp():
    now = datetime.now()
    return now.strftime("%Y-%m-%d_%H-%M")

In [71]:
EVAL_DIR = '/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/eval'

fn = 'rr_queries_model_sbert_' + get_timestamp() + '.csv'
rr_df.to_csv(os.path.join(EVAL_DIR, fn), index=None)