In [4]:
# AI Dementia Classifier with MRI Processing and Gemini Integration
# Improved version with robust image handling and diagnostics

# Core imports
import os
from dotenv import load_dotenv

# Image processing
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from skimage import exposure, measure, filters


# TensorFlow model
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Reduce TensorFlow warnings

# LLM: Gemini via Langchain
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate

# UI
import gradio as gr

In [5]:
# Load environment variables
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GEMINI_MODEL = "gemini-2.0-flash"  # Using latest Gemini model

# Initialize Gemini model with low temperature for more predictable outputs
llm = ChatGoogleGenerativeAI(
    google_api_key=GEMINI_API_KEY,
    model=GEMINI_MODEL,
    temperature=0.2
)


In [3]:
# Load the CNN model
model_path = "saved_models/dementia_cnn_sequential_4_history_V2.keras"
cnn_model = load_model(model_path)

# Print model summary and input shape for debugging
print(f"CNN model loaded: {cnn_model.name}")
print(f"Model input shape: {cnn_model.input_shape}")
print(f"Model output shape: {cnn_model.output_shape}")

ValueError: File not found: filepath=saved_models/dementia_cnn_sequential_4_history_V2.keras. Please ensure the file is an accessible `.keras` zip file.

In [4]:
# Extract expected input dimensions from the model
model_height, model_width = cnn_model.input_shape[1:3]
model_channels = cnn_model.input_shape[3]
print(f"Model expects images of size {model_height}x{model_width} with {model_channels} channels")


Model expects images of size 240x240 with 3 channels


In [5]:
# Class label mapping - IMPORTANT: Must match your model's training labels
label_map = ['NonDemented', 'VeryMildDemented', 'MildDemented', 'ModerateDemented']

In [6]:
# Function to display the image with predicted class for validation
def display_processed_image(original_image, processed_array, prediction, confidence):
    """Display the original and processed images side by side with prediction"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    
    # Original image
    ax1.imshow(original_image)
    ax1.set_title("Original Image")
    ax1.axis('off')
    
    # Processed image - handle both RGB and grayscale
    if model_channels == 1:
        ax2.imshow(processed_array[0, :, :, 0], cmap='gray')
    else:
        ax2.imshow(processed_array[0])
    ax2.set_title(f"Processed: {prediction} ({confidence:.2f})")
    ax2.axis('off')
    
    plt.tight_layout()
    return fig

In [7]:
# Improved prediction function with robust image handling
def predict_from_mri(image, verbose=True):
    """Process MRI image and predict dementia class with confidence scores"""
    original_image = image.copy()  # Keep original for display
    
    # Step 1: Convert to the right color mode (RGB or grayscale)
    if model_channels == 1 and image.mode != "L":
        image = image.convert("L")  # Convert to grayscale
        if verbose:
            print("Converting to grayscale for single-channel model input")
    elif model_channels == 3 and image.mode != "RGB":
        image = image.convert("RGB")  # Convert to RGB
        if verbose:
            print("Converting to RGB for three-channel model input")
    
    # Step 2: Resize image to match model input dimensions
    image = image.resize((model_width, model_height))
    if verbose:
        print(f"Resized image to {model_width}x{model_height}")
    
    # Step 3: Convert to numpy array
    img_array = img_to_array(image)
    if verbose:
        print(f"Image array shape after conversion: {img_array.shape}")
    
    # Step 4: Normalize pixel values to [0, 1]
    img_array = img_array / 255.0
    
    # Step 5: Add batch dimension
    img_array = np.expand_dims(img_array, axis=0)
    if verbose:
        print(f"Final input shape for model: {img_array.shape}")
    
    # Step 6: Make prediction
    preds = cnn_model.predict(img_array, verbose=0)
    
    # Step 7: Get predicted class and confidence
    predicted_class = np.argmax(preds)
    confidence = float(preds[0][predicted_class])
    label = label_map[predicted_class]
    
    # Step 8: Print prediction details
    if verbose:
        print(f"Model prediction: {label} ({confidence:.2f} confidence)")
        print("Class probabilities:")
        for i, prob in enumerate(preds[0]):
            print(f"  {label_map[i]}: {prob:.4f}")
    
    # Optionally display the processed image
    if verbose:
        fig = display_processed_image(original_image, img_array, label, confidence)
        plt.show()
    
    return label, confidence, preds[0], img_array


In [8]:
# Function to validate model on test images
def validate_model(test_image_path):
    """Test model on a known image and display results"""
    print(f"Validating model with test image: {test_image_path}")
    test_image = Image.open(test_image_path)
    label, confidence, probs, processed_img = predict_from_mri(test_image)
    # Display image with prediction
    return label, confidence, probs

In [9]:
# Function to use Gemini to generate natural language explanation
def generate_gemini_summary(input_text):
    """Use Gemini to explain technical findings in patient-friendly language"""
    prompt = PromptTemplate.from_template("""
You are a dementia diagnostic assistant providing information to medical professionals.

Interpret the following patient data or classification and explain what it means in simple, compassionate terms suitable for a clinical audience:

