# RAG-based Medical Bot

##### Any AI based assistance, code completion was not used to solve this problem.

In [1]:
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFacePipeline
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from utils import read_json

  from .autonotebook import tqdm as notebook_tqdm


# Step 0: Load and Preprocess Dataset

In [2]:
def remove_exact_duplicate_lines(text):
    """
    This function removes the duplicate lines from the passage of text.
    
    Args:
        text (str) : Medical domain answers.
        
    Returns:
        str : De-duplicated text
    """
    lines = [line.strip() for line in text.split('.') if line.strip()]
    unique_lines = []
    seen_lines = set()
    for line in lines:
        stripped_line = line.strip() # Consider stripping whitespace
        if stripped_line not in seen_lines:
            unique_lines.append(line)
            seen_lines.add(stripped_line)

    return ". ".join(unique_lines)

In [28]:
# Constant Global Variables

CHUNK_SIZE = 500
CHUNK_OVERLAP = 150
CLEAN_DATASET_PATH = './data/clean_questions_to_answers_dataset_v1.json'
EMBEDDING_MODEL_ID = "abhinand/MedEmbed-small-v0.1"
RERANKER_ID =  "cross-encoder/ms-marco-MiniLM-L-6-v2"


In [29]:
# load dataset
dataset = read_json(CLEAN_DATASET_PATH)

context_passages = []
dataset_questions = []

for question, answers in dataset.items():
    combined_passage = " ".join(answers)    
    context_passages.append(remove_exact_duplicate_lines(combined_passage)) # de-duplicate the data
    dataset_questions.append(question)

print(f'Total number of Questions: {len(dataset_questions)}')
print(f'Total number of Answers: {len(context_passages)}')

# Processed Dataset Preview
for q, a_list in list(zip(dataset_questions, context_passages))[:1]:
    print(f"\nQ: {q}\nAnswer: {a_list}")

Total number of Questions: 14338
Total number of Answers: 14338

Q: what is (are) glaucoma?
Answer: Glaucoma is a group of diseases that can damage the eye's optic nerve and result in vision loss and blindness. The most common form of the disease is open-angle glaucoma. With early treatment, you can often protect your eyes against serious vision loss. (Watch the video to learn more about glaucoma. To enlarge the video, click the brackets in the lower right-hand corner. To reduce the video, press the Escape (Esc) button on your keyboard. ) See this graphic for a quick overview of glaucoma, including how many people it affects, whos at risk, what to do if you have it, and how to learn more. See a glossary of glaucoma terms. The optic nerve is a bundle of more than 1 million nerve fibers. It connects the retina to the brain. Open-angle glaucoma is the most common form of glaucoma. In the normal eye, the clear fluid leaves the anterior chamber at the open angle where the cornea and iris me

# Step 1: Create Data Chunks

In [30]:
# Splits long passages into chunks of CHUNK_SIZE = 300 characters with an overlap of 100 characters.
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
    length_function=len,
)

all_chunks = []
chunk_to_question_map = {}

for i, doc in enumerate(context_passages):
    chunks = text_splitter.split_text(doc)
    question = dataset_questions[i]
    for chunk in chunks:
        chunk_to_question_map[chunk]=question
        all_chunks.append({
            "text": chunk,
            "question": question,  
            "answer_index": i + 1
        })

print(f"Total number of chunks created: {len(all_chunks)}")

Total number of chunks created: 58741


In [31]:
question_metadata = []

for i,quest in enumerate(dataset_questions):
    question_metadata.append({
        "text": quest,
        "answer_index": i + 1
    }) 

print(f"Total number of questions: {len(question_metadata)}")

Total number of questions: 14338


# Step 2: Create Embeddings

In [32]:
# BGE model fine-tuned on medical domain data.
encoder_model = SentenceTransformer(EMBEDDING_MODEL_ID)

chunk_texts = [chunk_info['text'] for chunk_info in all_chunks]

# converts text to embeddings.
answer_embeddings = encoder_model.encode(chunk_texts)
question_embeddings = encoder_model.encode(dataset_questions)

embedding_dimension = answer_embeddings.shape[1]
num_chunks = len(all_chunks)

print(f"Generated {num_chunks} answer embeddings of dimension {embedding_dimension}")
print(f"Generated {len(dataset_questions)} question embeddings of dimension {question_embeddings.shape[1]}")

Generated 58741 answer embeddings of dimension 384
Generated 14338 question embeddings of dimension 384


# Step 3: Create Vector Index

In [33]:
# Creating vector index for all the answers

answer_index = faiss.IndexFlatL2(embedding_dimension)
faiss.normalize_L2(answer_embeddings)
answer_index.add(answer_embeddings)
metadata_store = all_chunks


# Creating vector index for all the questions

