#### Week 2: Vector Search Applications w/ LLMs.  Authored by Chris Sanchez.

# Week 2 - Notebook 4

# Overview
We will divide the approach to building our system into four parts over two weeks-

#### Week One
* ~Part 1:~
  * ~Data ingest and preprocessing~
  * ~Convert text into vectors~
* ~Part 2~:
  * ~Index data on Weaviate database~
  * ~Search over data~
* ~Part 2.5~:
  * ~Benchmark retrieval results~

#### Week Two
* Part 3 **(THIS NOTEBOOK)**:
  * Add a reranker to the mix (new benchmark)
* Part 4:
  * Integrate with GPT-Turbo
* Part 5:
  * Display results in Streamlit

#### This notebook will cover the highlighted portion of the technical diagram below, as initially referenced in the Course content:

# PLACEHOLDER CELL FOR ARCH DIAGRAM

In [3]:
from sentence_transformers import CrossEncoder
from weaviate_interface import WeaviateClient
from typing import List, Union
import numpy as np
from loguru import logger
from reranker import ReRanker
import pandas as pd
from rich import print  # nice library that provides improved printing output (overrides default print function)
from tqdm.notebook import tqdm
import os

#load from local .env file
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(), override=True)

from preprocessing import FileIO

# Assignment 2.3 - Build a Hybrid Search method
**Implement a hybrid search method on the Weaviate python client**. 

