In [None]:
import os
from flask import Flask, request, render_template
from werkzeug.utils import secure_filename
import numpy as np
import tensorflow as tf
from keras.layers import TFSMLayer
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions
from PIL import Image
import cv2

app = Flask(__name__)

imgnet = MobileNetV2(weights="imagenet")

def is_valid_brain_mri(filepath):
    """
    Enhanced validation to ensure the image is specifically a brain MRI scan
    """
    try:
        img = Image.open(filepath).convert("RGB")
        arr = np.array(img)
        
        channel_variance = np.var(arr, axis=(0, 1))
        print("Channel variance:", channel_variance)
        
        if np.std(channel_variance) > 800:  # Increased threshold
            print("Failed: Too colorful for MRI")
            return False, "Image appears to be colorful - MRI scans are typically grayscale"
        
        gray = cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY)
        hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
        
        dark_pixels = np.sum(hist[0:30])  # Very dark pixels (background)
        bright_pixels = np.sum(hist[220:256])  # Very bright pixels
        total_pixels = arr.shape[0] * arr.shape[1]
        
        if (dark_pixels / total_pixels) > 0.85:  # Allow more background
            print("Failed: Too many dark pixels")
            return False, "Image appears to be mostly background"
        
        if (bright_pixels / total_pixels) > 0.4:  # Allow some bright areas (CSF, etc.)
            print("Failed: Too many bright pixels")
            return False, "Image has excessive bright areas"
        
        edges = cv2.Canny(gray, 30, 100)  # More sensitive edge detection
        contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        brain_region = gray[gray > 20]  # Focus on non-background pixels
        
        if len(brain_region) == 0:
            print("Failed: No brain tissue detected")
            return False, "No visible brain tissue detected in the image"
        
        brain_std = np.std(brain_region)
        if brain_std < 15:  # Too uniform
            print(f"Failed: Brain tissue too uniform {brain_std}")
            return False, "Brain tissue appears too uniform - may not be a valid MRI"
        
        total_pixels = arr.shape[0] * arr.shape[1]
        brain_pixels = len(brain_region)
        brain_ratio = brain_pixels / total_pixels
        
        if brain_ratio < 0.05:  
            print(f"Failed: Too little brain content {brain_ratio}")
            return False, "Insufficient brain tissue visible in the image"
        
        height, width = arr.shape[:2]
        aspect_ratio = max(width, height) / min(width, height)
        
        if aspect_ratio > 2.0:
            print(f"Failed: Aspect ratio {aspect_ratio}")
            return False, "Image aspect ratio doesn't match typical brain MRI format"
        
        non_background = gray[gray > 30]  # Pixels that aren't background
        
        if len(non_background) > 0:
            texture_variation = np.std(non_background)
            if texture_variation < 8:  # More lenient threshold
                print(f"Failed: Low texture variation {texture_variation}")
                return False, "Image lacks expected brain tissue complexity"
        else:
            print("Failed: No non-background pixels found")
            return False, "No visible brain structures detected"
        
        print("Passed all MRI validation tests")
        return True, "Valid brain MRI detected"
        
    except Exception as e:
        print(f"Error in validation: {e}")
        return False, "Error processing image"

def additional_content_filter(filepath):
    """
    Use ImageNet to filter out obvious non-medical images
    """
    try:
        x = tf.keras.preprocessing.image.load_img(filepath, target_size=(224, 224))
        x = tf.keras.preprocessing.image.img_to_array(x)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)

        preds = imgnet.predict(x)
        top_predictions = decode_predictions(preds, top=5)[0]
        
        invalid_objects = [
            "dog", "cat", "car", "flower", "butterfly", "bird", "person", "face",
            "building", "food", "fruit", "animal", "vehicle", "furniture", "tree",
            "plant", "sky", "water", "landscape", "portrait", "selfie", "phone",
            "computer", "book", "toy", "clothing", "sports", "music", "art"
        ]
        
        for pred in top_predictions:
            label = pred[1].lower()
            confidence = pred[2]
            
            if confidence > 0.1 and any(obj in label for obj in invalid_objects):
                return False, f"Image appears to contain: {pred[1]}"
        
        return True, "Passed content filter"
        
    except Exception as e:
        print(f"Error in content filtering: {e}")
        return True, "Content filter skipped due to error"


@app.route('/')
def upload_form():
    return render_template('index.html')


@app.route('/', methods=['POST'])
def upload_image():
    file = request.files['file']
    filepath = os.path.join("static/uploads", secure_filename(file.filename))
    file.save(filepath)

    is_valid_mri, mri_message = is_valid_brain_mri(filepath)
    if not is_valid_mri:
        return render_template(
            'index.html',
            type=f"❌ Invalid input: {mri_message}. Please upload a valid brain MRI scan.",
            image_file="uploads/" + os.path.basename(filepath)
        )

    passed_filter, filter_message = additional_content_filter(filepath)
    if not passed_filter:
        return render_template(
            'index.html',
            type=f"❌ Invalid input: {filter_message}. Please upload a brain MRI scan only.",
            image_file="uploads/" + os.path.basename(filepath)
        )

    print("All validation tests passed - proceeding with Alzheimer's analysis")

    try:
        imvar = tf.keras.preprocessing.image.load_img(filepath, target_size=(176, 176))
        imarr = tf.keras.preprocessing.image.img_to_array(imvar)
        imarr = np.expand_dims(imarr, axis=0)  # (1,176,176,3)

        # Load TF SavedModel
        model2 = TFSMLayer("model", call_endpoint="serving_default")

        outputs = model2(imarr)
        impred = list(outputs.values())[0].numpy()  # grab the first tensor

        # Confidence thresholding
        confidence = np.max(impred[0])
        classcount = np.argmax(impred[0]) + 1

        classdict = {
            1: "Mild Dementia",
            2: "Moderate Dementia", 
            3: "No Dementia, Patient is Safe",
            4: "Very Mild Dementia"
        }

        if confidence < 0.7:  # threshold
            prediction = f"⚠️ Uncertain prediction (confidence {confidence:.2f}). Please upload a clearer MRI scan."
        else:
            prediction = f"✅ {classdict[classcount]} (Confidence: {confidence:.2f})"

        print("Prediction:", prediction)

        return render_template(
            'index.html',
            type=prediction,
            image_file="uploads/" + os.path.basename(filepath)
        )
    
    except Exception as e:
        print(f"Error in model prediction: {e}")
        return render_template(
            'index.html',
            type="❌ Error processing the MRI scan. Please try uploading a different image.",
            image_file="uploads/" + os.path.basename(filepath)
        )


if __name__ == "__main__":
    os.makedirs("static/uploads", exist_ok=True)
    app.run(debug=True, port=5001, use_reloader=False)

  from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [28/Aug/2025 17:02:19] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [28/Aug/2025 17:02:19] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [28/Aug/2025 17:02:19] "GET /static/pexels_shvetsa.jpg HTTP/1.1" 200 -
