In [1]:
import logging
import asyncio
from typing import Any, Dict, List, Optional
from functools import partial
from langchain.schema import BaseRetriever, Document
from langchain.callbacks.manager import CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun

logging.basicConfig(
    format="%(asctime)s %(levelname)s:%(message)s",
    level=logging.DEBUG,
    datefmt="%m%d%Y %I:%M:%S %p"
)

In [None]:
class IRRetriever(BaseRetriever):
    """IRRetriever to retrieve documents using CustomIR."""
    
    ir_model: Any
    """IR vectorizer"""
    topk: int=5
    """Number of documents to return"""
    search_args: Dict=None
    """Search arguments of IR"""
    retrieved_docs: List[Document]=None
    """List of documents retrieved"""
    
    class Config:
        """Configuration for this pydantic object"""
        
        arbitrary_types_allowed = True
    
    
    def check_is_project_exist(self, project_name, where):
        projects = self.ir_model.list_projects()[where]
        is_exist = False
        for project in projects:
            name = project.get("name")
            if project_name == name:
                is_exist = True
        return is_exist
    
    def create_new_project(self, project_name, where, index_args=None):
        if index_args is None:
            index_args = self.ir_model.get_default_index_args()
        
        self.ir_model.new_project(
            project_name=project_name, where=where, index_args=index_args
        )
        logging.info("New project named `{}` created successfully.")

    def load_project(self, project_name, where):
        is_exist = self.check_is_project_exist(project_name, where)
        if not is_exist:
            raise ValueError(
                f"{project_name} not found in {where}. Please create the project first."
            )
        else:
            # load project -> get self.search_args
            self.search_args = self.ir_model.load_project(
                project_name=project_name, where=where
            )
            logging.info("Load project succesfully.")
    
    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        """Get documents relevant to a query.
        Args:
            query: String to find relevant documents for
            run_manager: The callbacks handler to use
        Returns:
            List of relevant documents
        """
        if self.search_args is None:
            raise Exception("No project loaded. Please use `load_project` method to load project in advance.")
        
        self.search_args["tok_k"] = self.topk
        docs = self.ir_model.search(query, search_args=self.search_args)
        
        self.retrieved_docs = []
        for i, doc in enumerate(docs, start=1):
            title, context = doc.pop("title"), doc.pop("context")
            page_content = f"{str(i)}. {title}\n{context}"
            self.retrieved_docs += [
                Document(
                    page_content=page_content,
                    meta_data=doc,
                    type="Document"
                )
            ]
        return self.retrieved_docs
    
    async def _aget_relevant_documents(
        self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
    ) -> List[Document]:
        """Asynchronously get documents relevant to a query.
        Args:
            query: String to find relevant documents for
            run_manager: The callbacks handler to use
        Returns:
            List of relevant documents
        """
        return await asyncio.get_running_loop().run_in_executor(
            None, partial(self._get_relevant_documents, run_manager=run_manager), query
        )