# OPTIONAL: Implement ReRanker methods
***

For those looking for a deeper understanding of how CrossEncoders work, try implementing the two methods that power the CrossEncoder reranking system.  Once completed, compare your work with the implementation in the `reranker.py` file. 

#### Instructions
1. Fill in the `None` sections in the `ReRanker Class`.  Read the comments for additional help.
2. Test out your implementation on real search results and note how the scores are reranked.
3. Compare your implementation with the code found in the `src/reranker.py` file.

In [None]:
from torch.nn import Sigmoid
from sentence_transformers import CrossEncoder
import numpy as np


class ReRanker(CrossEncoder):
    '''
    Cross-Encoder models achieve higher performance than Bi-Encoders, 
    however, they do not scale well to large datasets. The lack of scalability
    is due to the underlying cross-attention mechanism, which is computationally
    expensive.  Thus a Bi-Encoder is best used for 1st-stage document retrieval and 
    a Cross-Encoder is used to re-rank the retrieved documents. 

    https://www.sbert.net/examples/applications/cross-encoder/README.html
    '''

    def __init__(self, 
                 model_name: str='cross-encoder/ms-marco-MiniLM-L-6-v2',
                 **kwargs
                 ):
        super().__init__(model_name=model_name, 
                         **kwargs) 
        self.model_name = model_name
        self.score_field = 'cross_score'
        self.activation_fct = Sigmoid()

    def _cross_encoder_score(self, 
                             results: List[dict], 
                             query: str, 
                             cross_score_key: str='cross_score',
                             apply_sigmoid: bool=True,
                             return_scores: bool=False
                             ) -> Union[np.array, None]:
        '''
        Given a list of hits from a Retriever:
            1. Scores hits by passing query and results through CrossEncoder model. 
            2. Adds cross_score key to hits dictionary. 
            3. If desired returns np.array of CrossEncoder scores.

        Args
        ----
        results: List[dict]
            List of search results from OpenSearch client.
        query: str
            User query.
        cross_score_key: str='cross-score'
            Name of key/field that the new calculated cross encoder score will be associated with.
        return_scores: bool=False
            If True, returns a np.array of cross encoder scores. 

        Returns
        -------
        Either returns a np.array of cross encoder scores if "return_scores" is True, otherwise
        nothing is returned.  The primary purpose of this function is to update the "results" dict. 
        '''
        activation_fct = self.activation_fct if apply_sigmoid else None

        ##########################
        ##      START CODE      ##
        ##########################
        # build query/content list 
        # create a list of lists that contains the query and the content field of each result from "results"
        # important the cross_input variable must be a list of lists
        
        cross_input = None
        
        #get scores
        # Call the self.model's predict method to get predicted scores on the cross_input
        # Output at this step will be a numpy matrix of cross-encoder scores (logits if activation_fct is None)
        ######################################################################
        # Example: 
        # array([ 1.3296969 ,  0.8297793 ,  1.2054391 ,  2.9448447 ,  2.7284985 ,
        #         4.231843  , -1.6208533 ,  2.4096487 , -1.2081863 ,  2.9743905 ,
        #         3.2194595 , -0.27501446,  1.5256095 ,  2.8193645 ,  1.5568736 ,
        #         2.5138354 ,  1.9419916 ,  2.6341028 , -1.6115644 , -0.49818742,
        #         3.695484  ,  2.93317   ,  3.1728778 , -0.5114989 , -4.076729  ], dtype=float32)

        cross_scores = None

        #enumerate through the results and update each results dict with the cross_score_key arg as key and value as the new score:
        #Example:
             # {'cross-score' : 5.12345}
        for i, result in enumerate(results):
            None
        ##########################
        ##      END CODE      ##
        ##########################
        
        if return_scores:
            return cross_scores

    def rerank(self, results: List[dict], query: str, top_k: int=10, threshold: float=None) -> List[dict]:
        '''
        Given a list of search results from OpenSearch client, results are scored with a Cross Encoder 
        and returned in sorted order by the cross_score field.  Threshold allows user to filter out 
        results that do not meet cross_score threshold value:

        Args
        ----
        results: List[dict]
            List of search results from OpenSearch client.
        query: str
            User query.
        top_k: int=10
            Number of reranked results to return
        threshold: float=None
            If None, top_k results will be returned.  
            If float value is present, only results with a cross_score that meet or exceed the threshold
            will be retuned.  This arg is present to prevent very low scoring document from being returned. 

        Returns
        -------
        List of reranked search results. 
        '''

        ##########################
        ##      START CODE      ##
        ##########################
        
        # call the internal _cross_encoder_score function (it's ok that nothing is returned here)
        # the results dictionary is being updated 
        None

        #sort results by the new cross-score field
        sorted_hits = None

        #if user wants to set a threshold we need to account for that
        if threshold or threshold == 0:

            #filter sorted_hits by the threshold value
            filtered_hits = None
            
            if not any(filtered_hits):
                logger.warning(f'No hits above threshold {threshold}. Returning top {top_k} hits.')
                return sorted_hits[:top_k]
            return filtered_hits
            
        #if no threshold was set return top_k sorted_hits
        return None

        ##########################
        ##      END CODE      ##
        ##########################