In [3]:
import json
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
from elasticsearch import Elasticsearch
import pandas as pd
from langchain.embeddings import SentenceTransformerEmbeddings
from typing import Dict
from langchain_elasticsearch import ElasticsearchRetriever

In [4]:

embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")

es_url = 'http://localhost:9201'


index_name = "general-questions-vector"

  embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")


In [5]:
ground_truth = (pd.read_csv('../03-vector-search/ground-truth-data.csv')).to_dict(orient='records')

In [6]:
ground_truth[0]

{'question': 'On what date and time does the course commence?',
 'course': 'data-engineering-zoomcamp',
 'document': '23cb47db'}

In [7]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

In [8]:
def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

In [9]:
def elastic_search_hybrid(field, query, course):
    def hybrid_query(search_query: str) -> Dict:
        vector = embeddings.embed_query(search_query)  # same embeddings as for indexing
        return {
            "query": {
                "bool": {
                    "must": {
                        "multi_match": {
                            "query": search_query,
                            "fields": ["question", "text", "section"],
                            "type": "best_fields",
                            "boost": 0.5,
                        }
                    },
                    "filter": {
                        "term": {
                            "course": course
                        }
                    }
                }
            },
            "knn": {
                "field": field,
                "query_vector": vector,
                "k": 5,
                "num_candidates": 10000,
                "boost": 0.5,
                "filter": {
                    "term": {
                        "course": course
                    }
                }
            },
            "size": 5,
            "_source": ["text", "section", "question", "course", "id"],
            # "rank": {"rrf": {}},
        }
    
    
    hybrid_retriever = ElasticsearchRetriever.from_es_params(
        index_name=index_name,
        body_func=hybrid_query,
        content_field='text',
        url=es_url,
    )

    hybrid_results = hybrid_retriever.invoke(query)
    
    result_docs = []
    
    for hit in hybrid_results:
        result_docs.append(hit.metadata['_source'])

    return result_docs

In [10]:
def question_hybird(q,vector_type):
    question = q['question']
    course = q['course']
    
    return elastic_search_hybrid(vector_type,question,course)

In [11]:
def evaluate(ground_truth,search_function,vector_type):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']
        results = search_function(q,vector_type)
        relevance = [d['id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'vector_type': vector_type,
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [12]:
vector_types = ["text_vector","question_vector","question_text_vector"]


for vector_type in vector_types:
    result = evaluate(ground_truth,question_hybird,vector_type)
    print(result)

100%|██████████| 4664/4664 [02:03<00:00, 37.88it/s]


{'vector_type': 'text_vector', 'hit_rate': 0.9187392795883362, 'mrr': 0.8196433676386513}


100%|██████████| 4664/4664 [02:01<00:00, 38.31it/s]


{'vector_type': 'question_vector', 'hit_rate': 0.9195969125214408, 'mrr': 0.8249285305889086}


100%|██████████| 4664/4664 [02:03<00:00, 37.76it/s]

{'vector_type': 'question_text_vector', 'hit_rate': 0.9223842195540308, 'mrr': 0.8255253001715269}



