### Chatbot to upload dementia brain scan image located in the /data_split_trainTest/test folder.
You can select any images under MildDemented, ModerateDemented, NonDemented, or VeryMildDemented folder
Note: We chose 60% for the model to use the image and 40% for the text diagnostic.
So keep in mind this is only an exploratory project and we won't be able to hit 100% accuracy

At the Gradio interface you can also add patient diagnostic.

Note: Image is required. Diagnostic is optional but can be added later. End result, Gemini LLM will give you a diagnostic. Remember this is only an exploratory project, not to replace a true doctor's diagnosis


In [1]:
!pip install tensorflow gradio transformers



In [2]:
# Load necessary libraries
import re
import os
#import google.generativeai as genai
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import gradio as gr
from transformers import pipeline

from dotenv import load_dotenv
load_dotenv()
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
from dotenv import load_dotenv

In [3]:
# Load trained Keras model from correct path
# Remember this Keras file is very big, go to https://www.dropbox.com/scl/fo/mmtv94e8t4u9x7vgvqwvw/AEigp2bJ9nK4hC_juE2aVkY?rlkey=5s76fv2w7303kxymyccdinx0h&dl=0
# Request ACCESS to download the file from Matt Le (Matt@luxiumcreative.com) from thesaved_models folder, and upload the keras file to your own local

model = load_model("saved_models/dementia_cnn_sequential_model_V2.keras")

# Define label map
label_map = {
    0: "Mild Demented",
    1: "Moderate Demented",
    2: "Non Demented",
    3: "Very Mild Demented"
}

2025-05-18 20:31:33.225910: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M4 Pro
2025-05-18 20:31:33.225935: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 48.00 GB
2025-05-18 20:31:33.225941: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 18.00 GB
I0000 00:00:1747621893.225959  847872 pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
I0000 00:00:1747621893.225984  847872 pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [4]:
# Load a pre-trained text classifier (multi-class classification using BERT)
sentiment_model = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")

Device set to use mps:0


In [5]:
# Image preprocessing
def preprocess_mri_image(img_path_or_pil):
    model_input_shape = model.input_shape[1:3]
    img = image.load_img(img_path_or_pil, target_size=model_input_shape)
    img_array = image.img_to_array(img) / 255.0
    return np.expand_dims(img_array, axis=0)

In [6]:
# MRI prediction
def predict_dementia_stage(img):
    processed = preprocess_mri_image(img)
    preds = model.predict(processed)[0]
    return preds

In [7]:
# Rule-based text classification
def rule_based_classifier(text):
    text = text.lower()
    symptoms = {
        'moderate': [
            r'cannot recognize', r'difficulty dressing', r'cannot function alone', r'severe memory', r'wandering', r'safety concern'
        ],
        'mild': [
            r'forgetting recent events', r'losing things', r'confusion with time', r'poor judgment', r'difficulty with planning'
        ],
        'very_mild': [
            r'slight forgetfulness', r'misplacing keys', r'word finding difficulty', r'mild memory problems'
        ],
        'none': [
            r'no memory issue', r'normal memory', r'sharp', r'alert', r'good cognition'
        ]
    }
    scores = {k: 0 for k in symptoms}
    for stage, patterns in symptoms.items():
        for p in patterns:
            if re.search(p, text):
                scores[stage] += 1
    total = sum(scores.values())
    if total == 0:
        return [0.25, 0.25, 0.25, 0.25]
    return [
        scores['mild'] / total,
        scores['moderate'] / total,
        scores['none'] / total,
        scores['very_mild'] / total
    ]

In [8]:
# BERT sentiment classifier

def sentiment_classifier(text):
    result = sentiment_model(text, top_k=1)[0]
    label = result['label'].lower()
    if 'negative' in label:
        return [0.1, 0.6, 0.1, 0.2]
    elif 'positive' in label:
        return [0.2, 0.1, 0.6, 0.1]
    else:
        return [0.25, 0.25, 0.25, 0.25]

In [9]:
# Combined text classifier

