In [1]:
from pymilvus import MilvusClient, DataType
from openai import OpenAI, Embedding
from sentence_transformers import SentenceTransformer
import pandas as pd
import time
import numpy as np
import json


In [2]:
milvus_client = MilvusClient(
    uri="http://localhost:19530"
)


In [3]:
queries = pd.read_csv("../queries_for_documents.csv").tail(50)

In [4]:
len(queries)

50

In [5]:
### BERT Embeddings
# model_name = "sentence-transformers/all-mpnet-base-v2"
model_name = "all-MiniLM-L6-v2"
model = SentenceTransformer(model_name)
def get_embeddings_from_bert(sentence):
    return model.encode([sentence])[0]





In [6]:
queries.head()

Unnamed: 0,query,doc_index
25,Who are the plaintiffs involved in the Mazzaga...,556
26,How much of a hit did Best Buy's stock take af...,302
27,How are medicare pricing adjustments affecting...,312
28,How much has twitter been devalued since rebra...,314
29,What effect is China's imports on crude oil ha...,545


In [7]:
def compute_score(row, p):

    query, ground_truth_document_id = row["query"], row["doc_index"]
    collection_name = f"yahoo_finance_article_DROPOUT_{int(p * 100)}"
    embedded_query = get_embeddings_from_bert(query)

    start = time.time()
    
    res = milvus_client.search(
        collection_name=collection_name,
        data=[embedded_query],
        limit=5,
        # search_params={"metric_type": "IP", "params": {}}
        search_params={"metric_type": "COSINE", "params": {}}
    )
    
    end = time.time()
    exec_time = end - start

    top1 = 1 if ground_truth_document_id in [r["id"] for r in res[0]][:1] else 0
    top3 = 1 if ground_truth_document_id in [r["id"] for r in res[0]][:3] else 0
    top5 = 1 if ground_truth_document_id in [r["id"] for r in res[0]][:5] else 0

    return (top1, top3, top5, exec_time)

In [8]:
for p in [0, 0.1, 0.3, 0.5, 0.7, 0.9]:
    # queries["scores"] = queries["query"].map(lambda x: compute_score(x, p))
    queries[f"scores_{p}"] = queries.apply(lambda x: compute_score(x, p), axis=1)



In [9]:
queries

Unnamed: 0,query,doc_index,scores_0,scores_0.1,scores_0.3,scores_0.5,scores_0.7,scores_0.9
25,Who are the plaintiffs involved in the Mazzaga...,556,"(1, 1, 1, 0.0217742919921875)","(1, 1, 1, 0.00596165657043457)","(0, 1, 1, 0.007549762725830078)","(0, 0, 0, 0.006362438201904297)","(1, 1, 1, 0.0060079097747802734)","(1, 1, 1, 0.006003141403198242)"
26,How much of a hit did Best Buy's stock take af...,302,"(0, 1, 1, 0.006374835968017578)","(1, 1, 1, 0.006056070327758789)","(0, 1, 1, 0.006224393844604492)","(0, 0, 0, 0.006007671356201172)","(0, 0, 0, 0.007529258728027344)","(0, 0, 0, 0.006522178649902344)"
27,How are medicare pricing adjustments affecting...,312,"(1, 1, 1, 0.006029844284057617)","(1, 1, 1, 0.0045473575592041016)","(1, 1, 1, 0.0060272216796875)","(1, 1, 1, 0.006700277328491211)","(1, 1, 1, 0.005166292190551758)","(0, 0, 0, 0.004997968673706055)"
28,How much has twitter been devalued since rebra...,314,"(1, 1, 1, 0.0055086612701416016)","(0, 0, 1, 0.0053081512451171875)","(0, 0, 1, 0.00589442253112793)","(0, 1, 1, 0.005031108856201172)","(0, 0, 1, 0.006310701370239258)","(0, 1, 1, 0.0064618587493896484)"
29,What effect is China's imports on crude oil ha...,545,"(0, 0, 0, 0.00652623176574707)","(0, 0, 0, 0.0044286251068115234)","(1, 1, 1, 0.005524635314941406)","(1, 1, 1, 0.004998207092285156)","(0, 0, 0, 0.005009889602661133)","(0, 1, 1, 0.007614850997924805)"
30,Why is Brazil looking to make income tax cuts ...,12,"(1, 1, 1, 0.00732111930847168)","(1, 1, 1, 0.006031036376953125)","(1, 1, 1, 0.006000518798828125)","(1, 1, 1, 0.007057905197143555)","(1, 1, 1, 0.005017280578613281)","(1, 1, 1, 0.0069332122802734375)"
31,What background checking service do large firm...,25,"(0, 1, 1, 0.007109880447387695)","(0, 0, 0, 0.005045652389526367)","(0, 0, 1, 0.005025625228881836)","(1, 1, 1, 0.006009817123413086)","(0, 1, 1, 0.004525184631347656)","(0, 1, 1, 0.004999399185180664)"
32,How could President-elect Donald Trump's propo...,238,"(1, 1, 1, 0.006033182144165039)","(0, 0, 1, 0.006503582000732422)","(0, 1, 1, 0.006495952606201172)","(1, 1, 1, 0.00570225715637207)","(0, 0, 0, 0.005112648010253906)","(0, 1, 1, 0.005901336669921875)"
33,How is scaled solutions providing value to cli...,70,"(0, 1, 1, 0.006047725677490234)","(1, 1, 1, 0.005007028579711914)","(0, 0, 0, 0.005387306213378906)","(1, 1, 1, 0.005506038665771484)","(0, 0, 0, 0.0050046443939208984)","(0, 0, 1, 0.005006551742553711)"
34,How big of a problem was it for Symbotic to de...,121,"(0, 1, 1, 0.006068706512451172)","(1, 1, 1, 0.004998683929443359)","(0, 1, 1, 0.0050563812255859375)","(0, 1, 1, 0.006298065185546875)","(0, 1, 1, 0.007357120513916016)","(0, 1, 1, 0.0051555633544921875)"


In [10]:
scores_percentage = {}

for p in [0, 0.1, 0.3, 0.5, 0.7, 0.9]:
    scores_percentage[str(p)] = tuple(
        sum(values) / len(values) for values in zip(*queries[f'scores_{p}'])
    )

In [11]:
scores_percentage

{'0': (0.56, 0.84, 0.86, 0.00627105712890625),
 '0.1': (0.56, 0.74, 0.82, 0.005938982963562012),
 '0.3': (0.54, 0.74, 0.84, 0.006064205169677734),
 '0.5': (0.5, 0.7, 0.72, 0.005978379249572754),
 '0.7': (0.44, 0.64, 0.68, 0.005655636787414551),
 '0.9': (0.2, 0.44, 0.52, 0.006411447525024414)}

In [46]:
sums = {}


for p in [0, 0.1, 0.3, 0.5, 0.7, 0.9]:
    sums[str(p)] = tuple(
        sum(values)for values in zip(*queries[f'scores_{p}'])
    )
sums

{'0': (35, 43, 45, 0.22560715675354004),
 '0.1': (29, 40, 42, 0.22805333137512207),
 '0.3': (30, 43, 44, 0.21669435501098633),
 '0.5': (34, 40, 41, 0.21144366264343262),
 '0.7': (24, 30, 32, 0.22225189208984375),
 '0.9': (17, 27, 30, 0.22592711448669434)}