In [46]:
import os
import pickle
from typing import List, Type, Union
from dotenv import load_dotenv
from pydantic import BaseModel, Field, model_validator, PrivateAttr
from langchain_core.documents import Document
from crewai.tools.base_tool import BaseTool
from langchain_community.retrievers import BM25Retriever

In [47]:
def add_ids_to_documents(docs: List[Document]) -> List[Document]:
    
    fixed_docs = []
    for i, doc in enumerate(docs):
        # Copy metadata to avoid mutating originals
        new_metadata = dict(doc.metadata)
        new_metadata["id"] = f"doc_{i}"
        fixed_docs.append(Document(page_content=doc.page_content, metadata=new_metadata))
    return fixed_docs

In [49]:
class ChunkRetrievalInput(BaseModel):
    user_query: str = Field(..., description="The query to search for relevant context.")
    chunks: List[Union[Document]] = Field(..., description="Chunks of documents to search in.")

    # @model_validator(mode="before")
    # def validate_chunks_format(cls, values):
    #     values["chunks"] = convert_to_document(values.get("chunks", []))
    #     return values


In [68]:
class BM25ChunkRetrieverTool(BaseTool):
    name: str = "BM25PatientChunkTool"
    description:str = (
        "Retrieves the most relevant patient document chunks from 'Chunks/patient_chunks.pkl' "
        "using the BM25 algorithm. Takes a search query as input."
    )
    def augment_query(self, query: str, n=5) -> List[str]:
        return [f"{query} variation {i+1}" for i in range(n)]
    
    def _run(self, user_query: str):
        # print(os.getcwd())

        # file_path = os.path.join(os.getcwd(), "Chunks", "patient_chunks.pkl")
        file_path = r"C:\Users\rahul.g\Downloads\rahul\rahul\Chunks\patient_chunks.pkl"
        with open(file_path, "rb") as f:
            chunks = pickle.load(f)

        if chunks and isinstance(chunks[0], dict):
            chunks = [
                Document(page_content=c["page_content"], metadata=c.get("metadata", {}))
                for c in chunks
            ]
        chunk = add_ids_to_documents(chunks)
        retriever = BM25Retriever.from_documents(chunk)
        retriever.k = 5
        queries = self.augment_query(user_query,n=5)

        all_docs= []
        for q in queries:
            all_docs.extend(retriever.get_relevant_documents(q))
        
        seen = set()
        unique_document = [] 
        for doc in all_docs:
            content = getattr(doc, "page_content",str(doc))
            if doc.page_content not in seen:
                unique_document.append(doc)
                seen.add(doc.page_content)
            
        return "\n\n".join([doc.page_content for doc in unique_document])


    

In [69]:
retriever_tool = BM25ChunkRetrieverTool()
result = retriever_tool._run(
    user_query="how to spread awareness about cancer?",
)

print("\nRetrieved context:\n")
print(result)


Retrieved context:

### International Overdose Awareness Day‎

On August 31 of each year, International Overdose Awareness Day (IOAD) is recognized globally as a day to remember and grieve those that we've lost, take action to encourage support and recovery, and help end overdose by spreading awareness about drug overdose prevention. Join us as an IOAD partner by using your voice and platforms to spread messages about ending overdose.

!["End overdose. August 31. International Overdose Awareness Day"](/overdose-prevention/media/images/awareness/ioad-2023/23_IOAD_General_end1_1200x675.png)

International Overdose Awareness Day is a day to remember those lost to overdose, acknowledge the grief of loved ones left behind, and work ...

[Show More](javascript:void(0))

## Promising prevention strategies

Dr. Vemuri is passionate about training new and diverse cadres of scientists, and he enjoys presenting to and recruiting junior investigators at national and international meetings.

## In

  all_docs.extend(retriever.get_relevant_documents(q))