def classify_text_to_dementia(text):
    rule_probs = np.array(rule_based_classifier(text))
    sentiment_probs = np.array(sentiment_classifier(text))
    return (0.5 * rule_probs + 0.5 * sentiment_probs).tolist()

In [None]:
# Combine MRI and text prediction

def combine_predictions(mri_probs, text_probs, mri_weight=0.6): # this is the weight for MRI
    return (mri_weight * np.array(mri_probs) + (1 - mri_weight) * np.array(text_probs)).tolist()

In [11]:
# Store the API Key in a variable
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

# Set up the Google Generative AI model for our LLMs
GEMINI_MODEL = "gemini-2.0-flash"

llm_gemini = ChatGoogleGenerativeAI(
    model=GEMINI_MODEL,
    google_api_key=GEMINI_API_KEY,
    temperature=0.2
)



In [12]:

# Gemini summary for dementia stage explanation
# ✅ Define the Gemini stage explanation function
def generate_dementia_stage_summary(diagnosis):
    try:
        prompt = f"What does it mean if someone is diagnosed with {diagnosis}? Explain the dementia stage."
        response = llm_gemini.invoke(prompt)
        return response.content if hasattr(response, "content") else str(response)
    except Exception as e:
        return f"❌ Gemini stage summary error: {str(e)}"




In [13]:
# ✅ Define the summary generator function
def generate_gemini_summary(input_text):
    try:
        prompt = f"Explain in simple clinical terms: {input_text}"
        response = llm_gemini.invoke(prompt)
        return response.content if hasattr(response, "content") else str(response)
    except Exception as e:
        return f"❌ Gemini error: {str(e)}"

In [14]:
# Chatbot response

def dementia_chatbot(image, text_input):
    if image is None:
        return "Please upload an MRI image."

    mri_probs = predict_dementia_stage(image)
    if text_input:
        text_probs = classify_text_to_dementia(text_input)
        final_probs = combine_predictions(mri_probs, text_probs)
    else:
        final_probs = mri_probs

    idx = int(np.argmax(final_probs))
    confidence = final_probs[idx]
    diagnosis = label_map[idx]
    prob_text = "\n".join([f"{label_map[i]}: {p:.2f}" for i, p in enumerate(final_probs)])
    

    response = f"🧠 **Predicted Dementia Stage**: {diagnosis} (Confidence: {confidence:.2f})\n"
    response += f"📊 **Combined Probabilities**:\n{prob_text}\n"
    if text_input:
        response += "💬 **Entered Text Analysis**: Combined rule-based + BERT sentiment (40% weight).\n"
    # ✅ Add Gemini MRI interpretation summary
    try:
        explanation = generate_gemini_summary(f"MRI indicates: {diagnosis} with {confidence:.2f} confidence")
        response += f"\n🧬 **Gemini MRI Interpretation:**\n{explanation}\n"
    except Exception as e:
        response += f"\n❌ Gemini MRI Interpretation Error: {e}\n"

    # ✅ Add Gemini explanation of the dementia stage
    try:
        stage_summary = generate_dementia_stage_summary(diagnosis)
        response += f"\n📘 **Gemini Stage Explanation:**\n{stage_summary}\n"
    except Exception as e:
        response += f"\n❌ Gemini Stage Explanation Error: {e}\n"

    return response

In [None]:
# Gradio app interface
demo = gr.Interface(
    fn=dementia_chatbot,
    inputs=[
        gr.Image(type="filepath", label="Upload Brain MRI Image"),
        gr.Textbox(lines=2, placeholder="Enter symptoms or previous diagnosis (optional)", label="Symptom Text")
    ],
    outputs="text",
    title="🧠 Dementia MRI Chatbot (MRI + Text Analysis)",
    description="Upload a brain MRI and optionally describe symptoms. MRI contributes 60%, and combined rule-based + sentiment text classification contributes 40% to the final result."
)

demo.launch(share=True)


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://5526421984f64cd7ea.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 62ms/step


2025-05-18 20:32:48.116814: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23