In [34]:
from sentence_transformers import CrossEncoder
from opensearch_interface import OpenSearchClient
from typing import List, Union
import numpy as np
from loguru import logger
from reranker import ReRanker

In [3]:

class ReRanker:
    '''
    Cross-Encoder models achieve higher performance than Bi-Encoders, 
    however, they do not scale well to large datasets. The lack of scalability
    is due to the underlying cross-attention mechanism, which is computationally
    expensive.  Thus a Bi-Encoder is best used for 1st-stage document retrieval and 
    a Cross-Encoder is used to re-rank the retrieved documents. 

    https://www.sbert.net/examples/applications/cross-encoder/README.html
    '''

    def __init__(self, model_name: str='cross-encoder/ms-marco-MiniLM-L-6-v2', local_files: bool=False):
        self.model_name = model_name
        self.model = CrossEncoder(self.model_name, automodel_args={'local_files_only':local_files})
        self.score_field = 'cross_score'

    def _cross_encoder_score(self, 
                             results: List[dict], 
                             query: str, 
                             cross_score_key: str='cross-score',
                             return_scores: bool=False
                             ) -> Union[np.array, None]:
        '''
        Given a list of hits from a Retriever:
            1. Scores hits by passing query and results through CrossEncoder model. 
            2. Adds cross-score key to hits dictionary. 
            3. If desired returns np.array of Cross Encoder scores.

        Args
        ----
        results: List[dict]
            List of search results from OpenSearch client.
        query: str
            User query.
        cross_score_key: str='cross-score'
            Name of key/field that the new calculated cross encoder score will be associated with.
        return_scores: bool=False
            If True, returns a np.array of cross encoder scores. 

        Returns
        -------
        Either returns a np.array of cross encoder scores if "return_scores" is True, otherwise
        nothing is returned.  The primary purpose of this function is to update the "results" dict. 
        '''
        
        #build query/content list 
        #create a list of lists that contains the query and the content field of each result from "results"
        #important the cross_input variable must be a list of lists
        cross_input = None
        
        #get scores
        # Call the self.model's predict method to get predicted scores on the cross_input
        # Output at this step will be a numpy matrix of cross-encoder scores
        ######################################################################
        # Example: 
        # array([ 1.3296969 ,  0.8297793 ,  1.2054391 ,  2.9448447 ,  2.7284985 ,
        #         4.231843  , -1.6208533 ,  2.4096487 , -1.2081863 ,  2.9743905 ,
        #         3.2194595 , -0.27501446,  1.5256095 ,  2.8193645 ,  1.5568736 ,
        #         2.5138354 ,  1.9419916 ,  2.6341028 , -1.6115644 , -0.49818742,
        #         3.695484  ,  2.93317   ,  3.1728778 , -0.5114989 , -4.076729  ], dtype=float32)

        cross_scores = None

        #enumerate through the results and update each dict with the cross_score_key arg as key and value as the new score:
        #Example:
             # {'cross-score' : 5.12345}
        for i, result in enumerate(results):
            None

        if return_scores:
            return cross_scores

    def rerank(self, results: List[dict], query: str, top_k: int=10, threshold: float=None) -> List[dict]:
        '''
        Given a list of search results from OpenSearch client, results are scored with a Cross Encoder 
        and returned in sorted order by the cross_score field.  Threshold allows user to filter out 
        results that do not meet cross_score threshold value:

        Args
        ----
        results: List[dict]
            List of search results from OpenSearch client.
        query: str
            User query.
        top_k: int=10
            Number of reranked results to return
        threshold: float=None
            If None, top_k results will be returned.  
            If float value is present, only results with a cross_score that meet or exceed the threshold
            will be retuned.  This arg is present to prevent very low scoring document from being returned. 

        Returns
        -------
        List of reranked search results. 
        '''
        # call the internal _cross_encoder_score function (it's ok that nothing is returned here)
        # the results dictionary is being updated 
        None

        #sort results by the new cross-score field
        sorted_hits = None

        #if user wants to set a threshold we need to account for that
        if threshold or threshold == 0:

            #filter sorted_hits by the threshold value
            filtered_hits = None
            
            if not any(filtered_hits):
                logger.warning(f'No hits above threshold {threshold}. Returning top {top_k} hits.')
                return sorted_hits[:top_k]
            return filtered_hits
            
        #if no threshold was set return top_k sorted_hits
        return None

### Instantiate the ReRanker instance

In [35]:
ranker = ReRanker()

In [43]:
ranker._cross_encoder_score(response, query=query, return_scores=True)

array([ 1.3296969 ,  0.8297793 ,  1.2054391 ,  2.9448447 ,  2.7284985 ,
        4.231843  , -1.6208533 ,  2.4096487 , -1.2081863 ,  2.9743905 ,
        3.2194595 , -0.27501446,  1.5256095 ,  2.8193645 ,  1.5568736 ,
        2.5138354 ,  1.9419916 ,  2.6341028 , -1.6115644 , -0.49818742,
        3.695484  ,  2.93317   ,  3.1728778 , -0.5114989 , -4.076729  ],
      dtype=float32)

In [5]:
#set index name
index_name = 'impact-theory-minilm-196'

In [6]:
#set query
query = "how do I change my life for good"

### Test ReRanker class by conducting a hybrid search

In [13]:
#create new OpenSearch client
osclient = OpenSearchClient('sentence-transformers/all-MiniLM-L6-v2')
osclient.show_indexes()

health status index                              uuid                   pri rep docs.count docs.deleted store.size pri.store.size
yellow open   kw-impact-theory                   2MjMun4bQYOoeUpv5UsJxg   3   1      33164            0     29.4mb         29.4mb
yellow open   semantic-impact-theory-196         SY2nXyvmQ9i5LAS4hmn82g   3   1      37007            0    694.6mb        694.6mb
yellow open   kw-impact-theory-196               vsuHausxRb6EjysQriOX5w   3   1      37007            0     30.5mb         30.5mb
yellow open   security-auditlog-2023.10.21       Vj43Da3dTQm0mwBFNWHjCg   1   1          9            0    151.3kb        151.3kb
yellow open   security-auditlog-2023.10.22       YXYp6DkYT-aLgGxRZNGsUA   1   1       1704            0      1.6mb          1.6mb
yellow open   paul-graham3                       -74ZPvxoSMmtCPSzAI9o1A   1   1         18            0    768.2kb        768.2kb
yellow open   security-auditlog-2023.10.25       1Cn9t6VhT227XHl2KJJ-WQ   1   1        852

In [40]:
#conduct hybrid search
response = osclient.hybrid_search(query, index_name, index_name, kw_size=25, vec_size=0)

In [25]:
#ranked the hybrid response
ranker.rerank(response, query, top_k=10)

In [19]:
ranker._cross_encoder_score(response, query)

In [33]:
ranker.model.predict([['dead man walking', 'prisoner on death row with no chance to live'], ['what is a dead man walking', 'sean penn with susan sarandon']])

array([ -9.827699, -11.387819], dtype=float32)