In [None]:
import tensorflow as tf
from tensorflow import keras
import cv2
import numpy as np
from PIL import Image, ImageOps
import gradio as gr
from tensorflow.keras import backend as K
from keras.saving import register_keras_serializable

# --- Constants ---
im_height = 256
im_width = 256
smooth = 100
labels = ['glioma_tumor', 'no_tumor', 'meningioma_tumor', 'pituitary_tumor']

# --- Custom Losses and Metrics ---
@register_keras_serializable()
def dice_loss(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return 1 - ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))

def dice_coef(y_true, y_pred):
    y_truef = K.flatten(y_true)
    y_predf = K.flatten(y_pred)
    And = K.sum(y_truef * y_predf)
    return ((2 * And + smooth) / (K.sum(y_truef) + K.sum(y_predf) + smooth))

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

def iou(y_true, y_pred):
    intersection = K.sum(y_true * y_pred)
    sum_ = K.sum(y_true + y_pred)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return jac

def jac_distance(y_true, y_pred):
    y_truef = K.flatten(y_true)
    y_predf = K.flatten(y_pred)
    return -iou(y_true, y_pred)

# --- Model Loading ---
def load_classification_model():
    model = keras.models.load_model('/content/BTCMODEL.h5')
    return model

def load_segmentation_model():
    custom_objects = {
        'dice_loss': dice_loss,
        'dice_coef': dice_coef,
        'dice_coef_loss': dice_coef_loss,
        'iou': iou,
        'jac_distance': jac_distance,
    }
    model_seg = keras.models.load_model('/content/BTSMODEL.keras', custom_objects=custom_objects)
    return model_seg

# Load models once
model = load_classification_model()
model_seg = load_segmentation_model()

# --- Prediction Functions ---
def upload_predict(image, model):
    image_np = np.asarray(image)
    img_resize = cv2.resize(image_np, dsize=(150, 150), interpolation=cv2.INTER_CUBIC)
    img_reshape = np.expand_dims(img_resize, axis=0)
    pred = model.predict(img_reshape)
    pred_class = np.argmax(pred)
    confidence = np.max(pred)
    return pred_class, confidence

def add_grid(image, grid_size=50):
    h, w = image.shape[:2]
    for i in range(0, h, grid_size):
        cv2.line(image, (0, i), (w, i), (128, 128, 128), 1)
    for i in range(0, w, grid_size):
        cv2.line(image, (i, 0), (i, h), (128, 128, 128), 1)
    return image

def process_image(image):
    pred_class, confidence = upload_predict(image, model)
    classification_text = f"The image is classified as: {labels[pred_class]} with confidence {confidence:.2f}"

    image2 = np.asarray(image)
    original_display = cv2.resize(image2, (im_height, im_width))
    original_display = add_grid(original_display)

    if labels[pred_class] == 'no_tumor':
        metrics_text = "No tumor detected, metrics not applicable."
        gallery_list = [
            (original_display, "Original Image - No Tumor Detected")
        ]
    else:
        img_seg = cv2.resize(image2, (im_height, im_width))
        img_norm = img_seg / 255.0
        input_seg = np.expand_dims(img_norm, axis=0)
        pred_seg = model_seg.predict(input_seg)

        pred_mask = (np.squeeze(pred_seg) > 0.5).astype(np.uint8)

        num_labels, labels_im = cv2.connectedComponents(pred_mask)
        largest_component = np.zeros_like(pred_mask)
        if num_labels > 1:
            max_area = 0
            max_label = 0
            for label in range(1, num_labels):
                area = np.sum(labels_im == label)
                if area > max_area:
                    max_area = area
                    max_label = label
            largest_component[labels_im == max_label] = 1
        else:
            largest_component = pred_mask

        tumor_area = np.sum(largest_component)
        if tumor_area > 0:
            # Find contours
            contours, _ = cv2.findContours(largest_component.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if len(contours) > 0:
                contour = max(contours, key=cv2.contourArea)
                area = cv2.contourArea(contour)
                perimeter = cv2.arcLength(contour, True)
                circularity = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0

                if len(contour) >= 5:
                    ellipse = cv2.fitEllipse(contour)
                    (center, axes, angle) = ellipse
                    major_axis = max(axes)
                    minor_axis = min(axes)
                    eccentricity = np.sqrt(1 - (minor_axis / major_axis) ** 2) if major_axis > 0 else 0
                else:
                    major_axis, minor_axis, eccentricity = 0, 0, 0

                hull = cv2.convexHull(contour)
                hull_area = cv2.contourArea(hull)
                solidity = area / hull_area if hull_area > 0 else 0

                # Centroid
                M = cv2.moments(largest_component.astype(np.uint8))
                cx = int(M["m10"] / M["m00"]) if M["m00"] != 0 else 0
                cy = int(M["m01"] / M["m00"]) if M["m00"] != 0 else 0

                metrics_text = f"""
                Tumor Metrics:
                - Area: {area:.0f} pixels
                - Perimeter: {perimeter:.2f} pixels
                - Circularity: {circularity:.2f}
                - Major Axis Length: {major_axis:.2f} pixels
                - Minor Axis Length: {minor_axis:.2f} pixels
                - Eccentricity: {eccentricity:.2f}
                - Solidity: {solidity:.2f}
                - Centroid: ({cx}, {cy}) pixels
                """

                # Visualization
                alpha = 0.5
                color = np.array([255, 255, 0], dtype=np.uint8)  # Yellow for tumor
                overlay = np.zeros_like(img_seg, dtype=np.uint8)
                overlay[largest_component == 1] = color
                blended_img = cv2.addWeighted(img_seg, 1 - alpha, overlay, alpha, 0)
                blended_img = add_grid(blended_img)

                # Draw centroid
                cv2.drawMarker(blended_img, (cx, cy), (255, 0, 0), markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2)

                gallery_list = [
                    (original_display, "Original Image"),
                    (blended_img, "Segmentation Overlay")
                ]
            else:
                metrics_text = "No tumor segmented, metrics not applicable."
                gallery_list = [
                    (original_display, "Original Image - No Tumor Segmented")
                ]
        else:
            metrics_text = "No tumor segmented, metrics not applicable."
            gallery_list = [
                (original_display, "Original Image - No Tumor Segmented")
            ]

    return classification_text, metrics_text, gallery_list

# --- Gradio UI ---
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="Upload the image to be classified"),
    outputs=[
        gr.Textbox(label="Classification Result"),
        gr.Textbox(label="Tumor Metrics"),
        gr.Gallery(label="Visualization")
    ],
    title="Brain Tumor Classification & Segmentation",
    description=("Upload a brain MRI image to classify the tumor type and view detailed tumor metrics "
                 "along with visualizations including the original image and segmentation overlay with centroid marked.")
)

iface.launch()