In [10]:
# always reload an imported module before executing a particular cell
# (used to let changes in python files take effect without restarting kernel)
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
from sentence_transformers import SentenceTransformer, LoggingHandler
import numpy as np
import os
import pandas as pd
from gensim.test.utils import common_texts
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
import nltk
from nltk.corpus import stopwords


In [14]:
queries_file= '/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/queries3.csv'

df = pd.read_csv(queries_file, error_bad_lines=False)

df.head(3)

Unnamed: 0,qid,query,rel,type,doc
0,0,aaa a common cause of a skid is,0,original,Discounts and benefits are available at all He...
1,0,aaa a common cause of a skid is,0,original,AAA North Penn provides Authorized On - Line P...
2,0,aaa a common cause of a skid is,0,original,â¢ EAP-TLS authentication takes place between...


In [21]:
from abc import ABC, abstractmethod

class Ranker(ABC):
    
    @abstractmethod
    def rank(self, query, docs):
        """
        Return sorted indices
        """
        pass
    
    @staticmethod
    def sims_to_indices(sims):
        indexed_sims = enumerate(sims.squeeze())   
        indexed_sims_sorted = sorted(indexed_sims, key=lambda x:x[1], reverse=False)
        indices_sorted = [x[0] for x in indexed_sims_sorted]
        return indices_sorted

In [22]:
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import cosine_similarity

class SBERTRanker(Ranker):
    def __init__(self, model):
        self.model = model
        
    def rank(self, query, docs):
        embeddings = model.encode([query] + docs)

        sims = cdist(embeddings[0].reshape(-1,1).T, embeddings[1:], "cosine")[0]
        indices_sorted = self.sims_to_indices(sims)
        
        return indices_sorted

In [12]:
model = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')

In [26]:
sbert_ranker = SBERTRanker(model)

In [54]:
sbert_ranker.rank('i saw a cat', ['the cat was seen by me', 'pineapple on pizza is bad', 'i saw a cat'])

[2, 0, 1]

In [49]:
class Doc2vecRanker(Ranker):
    
    stopwords = set(stopwords.words('english'))
    
    def __init__(self, model_fn):
        self.model = Doc2Vec.load(model_fn)        
    
    def _preprocess_text(self, text):
        words = nltk.word_tokenize(text)
        words = [word.lower() for word in words if word.isalpha()]
        words = [w for w in words if not w in self.stopwords]
        return words
    
    def rank(self, query, docs):
        query_vector = self.model.infer_vector(self._preprocess_text(query))
        doc_vectors = [self.model.infer_vector(self._preprocess_text(d)) for d in docs]
        
        sims = cdist(query_vector.reshape(-1,1).T, np.array(doc_vectors), "cosine")[0]
        
        print(sims)
        indices_sorted = self.sims_to_indices(sims)
        
        return indices_sorted

In [50]:
d2v_ranker = Doc2vecRanker('/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/train_results/doc2vec2/d2v_2')

In [53]:
d2v_ranker.rank('i saw a cat', ['the cat was seen by me', 'cat cat cat', 'i saw a cat'])

[1.02262315 0.9711936  0.        ]


[2, 1, 0]

In [None]:
"""
# Compute the normalized LCS given an answer text and a source text
def lcs_norm_word(answer_text, source_text):
    '''Computes the longest common subsequence of words in two texts; returns a normalized value.
       :param answer_text: The pre-processed text for an answer text
       :param source_text: The pre-processed text for an answer's associated source text
       :return: A normalized LCS value'''
    
    answer_text = answer_text.split()
    source_text = source_text.split()
    
    m = np.zeros((len(answer_text)+1, len(source_text)+1))

    for i in range(1, len(source_text)+1):
        for j in range(1, len(answer_text)+1):
            if source_text[i-1] == answer_text[j-1]:
                m[j,i] = m[j-1,i-1] +1
            else:
                m[j,i] = max(m[j-1,i], m[j,i-1])
    return m[len(answer_text),len(source_text)] / len(answer_text)

"""

In [7]:
def calc_rr(rels_sorted, types_sorted, ttype='original', rel_label=1):
    indices = np.where(rels_sorted == rel_label)[0]
    #print(indices)
    original = np.where(types_sorted[indices] == ttype)[0]
    if len(original) == 0:
        raise Exception('no relevant docs of this type')
    return np.mean(1 / (indices[original]+1))


