## MediPal -- Multi-Vector Retriever with rerank

### In this section, I built a multi-vector retriever with Re-rank mechanism
![](../assets/screenshots/rerank_retriever.PNG "")

In the previous section, I used a medical-domain LLM to generate questions from multiple perspectives based on the given content.

As planned, I implemented a Multi-Vector architecture as the foundation of the RAG application.

Here’s what I did:

* Embedded the generated questions into the vector store and stored the corresponding documents in the doc store.

* Used the doc_id, to establish a link between the vector store and the doc store.

* Applied a cross-encoder to re-rank the retrieved documents and improve retrieval precision.

In [1]:
# Multi-Vector implementation
from src.mytools import timed, login_huggingface
import os, json, copy
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder

  from .autonotebook import tqdm as notebook_tqdm


##### I choose sentence-transformers/embeddinggemma-300m-medical, as it is a sentence-transformers model finetuned from google/embeddinggemma-300m on the miriad/miriad-4.4M dataset. It maps sentences & documents to a 768-dimensional dense vector space and can be used for medical information retrieval, specifically designed for searching for passages (up to 1k tokens) of scientific medical papers using detailed medical questions.

##### Citation:

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}

@misc{gao2021scaling,
    title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
    author={Luyu Gao and Yunyi Zhang and Jiawei Han and Jamie Callan},
    year={2021},
    eprint={2101.06983},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

### I involve a cross-encoder(ncbi/MedCPT-Cross-Encoder) to rerank the retrieved documents and output top_k(n) ones.
##### This crossEncoder(Bert) model was fine-tuned on 30522 medical related tokens.

##### Citation:

@article{jin2023medcpt,
  title={MedCPT: Contrastive Pre-trained Transformers with large-scale PubMed search logs for zero-shot biomedical information retrieval},
  author={Jin, Qiao and Kim, Won and Chen, Qingyu and Comeau, Donald C and Yeganova, Lana and Wilbur, W John and Lu, Zhiyong},
  journal={Bioinformatics},
  volume={39},
  number={11},
  pages={btad651},
  year={2023},
  publisher={Oxford University Press}
}


In [5]:
class Rerank_Retriever():
    """
        Rerank_Retriever class definition:
            Attributes:
                workspace_base_path: The current workspace.
                dataset_path: The path to the medicine dataset.                
                embedding_model_id: The name of the embedding model.
                cross_encoder_model_id: The name of crossEncoder model which is used to do reranking.
                embedding_model: A embedding model.
                retriever: It is a very important retriever who will similarity search the documents based on query.

            Functions:
                load_json_list: Load json file to json objects.
                login_huggingface: Login huggingface to gain the access to the LLMs
                build_medicine_retriever: Build a multi-vector db which contains vectorstore and docstore. Embedding generated questions to vectorstore and Storing original documents to docstore.
                load_embedding_model: Load embedding model.
                load_crossencoder: Load cross encoder model.
                retrieve: Wrap retriever and reranker up to fetch top_k relevant documents.
    """
    def __init__(self) -> None:

        self.workspace_base_path = os.getcwd()
        self.dataset_path = os.path.join(self.workspace_base_path, "src", "datasets", "medicine_data_questions.json")  
        self.chunked_dataset_path = os.path.join(self.workspace_base_path, "src", "datasets", "chunked_medicine_data.json")  
        self.vector_persist_directory = os.path.join(self.workspace_base_path, "src", "datasets", "vectordb")
        self.embedding_model_id = "sentence-transformers/embeddinggemma-300m-medical"
        self.cross_encoder_model_id = "ncbi/MedCPT-Cross-Encoder" 
        self.vectorstore = None
        self.embedding_model = None
        self.retriever = None
        self.cross_encoder = None

    @timed
    def load_embedding_model(self):        
        self.embedding_model = HuggingFaceEmbeddings(
            model_name=self.embedding_model_id,
            model_kwargs = {'device': 'cpu'},            
            # Normalizing helps cosine similarity behave better across models
            encode_kwargs={"normalize_embeddings": True},
        )      
    
    @timed
    def load_crossencoder(self):
        self.cross_encoder = CrossEncoder(self.cross_encoder_model_id)

    def load_questions_data(self):    
        with open(self.dataset_path, mode = "r", encoding="utf-8") as f:
            return json.load(f)
        
    def load_chunked_data(self):    
        with open(self.chunked_dataset_path, mode = "r", encoding="utf-8") as f:
            return json.load(f)      
        
    def build_medicine_retriever(self):        
        questions_data = self.load_questions_data()  
        chunked_data = self.load_chunked_data()          
        docstore = InMemoryStore()
        id_key = "doc_id"

        # The vectorstore to use to index the questions
        self.vectorstore = Chroma(
            collection_name = "medicine_data", 
            embedding_function = self.embedding_model,
            persist_directory=self.vector_persist_directory
        )
        # The Multi-Vector retriever
        self.retriever = MultiVectorRetriever(
            vectorstore=self.vectorstore,
            docstore=docstore,
            id_key=id_key,
        )

        doc_ids = list()
        questions = list()
        docs = list()
        for d in questions_data:
            doc_id = d["doc_id"]
            doc_ids.append(doc_id)
            docs.append(Document(metadata={"doc_id": doc_id}, page_content=d["original_doc"]))
            for q in d["questions"]:
                questions.append(Document(metadata={"doc_id": doc_id}, page_content=q))

        for d in chunked_data: 
            doc_id = d["doc_id"]        
            for q in d["docs"]:
                questions.append(Document(metadata={"doc_id": doc_id}, page_content=q))

        self.retriever.vectorstore.add_documents(questions)
        self.retriever.docstore.mset(list(zip(doc_ids,docs)))  
        
    def load_existing_retriever(self):
        questions_data = self.load_questions_data()
        docstore = InMemoryStore()
        id_key = "doc_id"
        # The vectorstore to use to index the questions
        self.vectorstore = Chroma(
            collection_name = "medicine_data", 
            embedding_function = self.embedding_model,
            persist_directory=self.vector_persist_directory
        )
        # The Multi-Vector retriever
        self.retriever = MultiVectorRetriever(
            vectorstore=self.vectorstore,
            docstore=docstore,
            id_key=id_key,
        )

        doc_ids = list()        
        docs = list()
        for d in questions_data:
            doc_id = d["doc_id"]
            doc_ids.append(doc_id)
            docs.append(Document(metadata={"doc_id": doc_id}, page_content=d["original_doc"]))
            
        self.retriever.docstore.mset(list(zip(doc_ids,docs)))

    @timed       
    def setup_retriever(self):
        login_huggingface()      
        self.load_embedding_model()
        self.load_crossencoder()

        if os.path.isdir(self.vector_persist_directory) and os.listdir(self.vector_persist_directory):
            self.load_existing_retriever()
        else:
            self.build_medicine_retriever()

    @timed
    def retrieve(self, query: str, top_k: int=5):
        retrieved_docs = self.retriever.invoke(query, kwargs={"k":10})
        retrieved_docs = copy.deepcopy(retrieved_docs) # Avoid rerank changes original documents
        #Rerank part
        pairs = [[query, d.page_content] for d in retrieved_docs]
        scores = self.cross_encoder.predict(pairs, batch_size=32)
        for r_d, score in zip(retrieved_docs, scores):
            r_d.metadata["rerank_score"] = float(score)
        retrieved_docs.sort(key= lambda d: d.metadata["rerank_score"], reverse=True)
        #Rerank part
        return retrieved_docs[ :top_k]

In [6]:
rag = Rerank_Retriever()

In [7]:
rag.setup_retriever()

setup_retriever starts runing!
Login HuggingFace!
load_embedding_model starts runing!
load_embedding_model took 3.9873s
load_crossencoder starts runing!
load_crossencoder took 1.1828s
setup_retriever took 23.6951s


In [8]:
rag.retrieve("what is Phenylephrine?",top_k=4)

retrieve starts runing!


Batches: 100%|██████████| 1/1 [00:00<00:00,  4.30it/s]

retrieve took 0.7255s





[Document(metadata={'doc_id': 'b1a6af68-fb24-4754-90f4-a49b9e9b039d', 'rerank_score': 0.9999997615814209}, page_content="phenylephrine may cause side effects. some side effects can be serious. if you experience any of these symptoms, stop using phenylephrine and call your doctor:nervousnessdizzinesssleeplessnessphenylephrine may cause other side effects. call your doctor if you have any unusual problems while taking Phenylephrine.if you experience a serious side effect, you or your doctor may send a report to the food and drug administration's (fda) medwatch adverse event reporting program online (https://www.fda.gov/safety/medwatch) or by phone ([phone])."),
 Document(metadata={'doc_id': 'a9505f8e-bd92-481f-acae-23d98d08be73', 'rerank_score': 0.9999997615814209}, page_content="phenylephrine comes as a tablet, a liquid, or a dissolving strip to take by mouth. it is usually taken every 4 hours as needed. follow the directions on your prescription label or the package label carefully, an

In [9]:
rag.retrieve("How can I avoid it?",top_k=4)

retrieve starts runing!


Batches: 100%|██████████| 1/1 [00:00<00:00, 16.39it/s]

retrieve took 0.7503s





[Document(metadata={'doc_id': '60559657-95b8-4f79-9a28-a6548c1f986d', 'rerank_score': 0.9443322420120239}, page_content="pyrethrin and piperonyl butoxide comes as a shampoo to apply to the skin and hair. it is usually applied to the skin and hair in two or three treatments. the second treatment must be applied 7-10 days after the first one. sometimes a third treatment may be necessary, as recommended by your doctor. follow the directions on your prescription label or the package label carefully, and ask your doctor or pharmacist to explain any part you do not understand. use pyrethrin and piperonyl butoxide shampoo exactly as directed. do not use more or less of it or use it more often than directed on the package label or prescribed by your doctor.the package label gives you an estimate of how much shampoo you will need based on your hair length. be sure to use enough shampoo to cover all of your scalp area and hair.pyrethrin and piperonyl butoxide shampoo should only be used on the s

In [10]:
rag.retrieve("Can phenylephrine be used to relieve nasal discomfort?",top_k=4)

retrieve starts runing!


Batches: 100%|██████████| 1/1 [00:00<00:00, 18.47it/s]

retrieve took 1.0253s





[Document(metadata={'doc_id': '9f78f45c-bdae-442b-aa03-755076702cbd', 'rerank_score': 0.9999997615814209}, page_content='phenylephrine is used to relieve nasal discomfort caused by colds, allergies, and hay fever. it is also used to relieve sinus congestion and pressure. phenylephrine will relieve symptoms but will not treat the cause of the symptoms or speed recovery. phenylephrine is in a class of medications called nasal decongestants. it works by reducing swelling of the blood vessels in the nasal passages. about Phenylephrine'),
 Document(metadata={'doc_id': 'a9505f8e-bd92-481f-acae-23d98d08be73', 'rerank_score': 0.994716465473175}, page_content="phenylephrine comes as a tablet, a liquid, or a dissolving strip to take by mouth. it is usually taken every 4 hours as needed. follow the directions on your prescription label or the package label carefully, and ask your doctor or pharmacist to explain any part you do not understand. take phenylephrine exactly as directed. do not take mo