## Reranking using cross encoder

In [2]:
!pip install flashrank

Collecting flashrank
  Downloading FlashRank-0.2.10-py3-none-any.whl.metadata (14 kB)
Downloading FlashRank-0.2.10-py3-none-any.whl (14 kB)
Installing collected packages: flashrank
Successfully installed flashrank-0.2.10


In [3]:
from flashrank import Ranker, RerankRequest
from typing import List, Dict, Any

In [10]:
class RerankProcessor:

    def __init__(self, model_name: str = "ms-marco-TinyBERT-L-2-v2"):
        """
        Initializes the FlashRank engine.
        Model options: 'ms-marco-MiniLM-L-12-v2' (More accurate) 
        or 'ms-marco-TinyBERT-L-2-v2' (Faster/Smaller).
        """
        self.ranker = Ranker(model_name=model_name, cache_dir="/tmp/")

    def refine_context(
            self,
            query: str,
            documents: List[Dict[str, Any]],
            threshold: float = 0.1,
            top_k: int = 3
    ) -> List[Dict[str, Any]]:
        
        if not documents:
            return []
        
        rerank_request = RerankRequest(query=query, passages=documents)

        rerank_response = self.ranker.rerank(rerank_request)

        refined_results = [
            doc for doc in rerank_response
            if doc['score'] >= threshold
        ]

        sorted_docs = sorted(refined_results, key=lambda x: x['score'], reverse=True)

        return sorted_docs[:top_k]
       

In [11]:
# documents

raw_retrieved_data = [
    {"id": "DOC_001", "text": "To reset your password, click the 'Forgot Password' link on the login page.", "meta": "auth_docs"},
    {"id": "DOC_002", "text": "The company kitchen provides free coffee and snacks for all employees.", "meta": "office_policy"},
    {"id": "DOC_003", "text": "Security protocols require a password reset every 90 days for administrative accounts.", "meta": "security_manual"},
    {"id": "DOC_004", "text": "Passwords must be at least 12 characters long and contain one special symbol.", "meta": "security_manual"},
    {"id": "DOC_005", "text": "The annual holiday party is scheduled for December 15th at the grand ballroom.", "meta": "events"}
]

In [12]:
rerank = RerankProcessor()

result = rerank.refine_context(
    query="How do I reset my password?",
    documents=raw_retrieved_data,
    threshold=0.2,
    top_k=2
)

In [13]:
print(result)

[{'id': 'DOC_001', 'text': "To reset your password, click the 'Forgot Password' link on the login page.", 'meta': 'auth_docs', 'score': np.float32(0.99954754)}, {'id': 'DOC_003', 'text': 'Security protocols require a password reset every 90 days for administrative accounts.', 'meta': 'security_manual', 'score': np.float32(0.98111445)}]


In [15]:
import pprint

pprint.pprint(result)

[{'id': 'DOC_001',
  'meta': 'auth_docs',
  'score': np.float32(0.99954754),
  'text': "To reset your password, click the 'Forgot Password' link on the "
          'login page.'},
 {'id': 'DOC_003',
  'meta': 'security_manual',
  'score': np.float32(0.98111445),
  'text': 'Security protocols require a password reset every 90 days for '
          'administrative accounts.'}]