In [8]:
def calculate_mrr_stats(ranker,
                        types=['original', 'degree_2', 'degree_3', 'degree_4', 
                               'degree_4_split', 'degree_8', 'degree_8_split'],
                        limit=None
                       ):

    rr_df = pd.DataFrame(columns=['qid'] + types + ['original_irrelevant'])

    qids = list(set(df['qid']))
    nr_qids = len(qids)
    for i, qid in enumerate(qids):
        
        if limit is not None and i >= limit:
            break
            
        query_df = df[df['qid'] == qid].reset_index(drop=True)
        query = query_df['query'].iloc[0]
        print('query:', query)

        indices_sorted = ranker.rank(query, query_df['doc'].values.tolist())   

        docs_sorted = query_df.iloc[indices_sorted]

        query_rr_row = [qid]
        for ttype in types:
            query_rr_row.append(calc_rr(docs_sorted['rel'].values, docs_sorted['type'].values, ttype=ttype))
        query_rr_row.append(calc_rr(docs_sorted['rel'].values, 
                                    docs_sorted['type'].values, ttype='original', 
                                    rel_label=0))        
        rr_df = rr_df.append(pd.Series(query_rr_row, index=rr_df.columns), ignore_index=True)

        if (i+1) % 5 ==  0:
            print('processed {:d} of {:d}'.format(i+1, nr_qids))
        
    
    return rr_df


In [9]:
ranker = SBERTRanker(model)

rr_df = calculate_mrr_stats(ranker, limit=400)

query: aaa a common cause of a skid is
query: chionophobia is an abnormal fear of what
query: abcmouse price
query: abnormal condition of blood in a joint
query: about how long can a bald eagle live?
processed 5 of 20117
query: 1 cup of strawberries nutrition facts
query: sprint pay by phone telephone number
query: is lameness a type of abnormal locomotion
query: what is pica condition
query: . in what kind of government does a small group have a firm control over a country? brainly
processed 10 of 20117
query: aaa templin phone number
query: according to the amdr what percentage of your daily intake should be protein
query: how to care for goji berry plants
query: how to become a firefighter in pa
query: how costly is it to remove mold from crawl space
processed 15 of 20117
query: do raccoons eat japanese beetles
query: what is a raccoon dog
query: how to calculate growth of dividend
query: how to cancel your zipcar membership
query: what helps get rid of moisture in basement
processe

query: how mych does golds gym membership cost in el paso?
query: krewe of femme fatale membership cost
query: what qualifications do a gym general manager needs
processed 180 of 20117
query: gold gym membership cost
query: how much is a anytime fitness membership
query: monthly cost of a horse
query: adt average monthly cost
query: which cost is an example of a variable cost?
processed 185 of 20117
query: cost of planet fitness membership
query: what's the cost of a costco executive membership
query: average entertainment cost per month
query: monthly cost of avastin cancer treatment
query: monthly t1 line cost
processed 190 of 20117
query: what are david barton gyms
query: usga basic membership cost
query: monthly cost for hulu
query: types of membership models
query: orange theory fitness membership cost
processed 195 of 20117
query: how much is membership at a fitness club
query: average human stomach size
query: how much do gold's gym employees make
query: how much is a kickboxing

processed 350 of 20117
query: how to be a child psychologist
query: average salary human resources recruiter
query: average salary first year lawyer
query: what units are used to measure both velocity and speed
query: average speed on a bike
processed 355 of 20117
query: typing speed average
query: schooling you need to become a child psychologist
query: if quantity demanded goes down what happens to total revenue
query: is acceleration considered a vector or scalar
query: educational psychologist salary
processed 360 of 20117
query: average income clinical psychologist
query: most probable speed formula for gases
query: what is the average starting salary for an art therapist?
query: average salary for a school psychologist
query: average vs instantaneous velocity
processed 365 of 20117
query: average salary medical oncology
query: average salary clinical data coordinator
query: what is the relationship between displacement and the spring constant k
query: average hourly rate for phil

In [10]:
rr_df

Unnamed: 0,qid,original,degree_2,degree_3,degree_4,degree_4_split,degree_8,degree_8_split,original_irrelevant
0,0.0,0.500000,0.206250,0.136905,0.296259,0.057528,0.040323,0.066947,0.046494
1,1.0,1.000000,0.213745,0.105420,0.091106,0.148611,0.055914,0.078646,0.039379
2,2.0,0.071429,0.051042,0.054717,0.078517,0.169338,0.182870,0.085419,0.191220
3,3.0,1.000000,0.263278,0.064440,0.057346,0.068040,0.148898,0.089815,0.040191
4,4.0,1.000000,0.200505,0.178030,0.049946,0.114469,0.077782,0.074963,0.038252
...,...,...,...,...,...,...,...,...,...
395,395.0,0.111111,0.110227,0.081496,0.052324,0.047498,0.418779,0.156481,0.063808
396,396.0,0.083333,0.117469,0.060886,0.429825,0.101190,0.066011,0.108352,0.058816
397,397.0,1.000000,0.138492,0.189732,0.054852,0.129167,0.046634,0.136555,0.038384
398,398.0,1.000000,0.134733,0.172393,0.100603,0.056061,0.144037,0.075956,0.044208


In [11]:
from datetime import datetime

def get_timestamp():
    now = datetime.now()
    return now.strftime("%Y-%m-%d_%H-%M")

In [None]:
EVAL_DIR = '/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/eval'

fn = 'rr_queries3_od_model_roberta_400_' + get_timestamp() + '.csv'
rr_df.to_csv(os.path.join(EVAL_DIR, fn), index=None)