question_index = faiss.IndexFlatL2(embedding_dimension)
faiss.normalize_L2(question_embeddings)
question_index.add(question_embeddings)

# Step 4: Initialize Reranker and Smart Retrieval

In [34]:
def smart_hybrid_rerank(
    user_query: str,
    query_encoder,
    answer_index,
    question_index,
    chunk_passages: list,
    chunk_to_question_map: list,
    dataset_questions: list,
    reranker,
    top_k=50,
    rank_k=10,
    final_k=3,
):
    """
    This function encapsulates the re-ranking logic for a given query and topk retrieved chunks of answers.
    It reranks the retrieved results based on:
        - similarity between user query and retrieved chunk's question.
        - similarity between user query and questions retrieved from questions index.
        - similarity between user query and retrieved chunk.
    If the average score of retrieved and ranked documents is all negative and below a certain threshold, then it assigns an out of domain flag
    to the query (Query not related to the database).
    
    """
    # encode user query
    query_embedding = query_encoder.encode([user_query])
    faiss.normalize_L2(query_embedding)
    
    # search answer index
    answer_distances, chunk_indices = answer_index.search(query_embedding, top_k)
    
    # retrieve chunks
    retrieved_chunks = [chunk_passages[i]['text'] for i in chunk_indices[0]]

    # retrieve questions corresponding to the retrieved chunks
    retrieved_questions = [chunk_to_question_map[chunk] for chunk in retrieved_chunks]
    
    # search question index
    question_distances, question_indices = question_index.search(query_embedding, 2) # only top 2
    similar_questions = [dataset_questions[i] for i in question_indices[0]]

    # score 1: sim(user query -- retrieved chunks)
    match_scores = reranker.predict([(user_query, chunk) for chunk in retrieved_chunks]) 
    
    # score 2: sim(user query -- retrieved chunks' questions)
    base_scores = reranker.predict([(user_query, q) for q in retrieved_questions])  
    
    # score 3: sim(retrieved chunks' questions -- similar questions from question index (only top 2))
    sim_scores = [
        max(reranker.predict([(rq, sq) for sq in similar_questions]))
        for rq in retrieved_questions
    ]
    
    # weighted average
    final_scores = [0.3 * ms + 0.6 * qs + 0.1*ss for ms, qs, ss in zip(match_scores, base_scores, sim_scores)]
    
    # if mean of all scores is negative and below -0.4 threshold, it means the query is out of domain scope.
    avg = np.mean(final_scores)
    if avg < -0.4:
        score_threshold = 0
    else:
        score_threshold = avg + 0.05

    filtered = [
        (score, passage, q)
        for score, passage, q in zip(final_scores, retrieved_chunks, retrieved_questions)
        if score >= score_threshold
    ]

    if not filtered:
        fallback = [
            (score, passage, q)
            for score, passage, q in zip(final_scores, retrieved_chunks, retrieved_questions)
        ]
        return fallback[:1], True
    else:
        return sorted(filtered, reverse=True)[:rank_k], False


In [35]:
reranker = CrossEncoder(RERANKER_ID)

You are trying to use a model that was created with Sentence Transformers version 4.1.0.dev0, but you're currently using version 4.0.2. This might cause unexpected behavior or errors. In that case, try to update to the latest version.


# Step 5: Load Open-source LLM

In [153]:
model_name ='TinyLlama/TinyLlama-1.1B-Chat-v1.0' # small LLM based on GPU and memory constraints.

llm_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
llm_tokenizer = AutoTokenizer.from_pretrained(model_name)

llm_pipe = pipeline("text-generation", model=llm_model, tokenizer=llm_tokenizer, max_new_tokens=400, do_sample=False, return_full_text=False, repetition_penalty=1)

llm = HuggingFacePipeline(pipeline=llm_pipe)

# Step 6: LangChain Pipeline

In [154]:
template = """
You are a knowledgeable and concise medical assistant.
Given the context below, answer the user's question **in your own words**.

Question:
{question}

Context:
{context}

Provide a clear, factual, and summarized response. Do **not** repeat the text verbatim.

Answer:
"""

In [155]:
prompt = PromptTemplate(input_variables=["question", "context"], template=template)
chain = LLMChain(prompt=prompt, llm=llm)

# Step 7: ChatBot Inference

In [156]:
import torch
import tqdm
import warnings
warnings.filterwarnings('ignore')
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [160]:
def get_llm_response(query):
    """
    This function extracts LLMs response based on a single query input.
    """
    top_chunks, is_ood = smart_hybrid_rerank(
        user_query=query,
        query_encoder=encoder_model,
        answer_index=answer_index,
        question_index=question_index,
        chunk_passages=all_chunks,
        chunk_to_question_map=chunk_to_question_map,
        dataset_questions=dataset_questions,
        reranker=reranker
    )

    context = " ".join([chunk for _, chunk, _ in top_chunks[:10]]) # put top 10 chunks in the context window.
    response = chain.run({"question": query, "context": context})

    if is_ood: # out of domain query?
        return "This may be an out-of-scope query."
    
    return response

