# To run the model:
1. Connect to T4 GPU
2. Run all the cells
3. Enter the prompt in the input box



In [1]:
pip install --upgrade torch accelerate datasets pandas scikit-learn sacremoses accelerate peft bitsandbytes transformers sentence-transformers faiss-cpu

Collecting accelerate
  Downloading accelerate-1.2.0-py3-none-any.whl.metadata (19 kB)
Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting pandas
  Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Collecting peft
  Downloading peft-0.14.0-py3-none-any.whl.metadata (13 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl.metadata (2.9 kB)
Collecting transformers
  Downloading transformers-4.47.0-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.5/43.5 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Collecting sentence-transformers
  Downloading sentence_transformers-3.3.1-py3-none-any.whl.metadata (

In [2]:
import torch
import pandas as pd
import numpy as np
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    BitsAndBytesConfig,
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    PeftModel,
    PeftConfig
)
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import faiss
import os
from typing import List, Dict, Optional
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

class MedicalChatbotRAG:
    def __init__(
        self,
        base_model: str = "google/flan-t5-large",
        embedding_model: str = "all-MiniLM-L6-v2",
        checkpoint_path: Optional[str] = None,
        lora_r: int = 16,
        lora_alpha: int = 32,
        lora_dropout: float = 0.1,
        top_k_retrieval: int = 3
    ):
        # tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)

        # RAG initialziing
        self.embedding_model = SentenceTransformer(embedding_model)
        self.knowledge_base: List[str] = []
        self.knowledge_embeddings = None
        self.faiss_index = None

        #retrieval
        self.top_k_retrieval = top_k_retrieval

        # Checkpoint loading, if check point available to use use checkpoint else base model
        if checkpoint_path and os.path.exists(checkpoint_path):
            self._load_from_checkpoint(base_model, checkpoint_path)
        else:
            self._initialize_model(base_model, lora_r, lora_alpha, lora_dropout)

    def _initialize_model(self, base_model: str, lora_r: int, lora_alpha: int, lora_dropout: float):
        """
        LoRA configuration.
        """

        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True
        )

        # Load base model
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            base_model,
            device_map="auto",
            quantization_config=quantization_config
        )

        self.model = prepare_model_for_kbit_training(self.model)

        # LoRA config
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=["q", "v", "k", "o"],
            bias="none",
            task_type="SEQ_2_SEQ_LM"
        )

        # Apply LoRA
        self.model = get_peft_model(self.model, lora_config)

    def _load_from_checkpoint(self, base_model: str, checkpoint_path: str):
        """
        Load model from a checkpoint with LoRA.
        """
        try:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                base_model,
                device_map="auto",
                quantization_config=BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_use_double_quant=True
                )
            )
            self.model = prepare_model_for_kbit_training(self.model)
            peft_config = PeftConfig.from_pretrained(checkpoint_path)
            self.model = get_peft_model(self.model, peft_config)
            self.model = PeftModel.from_pretrained(self.model, checkpoint_path)
            print(f"Model loaded from checkpoint: {checkpoint_path}")
        except Exception as e:
            raise RuntimeError(f"Error loading model from checkpoint: {e}")

    def add_to_knowledge_base(self, documents: List[str]):
        self.knowledge_base.extend(documents)
        self.knowledge_embeddings = self.embedding_model.encode(self.knowledge_base)
        dimension = self.knowledge_embeddings.shape[1]
        self.faiss_index = faiss.IndexFlatL2(dimension)
        self.faiss_index.add(self.knowledge_embeddings)

    def retrieve_relevant_context(self, query: str) -> List[str]:
        if not self.faiss_index:
            return []

        query_embedding = self.embedding_model.encode([query])
        distances, indices = self.faiss_index.search(query_embedding, self.top_k_retrieval)
        return [self.knowledge_base[idx] for idx in indices[0]]

    def generate_response(self, prompt: str) -> str:

        # Step 1: use of RAG to gather relevant content from knowledge base
        retrieved_contexts = self.retrieve_relevant_context(prompt)

        # Step 2: augmented prompt with detailed instructions
        augmented_prompt = (
            f"You are a highly knowledgeable medical assistant. "
            f"Your goal is to provide accurate, detailed, and structured responses "
            f"to medical queries. Use the retrieved context below to formulate your answer.\n\n"
            f"Context: {' '.join(retrieved_contexts)}\n\n"
            f"Query: {prompt}\n\n"
            f"Response: "
        )

        # Step 3: Tokenize the augmented prompt
        inputs = self.tokenizer(
            augmented_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512
        )

        for key in inputs:
          inputs[key] = inputs[key].to(self.model.device)

        # Step 4: Generate the response using the model
        try:
            outputs = self.model.generate(
                **inputs,
                max_length=256,
                num_return_sequences=1,
                temperature=0.7,
                top_p=0.9,
                repetition_penalty=1.2,
                early_stopping=True
            )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        #error handling
        except Exception as e:
            response = f"An error occurred while generating the response: {str(e)}"

        return response

    def prepare_dataset(self, df: pd.DataFrame):
        #dataset preparation
        dataset = Dataset.from_pandas(df[['Prompt', 'Completion']])
        def tokenize_function(examples):
            inputs = self.tokenizer(
                examples['Prompt'], max_length=512, truncation=True, padding='max_length'
            )
            targets = self.tokenizer(
                examples['Completion'], max_length=512, truncation=True, padding='max_length'
            )
            inputs['labels'] = targets['input_ids']
            return inputs
        return dataset.map(tokenize_function, remove_columns=['Prompt', 'Completion'])




        #```commented to avoid training```

    # def train(self, df: pd.DataFrame, output_dir: str = "./output", epochs: int = 5, batch_size: int = 4):
    #     """
    #     Train the model with LoRA.
    #     """
    #     train_dataset = self.prepare_dataset(df)
    #     training_args = Seq2SeqTrainingArguments(
    #         output_dir=output_dir,
    #         num_train_epochs=epochs,
    #         per_device_train_batch_size=batch_size,
    #         predict_with_generate=True,
    #         fp16=True,
    #         save_strategy="epoch",
    #         evaluation_strategy="no",
    #         logging_dir="./logs",
    #         logging_steps=10
    #     )
    #     trainer = Seq2SeqTrainer(
    #         model=self.model,
    #         args=training_args,
    #         train_dataset=train_dataset,
    #         tokenizer=self.tokenizer
    #     )
    #     trainer.train()
    #     trainer.save_model(output_dir)

