In [1]:
# always reload an imported module before executing a particular cell
# (used to let changes in python files take effect without restarting kernel)
%load_ext autoreload
%autoreload 2


In [2]:
from sentence_transformers import SentenceTransformer, LoggingHandler
import numpy as np
import os
import pandas as pd

In [3]:
model = SentenceTransformer('bert-base-nli-mean-tokens')


In [4]:
sentences = ['can polycythemia cause stroke', 
    'Thrombotic strokes can affect large or small arteries in the brain. When a thrombotic stroke occurs in a small artery deep within the brain, the stroke is called a lacunar stroke. Embolic strokes - In an embolic stroke, a blood clot or other solid mass of debris travels to the brain, where it blocks a brain artery.',
     "Menstrual bleeding usually results from a decrease in natural hormone levels about 14 days after the ovulation, if you're not pregnant. The average woman takes one month to three months to start ovulating again after stopping the pill. Sometimes ovulation may occur sooner; other times, it may take longer.f you don't get your period for some time after stopping the pill chances are that you are either pregnant (do a pregnancy test!) or you did not ovulate. Even without getting your period first there might be a chance you are pregnant.",
     "ice cream"]
sentence_embeddings = model.encode(sentences)

In [5]:
len(sentence_embeddings)

4

In [6]:
sentence_embeddings[1].shape

(768,)

In [7]:
for sentence, embedding in zip(sentences, sentence_embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding.shape)
    print("")

Sentence: can polycythemia cause stroke
Embedding: (768,)

Sentence: Thrombotic strokes can affect large or small arteries in the brain. When a thrombotic stroke occurs in a small artery deep within the brain, the stroke is called a lacunar stroke. Embolic strokes - In an embolic stroke, a blood clot or other solid mass of debris travels to the brain, where it blocks a brain artery.
Embedding: (768,)

Sentence: Menstrual bleeding usually results from a decrease in natural hormone levels about 14 days after the ovulation, if you're not pregnant. The average woman takes one month to three months to start ovulating again after stopping the pill. Sometimes ovulation may occur sooner; other times, it may take longer.f you don't get your period for some time after stopping the pill chances are that you are either pregnant (do a pregnancy test!) or you did not ovulate. Even without getting your period first there might be a chance you are pregnant.
Embedding: (768,)

Sentence: ice cream
Embed

In [8]:
sentence_embeddings[0].shape

(768,)

In [9]:
# Evaluation below

In [10]:
from sklearn.metrics.pairwise import cosine_similarity

In [11]:
cosine_similarity(sentence_embeddings[0].reshape(-1,1).T, Y=sentence_embeddings[1:])

array([[0.6045091 , 0.4425732 , 0.12299903]], dtype=float32)

In [12]:
from scipy.spatial.distance import cdist

In [13]:
cdist(sentence_embeddings[0].reshape(-1,1).T, 
                             sentence_embeddings[1:], "cosine")[0]

array([0.39549089, 0.55742685, 0.87700098])

In [14]:
queries_file= '/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/queries_od.csv'

df = pd.read_csv(queries_file)

df.head(3)

Unnamed: 0,qid,query,rel,type,doc
0,0,are cnn ratings falling,0,original,Mustard greens are also a good food choice for...
1,0,are cnn ratings falling,0,original,The only concessions Jay obtained was a surren...
2,0,are cnn ratings falling,0,original,Allen: Constitution Prevails Over President's ...


In [15]:
def calc_rr(rels_sorted, types_sorted, ttype='original', rel_label=1):
    indices = np.where(rels_sorted == rel_label)[0]
    #print(indices)
    original = np.where(types_sorted[indices] == ttype)[0]
    if len(original) == 0:
        raise Exception('no relevant docs of this type')
    return np.mean(1 / (indices[original]+1))


In [16]:
from abc import ABC, abstractmethod

class Ranker(ABC):
    
    @abstractmethod
    def rank(self, query, docs):
        """
        Return sorted indices
        """
        pass

In [47]:
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import cosine_similarity

class SBERTRanker(Ranker):
    def __init__(self, model):
        self.model = model
        
    def rank(self, query, docs):
        embeddings = model.encode([query] + docs)

        sims = cdist(embeddings[0].reshape(-1,1).T, embeddings[1:], "cosine")[0]
        indexed_sims = enumerate(sims.squeeze())   
        indexed_sims_sorted = sorted(indexed_sims, key=lambda x:x[1], reverse=False)
        indices_sorted = [x[0] for x in indexed_sims_sorted]
        
        return indices_sorted

In [48]:
def calculate_mrr_stats(ranker,
                        types=['original', 'degree_2', 'degree_3', 'degree_4', 
                               'degree_4_split', 'degree_8', 'degree_8_split'],
                        limit=None
                       ):

    rr_df = pd.DataFrame(columns=['qid'] + types + ['original_irrelevant'])

    qids = list(set(df['qid']))
    nr_qids = len(qids)
    for i, qid in enumerate(qids):
        
        if limit is not None and i >= limit:
            break
            
        query_df = df[df['qid'] == qid].reset_index(drop=True)
        query = query_df['query'].iloc[0]
        print('query:', query)

        indices_sorted = ranker.rank(query, query_df['doc'].values.tolist())   

        docs_sorted = query_df.iloc[indices_sorted]

        query_rr_row = [qid]
        for ttype in types:
            query_rr_row.append(calc_rr(docs_sorted['rel'].values, docs_sorted['type'].values, ttype=ttype))
        query_rr_row.append(calc_rr(docs_sorted['rel'].values, 
                                    docs_sorted['type'].values, ttype='original', 
                                    rel_label=0))        
        rr_df = rr_df.append(pd.Series(query_rr_row, index=rr_df.columns), ignore_index=True)

        if (i+1) % 5 ==  0:
            print('processed {:d} of {:d}'.format(i+1, nr_qids))
        
    
    return rr_df

ranker = SBERTRanker(model)

rr_df = calculate_mrr_stats(ranker, limit=2)

query: are cnn ratings falling
query: can polycythemia cause stroke


In [49]:
rr_df

Unnamed: 0,qid,original,degree_2,degree_3,degree_4,degree_4_split,degree_8,degree_8_split,original_irrelevant
0,0.0,0.25,0.149463,0.399731,0.076326,0.081496,0.044835,0.109524,0.03887
1,1.0,0.5,0.172222,0.078781,0.294139,0.063667,0.084675,0.066787,0.049164


In [281]:
from datetime import datetime

def get_timestamp():
    now = datetime.now()
    return now.strftime("%Y-%m-%d_%H-%M")

In [282]:
EVAL_DIR = '/run/media/root/Windows/Users/agnes/Downloads/data/msmarco/eval'

fn = 'rr_queries_od_model_sbert_' + get_timestamp() + '.csv'
rr_df.to_csv(os.path.join(eval_dir, fn), index=None)