### Example responses

In [159]:
# Unrelated question
query = "What is a Transformer based model?"
get_llm_response(query)

'This may be an out-of-scope query.'

In [161]:
# Simple question
query = "What is Glaucoma?"
get_llm_response(query)

'\nGlaucoma is a group of eye disorders in which the optic nerves connecting the eyes and the brain are progressively damaged. This damage can lead to reduction in side (peripheral) vision and eventual blindness. Other signs and symptoms may include bulging eyes, excessive tearing, and abnormal sensitivity to light (photophobia). The term "early-onset glaucoma" may be used when the disorder appears before the age of 40. In most people with glaucoma, the damage to the optic nerves is caused by increased pressure within the eyes (intraocular pressure). Intraocular pressure depends on a balance between fluid entering and leaving the eyes. Usually glaucoma develops in older adults, in whom the risk of developing the disorder may be affected by a variety of medical conditions including high blood pressure (hypertension) and diabetes mellitus, as well as family history. The risk of early-onset glaucoma depends mainly on heredity. Structural abnormalities that impede fluid drainage in the eye

In [162]:
# Asking two questions as one question.
query = "What is High Blood pressure and how to treat it?"
get_llm_response(query)

'High blood pressure is a common disease in which blood flows through blood vessels (arteries) at higher than normal pressures. It is a medical condition that can lead to various health problems, including heart disease, stroke, and kidney disease. Treatment for high blood pressure involves lifestyle changes and medications to control blood pressure. Health care providers develop treatment plans based on the diagnosis, lifestyle changes, and medicines that work best for each individual. Treatment plans may evolve until blood pressure control is achieved. In most cases, the goal is to keep blood pressure below 140/90 mmHg (130/80 if you have diabetes or chronic kidney disease). Normal blood pressure is less than 120/80. Ask your doctor what your blood pressure goal should be. If you have high blood pressure, you will need to treat it and control it for life. This means making lifestyle changes, and, in some cases, taking prescribed medicines, and getting ongoing and maintain normal bloo

### Comments:

The length of the responses can be controlled by max_tokens, also if the repetition_penalty is increased, the answers are much shorter and concise. The output here is more grounded based on retrieved documents. If temperature is increased, it will be less verbatim to the documents.

# Chatbot Evaluation

In [136]:
import evaluate
import random
import tqdm

In [59]:
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")

In [138]:
def get_llm_input(queries):
    
    inputs = []
    
    for query in tqdm.tqdm(queries):
        top_chunks, is_ood = smart_hybrid_rerank(
            user_query=query,
            query_encoder=encoder_model,
            answer_index=answer_index,
            question_index=question_index,
            chunk_passages=all_chunks,
            chunk_to_question_map=chunk_to_question_map,
            dataset_questions=dataset_questions,
            reranker=reranker
        )
    
        context = " ".join([chunk for _, chunk, _ in top_chunks[:10]])
    
        inputs.append({
            "question": query,
            "context": context
        })
    
    return inputs

In [139]:
sample_set = 50

indices = list(range(len(dataset_questions)))

selected_indices = random.sample(indices, sample_set)

# Random test set for LLM evaluation

test_questions = [dataset_questions[i] for i in selected_indices]
test_passages = [context_passages[i] for i in selected_indices]

In [140]:
len(test_questions), len(test_passages)

(50, 50)

In [141]:
llm_inputs = get_llm_input(test_questions)

100%|██████████| 50/50 [00:23<00:00,  2.13it/s]


In [144]:
predictions = []
for input_dict in tqdm.tqdm(llm_inputs, desc="Generating answers"):
    result = chain.apply([input_dict])[0] 
    predictions.append(result)

Generating answers: 100%|██████████| 50/50 [10:08<00:00, 12.17s/it]


In [150]:
llm_answers = [pred['text'] for pred in predictions]

# ROGUE Score

In [151]:
rouge.compute(predictions=llm_answers, references=test_passages)

{'rouge1': 0.4115395382943542,
 'rouge2': 0.28478724331104943,
 'rougeL': 0.3245266069281287,
 'rougeLsum': 0.32423982991848277}

### Summary

#### ROUGE Metrics (Lexical Overlap)
ROUGE-1 = 0.41 basically it shows that 41% of the unigrams in the reference answers were also present in the generated answers.

ROUGE-2 = 0.28 indicates that the generated answers maintain some sort of fluency and phrase similarity, but not perfect replication.