In [3]:
def main():
    # initialize chatbot with the checkpoint
    chatbot = MedicalChatbotRAG(
        base_model="google/flan-t5-large",
        embedding_model="all-MiniLM-L6-v2",
        checkpoint_path="https://drive.google.com/drive/folders/11w3BsCdeqUHRURhMPO4EmI8cohiTS-Ei?usp=sharing"
    )


    medical_knowledge_base = [
        # Cardiovascular Diseases
        "Hypertension (high blood pressure) is a common condition where the force of the blood against the walls of your arteries is too high. It can lead to serious complications like heart disease and stroke.",
        "Coronary artery disease occurs when blood vessels become narrowed or blocked by plaque, leading to reduced blood flow to the heart muscle. It can result in chest pain (angina) or a heart attack.",

        # Diabetes and Metabolic Disorders
        "Type 2 diabetes is a condition where the body becomes resistant to insulin or does not produce enough insulin. It often results in increased thirst, frequent urination, and fatigue.",
        "Insulin resistance can lead to metabolic syndrome, characterized by obesity, high blood pressure, and high blood sugar, which increase the risk of cardiovascular diseases.",

        # Respiratory Diseases
        "Asthma is a chronic condition characterized by inflammation and narrowing of the airways, leading to symptoms like wheezing, shortness of breath, and coughing.",
        "COPD is a group of lung diseases that cause chronic airflow limitation. It is mainly caused by long-term smoking and includes conditions like emphysema and chronic bronchitis.",

        # Neurological Conditions
        "Alzheimer's disease is a neurodegenerative condition that causes memory loss, confusion, and behavioral changes, commonly affecting older adults.",
        "Parkinson's disease is a neurodegenerative disorder that leads to tremors, stiffness, and difficulty with movement. It is caused by the loss of dopamine-producing neurons.",

        # Cancer Types and Treatments
        "Breast cancer is a common type of cancer that can present as a lump in the breast. Treatment may include surgery, chemotherapy, and radiation therapy.",
        "Lung cancer is often diagnosed through imaging and biopsy. Treatment options include surgery, chemotherapy, immunotherapy, and targeted therapies.",

        # Mental Health Disorders
        "Depression is a mood disorder marked by persistent feelings of sadness, loss of interest, and fatigue. It is typically treated with therapy and antidepressant medications.",
        "Anxiety disorders are characterized by excessive worry and fear. Treatment includes therapy (e.g., CBT) and medications (e.g., SSRIs, benzodiazepines).",

        # Infectious Diseases
        "Tuberculosis (TB) primarily affects the lungs and is characterized by a chronic cough, fever, and weight loss. It is treated with a combination of antibiotics.",
        "COVID-19 is caused by the SARS-CoV-2 virus and can present with fever, cough, difficulty breathing, and fatigue. Vaccines and antiviral treatments are available.",

        # Orthopedic Conditions
        "Osteoarthritis is a degenerative joint disease that causes pain, stiffness, and swelling, particularly in weight-bearing joints such as the knees and hips.",
        "Fractures are breaks in the bone caused by trauma. Treatment may include immobilization with a cast or surgery, depending on the type and location of the fracture.",

        # Pediatric Conditions
        "Pediatric asthma is common in children and is characterized by wheezing, coughing, and shortness of breath. It is often triggered by allergens or respiratory infections.",
        "Viral infections like RSV (Respiratory Syncytial Virus) can cause symptoms such as wheezing and difficulty breathing, especially in infants.",

        # Gastrointestinal Diseases
        "Irritable Bowel Syndrome (IBS) is a functional gastrointestinal disorder causing symptoms like abdominal pain, bloating, and changes in bowel habits.",
        "Gastroesophageal reflux disease (GERD) occurs when stomach acid frequently leaks into the esophagus, leading to symptoms like heartburn and regurgitation.",

        # Optional dataset related knowledge bases
        "Osteoarthritis is a degenerative joint disease that commonly affects older adults, leading to pain, stiffness, and reduced function.",
        "Management of osteoarthritis typically includes lifestyle modifications, physical therapy, pain relief with NSAIDs, and weight management.",
        "In cases of worsened symptoms, corticosteroid injections, hyaluronic acid injections, or surgical interventions may be considered.",

        "A fall onto an outstretched hand may result in a wrist sprain, fracture, or other injuries, presenting with pain, swelling, and difficulty moving the wrist.",
        "Common injuries include scaphoid fractures, distal radius fractures, or wrist sprains. These injuries can be diagnosed with X-rays or other imaging techniques.",
        "Treatment may involve immobilization with a cast or splint, pain management, and potentially surgery, depending on the severity of the fracture.",

        "Attention-Deficit/Hyperactivity Disorder (ADHD) is characterized by symptoms like restlessness, impulsivity, and difficulty concentrating, often beginning in childhood.",
        "ADHD can significantly affect personal and professional life and is diagnosed through clinical evaluation, including symptom assessment and history.",
        "Treatment options for ADHD include stimulant medications such as methylphenidate or non-stimulants like atomoxetine, along with behavioral therapy."
    ]

    # Add knowledge to the RAG model's knowledge base
    chatbot.add_to_knowledge_base(medical_knowledge_base)

    # Interact with the chatbot
    prompt = input("Enter your query: ")
    response = chatbot.generate_response(f"Please provide a detailed medical explanation for the following case: {prompt}")
    print(f"Chatbot: {response}")

    prompt1 = input("Enter your query: ")
    response1 = chatbot.generate_response(f"Please provide a detailed medical explanation for the following case: {prompt1}")
    print(f"Chatbot: {response1}")

    prompt2 = input("Enter your query: ")
    response2 = chatbot.generate_response(f"Please provide a detailed medical explanation for the following case: {prompt2}")
    print(f"Chatbot: {response2}")

if __name__ == "__main__":
    main()


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Enter your query: A 50-year-old male with obesity presents with chronic knee pain, which is worse with activity and relieved by rest. Physical examination reveals knee joint tenderness and crepitus with movement. What is the likely diagnosis and the next steps?
Chatbot: Osteoarthritis is a degenerative joint disease that causes pain, stiffness, and swelling, particularly in weight-bearing joints such as the knees and hips.
Enter your query: A 26-year-old man has been experiencing persistent restlessness, impulsivity, and difficulty concentrating both in his professional and personal life. He also mentions a history of similar problems during childhood. What could be the potential causes?
Chatbot: ADHD is characterized by symptoms like restlessness, impulsivity, and difficulty concentrating, often beginning in childhood. It is caused by the loss of dopamine-producing neurons.
Enter your query: A 45-year-old woman fell onto her outstretched hand while jogging. She now has pain and swelli