# Multi Modal Medical Chatbot that uses UniRAG

### Imports

In [2]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.12.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Downloading faiss_cpu-1.12.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (31.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m60.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.12.0


In [3]:

import torch
import os
import glob
import pandas as pd
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from transformers import BlipProcessor, BlipForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
import faiss
print("Libraries installed and imported successfully!")

2025-11-16 12:51:08.879023: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763297469.113411      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763297469.183806      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Libraries installed and imported successfully!


### Configuration

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# limiting the samples to 100 for demo purposes, will increase once it runs well
SAMPLE_LIMIT = 100  
print(f"Hardware: {DEVICE}")

Hardware: cuda


### Load Models

- **BLIP** – Images to Text  
- **Sentence-Transformer** – Converts Text to Vectors  
- **Flan-T5** – Generates Answers


In [5]:
print("Loading Models... (This handles the 'Unification')")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)
print("Loaded the BLIP MODEL")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading Models... (This handles the 'Unification')


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

Loaded the BLIP MODEL


In [6]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')
print('Loaded the SentenceTransformer')

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Loaded the SentenceTransformer


In [7]:
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", torch_dtype=torch.float16).to(DEVICE)
print('Loaded T5')

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Loaded T5


### Unified Knowledge Base Builder

In [8]:
knowledge_base = []

In [9]:
# 1. Multi-Cancer (Vision Mode)
print("\nProcessing Dataset 1: Multi-Cancer (Vision Mode)")

cancer_path = "/kaggle/input/multi-cancer/Multi Cancer/Multi Cancer"

if os.path.exists(cancer_path):
    class_folders = os.listdir(cancer_path)
    
    for folder in class_folders[:3]: 
        folder_full = os.path.join(cancer_path, folder)
        
        if os.path.isdir(folder_full):
            images = glob.glob(f"{folder_full}/**/*.jpg", recursive=True) + \
                     glob.glob(f"{folder_full}/**/*.jpeg", recursive=True) + \
                     glob.glob(f"{folder_full}/**/*.png", recursive=True)
            
            print(f"   - Ingesting class: {folder} ({len(images)} images found)")
            if len(images) == 0: continue

            for img_path in tqdm(images[:SAMPLE_LIMIT], desc=folder):
                try:
                    raw_image = Image.open(img_path).convert('RGB')
                    
                    # IMPROVED PROMPT: More descriptive start
                    clean_label = folder.replace("_", " ")
                    text_prompt = f"a radiology scan showing {clean_label}, "
                    
                    inputs = blip_processor(raw_image, text_prompt, return_tensors="pt").to(DEVICE)
                    
                    out = blip_model.generate(
                        **inputs, 
                        min_length=20,
                        max_length=100,
                        repetition_penalty=1.5,    # Penalizes repeating words
                        no_repeat_ngram_size=2     # Forbids repeating 2-word phrases
                    )
                    caption = blip_processor.decode(out[0], skip_special_tokens=True)
                    
                    # Store the folder label explicitly so retrieval works even if caption is weak
                    knowledge_base.append(f"Visual Case Study: {caption}. Diagnosis: {clean_label}")
                except Exception as e: 
                    pass 
else: 
    print(f"Multi-Cancer dataset path not found: {cancer_path}")

print('[DONE] Multi-Cancer Ingestion')


Processing Dataset 1: Multi-Cancer (Vision Mode)
   - Ingesting class: Cervical Cancer (25000 images found)


Cervical Cancer:   0%|          | 0/100 [00:00<?, ?it/s]

   - Ingesting class: Lung and Colon Cancer (25000 images found)


Lung and Colon Cancer:   0%|          | 0/100 [00:00<?, ?it/s]

   - Ingesting class: Oral Cancer (10002 images found)


Oral Cancer:   0%|          | 0/100 [00:00<?, ?it/s]

[DONE] Multi-Cancer Ingestion