ROUGE-L = 0.325 implies moderate structural alignment, which means the model often reorders or paraphrases content rather than copying exact sequences.

#### Conclusion:
Based on these results and some manual inspection, it was evident that the model's output has certain level of lexical and structural overlap with the reference passages. The important words were replicated from the provided documents. The model didn't hallucinate much and was grounded by the retrieved passages. There is evidently some paraphrasing and hence exact matches of the phrases from the passages is less. The paraphrasing can be controlled by the temperature of the LLM but those experiments are for future scope based on broader requirements.

# Bert Score

In [152]:
bertscore_results = bertscore.compute(predictions=llm_answers, references=test_passages, lang="en")
print("BERTScore:")
print(f"Precision: {np.mean(bertscore_results['precision']):.4f}")
print(f"Recall:    {np.mean(bertscore_results['recall']):.4f}")
print(f"F1:        {np.mean(bertscore_results['f1']):.4f}")

BERTScore:
Precision: 0.8591
Recall:    0.8734
F1:        0.8656


#### BERTScore Metrics (Semantic Similarity)
F1 = 0.866 shows that the generated answers are semantically very close to the reference answers — even when exact words differ.

Precision = 0.859 indicates the model mostly avoids hallucinating extra or irrelevant information.

Recall = 0.873 shows that the model includes a large portion of the relevant content from the reference.

#### Conclusion:
The LLM definitely understands and expresses the correct medical facts from the supplied passages. High recall implies that answers are informative and complete, while high precision shows they’re on-topic and relevant. 

Apart from these metrics, based on manual inspection, its clear that the RAG system produces factually correct and relevant answers, with minor differences in phrasing or answer structure. The model often paraphrases or reformats the information instead of copying exact sentences — which is the desired behavior in to make the bot a good assistant. While lexical similarity (ROUGE) is moderate, semantic similarity (BERTScore) is high, indicating that the overall solution approach performs well for Medical Q&A.

# Future Scope:

The current outline solutions works well for the Small Memory and Single GPU constraints. There are various ways to improve the performane of each of these components of the RAG architecture.

### Embedding Model:
- Larger Models trained on specific Medical Domain can be used to encode user queries and documents.
- These model embeddings will be more robust and will focus on medical terminologies more than regular english language and expressions.
- The models used in this solution have lesser parameters and the embedding dimension is extremely small (384). Can go for larger model size with larger embedding dimension.

### RAG:

#### Data Filtering before Indexing
- An LLM can be used to filter out unnecessary answers that do not provide any contextual information related to the question. This reduces the number of documents indexed into the vector database, making the index more clean, and full of relevant content. This in turn helps to retrieve semantically similar and useful documents.

#### Topic Modeling/User Intent Classification
- Another classifier can be trained to classfiy the user query into certain topics. For example, in this dataset, there are various topics based on health conditions ~ Glaucoma, Blood Pressure, etc. There are several questions per topic. If a classifier predicts what topic the user query belongs to, that information can be used to either semantically retrieve documents only from specific topic (reducing unncessary noise) or can be used in re-ranking step to rank documents based on the topics.
- Can use this topic to understand if the query is out of domain or in domain.

#### Query rewriting
- LLMs can be used to expand the user query for better retrieval results. Add more domain knowledge to the user query if its very short.
- Augment the query with user history.

#### Indexing & Retrieval
- There are various large scale efficient Vector indexes to explore (Milvus, DBX, PineCone,etc.)
- Explore Contextual RAG ~ Use LLMs to generate context about each document chunk, and prepend that context to the respective chunk before indexing. This improves the retrieval accuracy.
- Use hybrid search instead of just cosine similarity to retrieve similar documents. BM25 and cosine similarity helps to retrieve better results.
- Use LLM to improve the ranking/ordering of the retrieved chunks. Can construct a prompt to verify if the user query and the document matches.
- Agentic RAG if compute and latency isn't too much of an issue.

#### LLM for Chatbot
- Explore LLMs with larger context window. There are several Medical LLMs fine-tuned on medical data that can be used for the final step in this architecture.
- LLMs with atleast 7B-8B parameters would definitely improve the performance ~ Rogue and Bert Score.
- If there is a well curated dataset on medical conditions and general medical information, fine-tuning using LoRA can be another option to ground the responses with respect to medical domain.
- Use LLMs as Judge to make sure if the generated answer and user query match.

#### Feasibility
- So many LLM calls cannot be incorporated in a real-time system if the latency is an issue. 
- Hosting multiple LLMs is expensive, multiple LLM calls involves lot of GPU compute and dedicated deployment server.
- As the documents scale, larger vector index is required which is hosted 24*7.
- LoRA Fine-tuning  OpenSource models can be cheaper alternative than using paid APIs based on user traffic.