#### Instructions:
- Fill in the areas of the code wherever you see a `None` statement.
- `fusion_type` is one of the hybrid search parameters.  Leave this value as [**`relativeScoreFusion`**](https://weaviate.io/blog/hybrid-search-fusion-algorithms) in your code, doing so will ensure that you are using Weaviate's preferred ranking alogrithm for hybrid search. 

In [None]:
#configure search constants
class_name = None
query = None
query_embedding = None

#design hybird search query
'''
Get objects using bm25 and vector, then combine the results using a reciprocal ranking algorithm.

Args
----
query: str
    User query.
class_name: str
    Class (index) to search.
properties: List[str]
    List of properties to search across (using BM25)
alpha: float=0.5
    Weighting factor for BM25 and Vector search.
    alpha can be any number from 0 to 1, defaulting to 0.5:
        alpha = 0 executes a pure keyword search method (BM25)
        alpha = 0.5 weighs the BM25 and vector methods evenly
        alpha = 1 executes a pure vector search method
limit: int=10
    Number of results to return.
display_properties: List[str]=None
    List of properties to return in response.
    If None, returns all properties.
'''

response = (client.query
            
 .get(None, None). #reminder that the properties param here refers to the "display_properties"
            
 # use near_vector our search method, and only search over the "content" property
 .with_hybrid(query=None,
              alpha=None,
              vector=None,
              properties=None,
              fustion_type='relativeScoreFusion')

 # instead of "score", vector search can return a "distance" property for scoring, the smaller the distance, the more semantically similar is the result
 .with_additional(['score', 'distance', 'explainScore'])
            
 # limit the returned results to the top-3 ranked hits
 .with_limit(None)
            
 # execute the search with the "do" command
 .do()
)

# To show cleaned up results we'll use the built-in format response method
print(client.format_response(response, class_name))

In [3]:

class ReRanker:
    '''
    Cross-Encoder models achieve higher performance than Bi-Encoders, 
    however, they do not scale well to large datasets. The lack of scalability
    is due to the underlying cross-attention mechanism, which is computationally
    expensive.  Thus a Bi-Encoder is best used for 1st-stage document retrieval and 
    a Cross-Encoder is used to re-rank the retrieved documents. 

    https://www.sbert.net/examples/applications/cross-encoder/README.html
    '''

    def __init__(self, model_name: str='cross-encoder/ms-marco-MiniLM-L-6-v2', local_files: bool=False):
        self.model_name = model_name
        self.model = CrossEncoder(self.model_name, automodel_args={'local_files_only':local_files})
        self.score_field = 'cross_score'

    def _cross_encoder_score(self, 
                             results: List[dict], 
                             query: str, 
                             cross_score_key: str='cross-score',
                             return_scores: bool=False
                             ) -> Union[np.array, None]:
        '''
        Given a list of hits from a Retriever:
            1. Scores hits by passing query and results through CrossEncoder model. 
            2. Adds cross-score key to hits dictionary. 
            3. If desired returns np.array of Cross Encoder scores.

        Args
        ----
        results: List[dict]
            List of search results from OpenSearch client.
        query: str
            User query.
        cross_score_key: str='cross-score'
            Name of key/field that the new calculated cross encoder score will be associated with.
        return_scores: bool=False
            If True, returns a np.array of cross encoder scores. 

        Returns
        -------
        Either returns a np.array of cross encoder scores if "return_scores" is True, otherwise
        nothing is returned.  The primary purpose of this function is to update the "results" dict. 
        '''
        
        #build query/content list 
        #create a list of lists that contains the query and the content field of each result from "results"
        #important the cross_input variable must be a list of lists
        cross_input = None
        
        #get scores
        # Call the self.model's predict method to get predicted scores on the cross_input
        # Output at this step will be a numpy matrix of cross-encoder scores
        ######################################################################
        # Example: 
        # array([ 1.3296969 ,  0.8297793 ,  1.2054391 ,  2.9448447 ,  2.7284985 ,
        #         4.231843  , -1.6208533 ,  2.4096487 , -1.2081863 ,  2.9743905 ,
        #         3.2194595 , -0.27501446,  1.5256095 ,  2.8193645 ,  1.5568736 ,
        #         2.5138354 ,  1.9419916 ,  2.6341028 , -1.6115644 , -0.49818742,
        #         3.695484  ,  2.93317   ,  3.1728778 , -0.5114989 , -4.076729  ], dtype=float32)

        cross_scores = None

        #enumerate through the results and update each dict with the cross_score_key arg as key and value as the new score:
        #Example:
             # {'cross-score' : 5.12345}
        for i, result in enumerate(results):
            None

        if return_scores:
            return cross_scores

    def rerank(self, results: List[dict], query: str, top_k: int=10, threshold: float=None) -> List[dict]:
        '''
        Given a list of search results from OpenSearch client, results are scored with a Cross Encoder 
        and returned in sorted order by the cross_score field.  Threshold allows user to filter out 
        results that do not meet cross_score threshold value:

        Args
        ----
        results: List[dict]
            List of search results from OpenSearch client.
        query: str
            User query.
        top_k: int=10
            Number of reranked results to return
        threshold: float=None
            If None, top_k results will be returned.  
            If float value is present, only results with a cross_score that meet or exceed the threshold
            will be retuned.  This arg is present to prevent very low scoring document from being returned. 

        Returns
        -------
        List of reranked search results. 
        '''
        # call the internal _cross_encoder_score function (it's ok that nothing is returned here)
        # the results dictionary is being updated 
        None

        #sort results by the new cross-score field
        sorted_hits = None

        #if user wants to set a threshold we need to account for that
        if threshold or threshold == 0:

            #filter sorted_hits by the threshold value
            filtered_hits = None
            
            if not any(filtered_hits):
                logger.warning(f'No hits above threshold {threshold}. Returning top {top_k} hits.')
                return sorted_hits[:top_k]
            return filtered_hits
            
        #if no threshold was set return top_k sorted_hits
        return None

### Instantiate the ReRanker instance

In [35]:
ranker = ReRanker()

In [43]:
ranker._cross_encoder_score(response, query=query, return_scores=True)

array([ 1.3296969 ,  0.8297793 ,  1.2054391 ,  2.9448447 ,  2.7284985 ,
        4.231843  , -1.6208533 ,  2.4096487 , -1.2081863 ,  2.9743905 ,
        3.2194595 , -0.27501446,  1.5256095 ,  2.8193645 ,  1.5568736 ,
        2.5138354 ,  1.9419916 ,  2.6341028 , -1.6115644 , -0.49818742,
        3.695484  ,  2.93317   ,  3.1728778 , -0.5114989 , -4.076729  ],
      dtype=float32)

In [5]:
#set index name
index_name = 'impact-theory-minilm-196'

In [6]:
#set query
query = "how do I change my life for good"

### Test ReRanker class by conducting a hybrid search

# Evaluation of Reranker Effect on Latency

In [None]:
def time_search(limit: int, rerank: bool=True):
    start = time.perf_counter()
    response = client.keyword_search('how do I make a million dollars', 'Impact_theory_minilm_256', limit=limit, display_properties=['content', 'title'])
    reranked = reranker.rerank(response)
    end = time.perf_counter() - start
    return round(end, 3)

false_times = []
for x in tqdm(range(10, 400, 10)):
    false_times.append((time_search(x, rerank=False), x))
    

ranked_times = []
for x in tqdm(range(10, 400, 10)):
    ranked_times.append((time_search(x, rerank=True), x))

import pandas as pd

false_df = pd.DataFrame(false_times, columns=['time', 'n'])
ranked = pd.DataFrame(ranked_times, columns=['time', 'n'])

ax = false_df.plot.scatter(x='n', y='time', label='No Reranker')
ax2 = ranked.plot.scatter(x='n', y='time', ax=ax, color='orange', ylabel='Latency (ms)', label='With Reranker', xlabel='# of Returned Results')

# Validation Dataset

In [1]:
data_path = './data/valid_dataset.json'

In [None]:
valid = E