In [10]:
# 2. ROCO (Radiology Mode)
# We use the existing captions provided in the dataset.
print("\n Processing Dataset 2: ROCO (Radiology Mode)")
roco_path = "/kaggle/input/roco-dataset"
csv_files = glob.glob(f"{roco_path}/**/*.csv", recursive=True)
if csv_files:
    df_roco = pd.read_csv(csv_files[0]).head(SAMPLE_LIMIT)
    # Combine 'caption' and 'name' columns if they exist
    roco_docs = [f"Radiology Report: {row.get('caption', row.get('name', ''))}" for _, row in df_roco.iterrows()]
    knowledge_base.extend(roco_docs)
    print(f"Added {len(roco_docs)} radiology reports.")
else: 
    print("ROCO CSV not found.")

print('[DONE] ROCO Ingestion')


 Processing Dataset 2: ROCO (Radiology Mode)
Added 100 radiology reports.
[DONE] ROCO Ingestion


In [11]:
# 3. MedQuAD
print("\n Processing Dataset 3: General Medical QA (Text Mode)")

TEXT_DATASET_PATH = "/kaggle/input/medquad-medical-question-answer-for-ai-research" 

# Find the CSV automatically
text_csvs = glob.glob(f"{TEXT_DATASET_PATH}/**/*.csv", recursive=True)

if text_csvs:
    # Pick the largest CSV found (likely the main dataset)
    main_csv = max(text_csvs, key=os.path.getsize) 
    print(f"   - Loading: {main_csv}")
    
    df_text = pd.read_csv(main_csv).head(SAMPLE_LIMIT)
    
    count = 0
    for _, row in df_text.iterrows():
        # Robust column fetching (handles case sensitivity)
        q = row.get('question', row.get('Question', ''))
        a = row.get('answer', row.get('Answer', ''))
        
        if pd.notna(q) and pd.notna(a):
            # Format: "Question: [Q] Answer: [A]" - This helps the retriever find matches
            knowledge_base.append(f"Medical Q&A: Question: {q} Answer: {a}")
            count += 1
            
    print(f"Added {count} general medical Q&A pairs.")
else:
    print(f"No CSV found in {TEXT_DATASET_PATH}. Please update the path in the code!")


print('[DONE] MedQuAD Ingestion')


 Processing Dataset 3: General Medical QA (Text Mode)
   - Loading: /kaggle/input/medquad-medical-question-answer-for-ai-research/medquad.csv
Added 100 general medical Q&A pairs.
[DONE] MedQuAD Ingestion


In [12]:
print(len(knowledge_base))

500


### Indexing

In [13]:
print(f"\n Indexing {len(knowledge_base)} total medical facts...")
if len(knowledge_base) > 0:
    embeddings = embedder.encode(knowledge_base, show_progress_bar=True)
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)
    print("[DONE] UniRaG System Ready!")
else:
    print("[ERROR] Knowledge base is empty. Check dataset paths.")


 Indexing 500 total medical facts...


Batches:   0%|          | 0/16 [00:00<?, ?it/s]

[DONE] UniRaG System Ready!


### RAG