"{input_text}"

                                          Focus on:
1. Explaining what this classification typically indicates about brain structure and function
2. Common cognitive or behavioral symptoms associated with this stage
3. Important considerations for the physician to discuss with the patient/family

Do not diagnose or suggest specific treatment plans. Just explain what the MRI findings and symptoms may indicate about dementia stages. Be precise but compassionate.
""")
    final_prompt = prompt.format(input_text=input_text)
    response = llm.invoke(final_prompt)
    return response.content.strip()

In [10]:


def analyze_ventricles(image):
    """Analyze ventricle size to help correct misclassifications"""
    # Convert to grayscale if needed
    if image.mode != "L":
        image = image.convert("L")
    
    # Resize for consistency
    image = image.resize((240, 240))
    
    # Convert to numpy array
    img_array = np.array(image)
    
    # Normalize
    img_array = img_array / 255.0
    
    # Apply CLAHE for better ventricle visibility
    img_eq = exposure.equalize_adapthist(img_array, clip_limit=0.03)
    
    # Create a central region mask where ventricles are typically located
    h, w = img_array.shape
    center_y, center_x = h // 2, w // 2
    Y, X = np.ogrid[:h, :w]
    
    # Create masks for ventricle region (typically central)
    ventricle_region = (X - center_x)**2 + (Y - center_y)**2 <= (min(h,w)//4)**2
    
    # Apply threshold to find ventricles (dark regions)
    thresh_val = filters.threshold_otsu(img_eq * ventricle_region)
    ventricles = (img_eq < thresh_val * 0.8) & ventricle_region
    
    # Measure ventricle properties
    ventricle_size = np.sum(ventricles) / np.sum(ventricle_region)
    
    # Visualize for debugging
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    ax1.imshow(img_array, cmap='gray')
    ax1.set_title("Original MRI")
    ax1.axis('off')
    
    ax2.imshow(img_eq, cmap='gray')
    ax2.set_title("Enhanced MRI")
    ax2.axis('off')
    
    ax3.imshow(img_array, cmap='gray')
    ax3.imshow(ventricles, cmap='hot', alpha=0.5)
    ax3.set_title(f"Ventricles (Size: {ventricle_size:.4f})")
    ax3.axis('off')
    
    plt.tight_layout()
    
    return ventricle_size, fig

def correct_misclassification(original_prediction, ventricle_size, all_probs):
    """Apply rules to correct common misclassifications based on ventricle analysis"""
    label, confidence = original_prediction
    
    # Create a dictionary for easier referencing
    class_probs = {label_map[i]: prob for i, prob in enumerate(all_probs)}
    
    # Rule 1: Large ventricles typically indicate moderate or severe dementia
    if ventricle_size > 0.12:  # Threshold determined empirically
        if label == "VeryMildDemented" or label == "MildDemented":
            # Check if there's reasonable probability for moderate
            if class_probs["ModerateDemented"] > 0.1:  # Even a small probability
                return "ModerateDemented", class_probs["ModerateDemented"]
            # Upgrade at least to mild if predicted as very mild
            elif label == "VeryMildDemented" and class_probs["MildDemented"] > 0.1:
                return "MildDemented", class_probs["MildDemented"]
    
    # Rule 2: If moderate is the second highest prediction with close probability
    sorted_probs = sorted(class_probs.items(), key=lambda x: x[1], reverse=True)
    if sorted_probs[0][0] != "ModerateDemented" and sorted_probs[1][0] == "ModerateDemented":
        if sorted_probs[1][1] > 0.7 * sorted_probs[0][1] and ventricle_size > 0.1:
            return "ModerateDemented", sorted_probs[1][1]
    
    # No correction needed
    return label, confidence

# Modify your existing predict_from_mri function to include this correction
def enhanced_predict_from_mri(image):
    """Enhanced prediction with ventricle analysis to correct misclassifications"""
    # Step 1: Get original model prediction
    label, confidence, all_probs, processed_img = predict_from_mri(image, verbose=False)
    
    # Step 2: Analyze ventricles
    ventricle_size, ventricle_fig = analyze_ventricles(image)
    print(f"Ventricle analysis complete - size: {ventricle_size:.4f}")
    
    # Step 3: Apply correction rules
    corrected_label, corrected_confidence = correct_misclassification(
        (label, confidence), ventricle_size, all_probs)
    
    # Step 4: Log the correction if any
    if label != corrected_label:
        print(f"⚠️ Prediction corrected: {label} → {corrected_label}")
        print(f"Original confidence: {confidence:.4f}, Corrected confidence: {corrected_confidence:.4f}")
    else:
        print(f"✓ Original prediction maintained: {label} ({confidence:.4f})")
    
    # Return both original and corrected predictions
    return {
        "original": {"label": label, "confidence": confidence},
        "corrected": {"label": corrected_label, "confidence": corrected_confidence},
        "ventricle_size": ventricle_size,
        "probabilities": all_probs,
        "visualization": ventricle_fig
    }

# Update your multimodal_chatbot function to use the enhanced prediction
def improved_multimodal_chatbot(mri_image, symptom_text):
    """Improved chatbot with ventricle analysis to fix misclassifications"""
    responses = []
    
    mri_text = ""
    symptom_summary = ""
    
    if mri_image is not None:
        try:
            # Use enhanced prediction with ventricle analysis
            result = enhanced_predict_from_mri(mri_image)
            
            # Extract results
            original = result["original"]
            corrected = result["corrected"]
            ventricle_size = result["ventricle_size"]
            all_probs = result["probabilities"]
            
            # Format probabilities for all classes
            class_probs = [f"{label_map[i]}: {prob:.4f}" for i, prob in enumerate(all_probs)]
            prob_text = ", ".join(class_probs)
            
            # Add results to responses
            responses.append(f"**MRI Classification:** {corrected['label']} (Confidence: {corrected['confidence']:.4f})")
            
            if original['label'] != corrected['label']:
                responses.append(f"*Note: Original model prediction ({original['label']}) was corrected based on ventricle analysis (size: {ventricle_size:.4f})*")
            
            responses.append(f"**All Probabilities:** {prob_text}")
            
            # Store MRI analysis for combined reasoning
            mri_text = (f"MRI analysis indicates a classification of **{corrected['label']}** with "
                         f"confidence of {corrected['confidence']:.4f}. Ventricle analysis shows "
                         f"a ventricle size of {ventricle_size:.4f}, which is "
                         f"{'enlarged' if ventricle_size > 0.12 else 'within normal range'}.")
            
        except Exception as e:
            error_msg = f"Error processing MRI image: {str(e)}"
            print(error_msg)
            responses.append(f"**Error:** {error_msg}")
    
    # Rest of your function remains the same...
    # Process symptoms, generate Gemini summary, etc.
    
    return "\n\n".join(responses)

In [11]:
# Create a diagnostic tool to validate the model on test images
def diagnostic_tool():
    """Run a series of diagnostic tests on the model with sample images"""
    print("Running model diagnostics...")
    
    # Check if test directory exists
    test_dir = "test_images"
    if not os.path.exists(test_dir):
        print(f"Test directory {test_dir} not found. Please create it and add sample MRI images.")
        return
    
    # Get list of test images
    test_images = [f for f in os.listdir(test_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
    if not test_images:
        print("No test images found. Add some sample MRIs to the test_images directory.")
        return
    
    # Test each image
    results = []
    for img_file in test_images:
        img_path = os.path.join(test_dir, img_file)
        print(f"\nTesting image: {img_file}")
        try:
            image = Image.open(img_path)
            label, confidence, probs, _ = predict_from_mri(image)
            results.append({
                'image': img_file,
                'prediction': label,
                'confidence': confidence,
                'probabilities': probs
            })
        except Exception as e:
            print(f"Error processing {img_file}: {str(e)}")
    
    # Print summary of results
    print("\n=== Diagnostic Results ===")
    for r in results:
        print(f"Image: {r['image']}, Prediction: {r['prediction']}, Confidence: {r['confidence']:.4f}")
    
    return results


In [12]:
# Create Gradio Interface with improved layout and error handling
def create_gradio_interface():
    """Create and launch Gradio UI for the dementia classifier"""
    
    with gr.Blocks(title="Enhanced Dementia MRI & Symptom Assistant") as demo:
        gr.Markdown("# Dementia MRI & Symptom Assistant (Powered by Gemini)")
        gr.Markdown("Upload an MRI scan and/or enter symptoms to receive AI-assisted analysis.")
        
        with gr.Row():
            with gr.Column():
                # Input components
                image_input = gr.Image(
                    type="pil", 
                    label="Upload MRI Image",
                    elem_id="mri-upload"
                )
                
                symptom_input = gr.Textbox(
                    lines=4, 
                    placeholder="Describe symptoms, cognitive issues, and relevant patient history...", 
                    label="Clinical Notes & Symptoms"
                )
                
                submit_btn = gr.Button("Analyze", variant="primary")
                
            with gr.Column():
                # Output components
                output = gr.Markdown(label="Analysis Results")
        
        # Set up event handler
        submit_btn.click(
            fn=improved_multimodal_chatbot,
            inputs=[image_input, symptom_input],
            outputs=output
        )
        
        gr.Markdown("### Disclaimer: This tool is for research purposes only and not for clinical use.")
    
    # Launch the interface
    demo.launch()

In [13]:
# Main execution function
if __name__ == "__main__":
    # Option 1: Run diagnostic tests on sample images
    # diagnostic_tool()
    
    # Option 2: Launch the Gradio interface
    create_gradio_interface()

* Running on local URL:  http://127.0.0.1:7866

To create a public link, set `share=True` in `launch()`.


Ventricle analysis complete - size: 0.2853
✓ Original prediction maintained: MildDemented (1.0000)
Ventricle analysis complete - size: 0.3056
✓ Original prediction maintained: MildDemented (1.0000)
Ventricle analysis complete - size: 0.2493
✓ Original prediction maintained: MildDemented (1.0000)
