In [5]:
import pickle
from elasticsearch import Elasticsearch
from sentence_transformers import SentenceTransformer

with open('fixed_ground_truth.pkl','rb') as infile:
    data = pickle.load(infile)

In [6]:
model_name = 'multi-qa-MiniLM-L6-cos-v1'
model = SentenceTransformer(model_name)
es = Elasticsearch("http://localhost:9200")

In [7]:
load_data = []
for i in data:
    d = {}
    d['title'] = i['title']
    d['text'] = i['text']
    d['timecode_text'] = i['timecode_text']
    d['description'] = i['description']
    d['id'] = i['id']
    d['title_vector'] = model.encode(i['title'])
    d['timecode_vector'] = model.encode(i['timecode_text'])
    d['text_vector'] = model.encode(i['text'])
    d['description_vector'] = model.encode(i['description'])
    load_data.append(d)




In [8]:
from elasticsearch import Elasticsearch
es_client = Elasticsearch('http://localhost:9200') 

index_settings = {
    "settings": {
        "number_of_shards": 1,
        "number_of_replicas": 0
    },
    "mappings": {
        "properties": {
            "title": {"type": "text"},
            "text": {"type": "text"},
            "timecode_text": {"type": "text"},
            "description": {"type": "keyword"},
            "id": {"type": "keyword"},
            "title_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
            "timecode_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
            "text_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
            "description_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
        }
        }
    }


index_name = "test-search"

es_client.indices.delete(index=index_name, ignore_unavailable=True)
es_client.indices.create(index=index_name, body=index_settings)

ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'test-search'})

In [9]:
for doc in load_data:
    try:
        es_client.index(index=index_name, document=doc)
    except Exception as e:
        print(e)

In [10]:

def knn_query(question):
    return  {
        "field": "text_vector",
        "query_vector": model.encode(question),
        "k": 5,
        "num_candidates": 10000,
        "boost": 0.5,
        
    }

In [11]:
def keyword_query(question):
    return {
        "bool": {
            "must": {
                "multi_match": {
                    "query": f"{question}",
                    "fields": ["description^3", "text", "title"],
                    "type": "best_fields",
                    "boost": 0.5,
                }
            },
        }
    }

In [14]:
def multi_search(key_word):
    response = es_client.search(
        index=index_name,
        query=keyword_query(key_word),
        knn=knn_query(key_word),
        size=10
    )
    return response["hits"]["hits"]

In [15]:
from tqdm.auto import tqdm
relevance_total = []
for q in tqdm(data):
    doc_id = q['id']
    results = multi_search(q['student_question'])
    relevance = [d["_source"]['id'] == doc_id for d in results]
    relevance_total.append(relevance)

  0%|          | 0/837 [00:00<?, ?it/s]

In [16]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)
print(f"Hit rate is: {hit_rate(relevance_total)}")

Hit rate is: 0.5579450418160096


In [17]:
def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)
print(f"MRR is: {mrr(relevance_total)}")

MRR is: 0.737086818000797


In [21]:
# import hashlib
# def gen_ids(data):
#     for rec in data:
#         unique_id = rec['vid_id'].strip()+rec['timecode'].strip()
#         hash_object = hashlib.md5(unique_id.encode())
#         hash_hex = hash_object.hexdigest()
#         rec['id'] = hash_hex
#     return data
# data_with_ids = gen_ids(data)