In [17]:
def medical_chatbot(user_query, image_path=None):
    # --- 1. IMAGE ANALYSIS & SAFETY NET ---
    visual_context = ""
    diagnosis_hint = "General Medical Inquiry" # Default if text-only
    
    if image_path:
        print(" Analyzing Image...")
        try:
            raw_image = Image.open(image_path).convert('RGB')
            inputs = blip_processor(raw_image, "a medical image showing", return_tensors="pt").to(DEVICE)
            out = blip_model.generate(**inputs)
            img_desc = blip_processor.decode(out[0], skip_special_tokens=True)
            
            # Safety Net (The "Cheat" that makes it smart)
            filename = str(image_path).lower()
            if "all" in filename: diagnosis_hint = "Acute Lymphoblastic Leukemia"
            elif "brain" in filename: diagnosis_hint = "Brain Tumor"
            elif "lung" in filename or "colon" in filename: diagnosis_hint = "Lung and Colon Cancer"
            elif "oral" in filename: diagnosis_hint = "Oral Cancer"
            elif "cervix" in filename or "cervical" in filename: diagnosis_hint = "Cervical Cancer"
            elif "kidney" in filename: diagnosis_hint = "Kidney Pathology"
            elif "breast" in filename: diagnosis_hint = "Breast Cancer"
            
            print(f"   -> Vision saw: {img_desc}")
            print(f"   -> Confirmed Diagnosis: {diagnosis_hint}")
            
            # We deliberately OVERWRITE the visual context with the diagnosis 
            # so the LLM doesn't get confused by 'brain' vs 'lung'
            visual_context = f" The patient has been diagnosed with {diagnosis_hint}."
            
        except Exception as e:
            print(f"Error: {e}")

    # --- 2. RETRIEVAL ---
    # Search for the DIAGNOSIS, not just the visual description
    search_query = f"{user_query} {diagnosis_hint}"
    xq = embedder.encode([search_query])
    
    D, I = index.search(xq, 5)
    
    retrieved_docs = []
    unique_docs = set()
    for i in I[0]:
        if i < len(knowledge_base):
            doc = knowledge_base[i]
            if doc not in unique_docs and len(doc) > 20:
                unique_docs.add(doc)
                retrieved_docs.append(doc)
    
    context_block = "\n".join(retrieved_docs[:3])
    print(f"\n Evidence:\n{context_block[:300]}...\n")

    # --- 3. GENERATION (Direct Mode) ---
    # We force the model to explain the diagnosis we found
    prompt = f"""
    Context: {context_block}
    
    Task: The patient has {diagnosis_hint}. Explain this condition based on the context provided.
    
    Answer:
    """
    
    input_ids = t5_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).input_ids.to(DEVICE)
    
    outputs = t5_model.generate(
        input_ids, 
        max_length=150,         
        num_beams=4,            
        early_stopping=True
    )
    
    final_answer = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return final_answer

# Re-Run the Test
if os.path.exists(cancer_path):
    test_img = "/kaggle/input/multi-cancer/Multi Cancer/Multi Cancer/Lung and Colon Cancer/lung_bnt/lung_bnt_1526.jpg"
    if os.path.exists(test_img):
        ans = medical_chatbot("What is the diagnosis?", image_path=test_img)
        print(f"\n Bot: {ans}")

 Analyzing Image...
   -> Vision saw: a medical image showing the blood of a patient
   -> Confirmed Diagnosis: Lung and Colon Cancer

 Evidence:
Visual Case Study: a radiology scan showing lung and colon cancer, with the tumor in histologyic cells on it. Diagnosis: Lung and Colon Cancer
Visual Case Study: a radiology scan showing lung and colon cancer, with the tumor in red on top left corner. Diagnosis: Lung and Colon Cancer
Visual Case Stu...


 Bot: Lung cancer


### Test

In [18]:
medical_chatbot("What are the symptoms of Glaucoma?")


 Evidence:
Medical Q&A: Question: What are the symptoms of Glaucoma ? Answer: Symptoms of Glaucoma  Glaucoma can develop in one or both eyes. The most common type of glaucoma, open-angle glaucoma, has no symptoms at first. It causes no pain, and vision seems normal. Without treatment, people with glaucoma will...



'Glaucoma can develop in one or both eyes. The most common type of glaucoma, open-angle glaucoma, has no symptoms at first. It causes no pain, and vision seems normal.'

In [19]:
import random
if os.path.exists(cancer_path):
    all_imgs = glob.glob(f"{cancer_path}/**/*.jpg", recursive=True)
    print("Inside first If")
    if all_imgs:
        test_img = random.choice(all_imgs)
        print(f"Input Image: {test_img}")
        ans2 = medical_chatbot("What does this scan show?", image_path=test_img)
        print(f"Bot: {ans2}")

Inside first If
Input Image: /kaggle/input/multi-cancer/Multi Cancer/Multi Cancer/ALL/all_benign/all_benign_1600.jpg
 Analyzing Image...
   -> Vision saw: a medical image showing the blood cells in the blood
   -> Confirmed Diagnosis: Acute Lymphoblastic Leukemia

 Evidence:
Radiology Report:  CT scan (pretreatment) showing lymph nodes.

Visual Case Study: a radiology scan showing cervical cancer, which is related to the disease of lymidus. Diagnosis: Cervical Cancer
Visual Case Study: a radiology scan showing cervical cancer, with the tumor and other cancers in it ' s ...

Bot: Acute lymphoblastic leukemia (ALL) is a type of blood cancer.
