In [3]:
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
import os
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt

# Load the trained model
model = load_model("models/multi_task_chest_xray_model.h5")

# Class labels
disease_names = ['COVID19','PNEUMONIA','NORMAL']
severity_names = ['Mild', 'Moderate', 'Severe']

# Grad-CAM utility functions
def get_last_conv_layer(model):
    for layer in reversed(model.layers):
        if isinstance(layer, tf.keras.layers.Conv2D):
            return layer.name
    raise ValueError("No convolutional layer found in the model.")

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None, output_head='disease_output'):
    grad_model = tf.keras.models.Model(
        [model.inputs], 
        [model.get_layer(last_conv_layer_name).output, model.get_layer(output_head).output]
    )
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]
    grads = tape.gradient(class_channel, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

def apply_heatmap(img_path, heatmap, alpha=0.4):
    img = cv2.imread(img_path)
    img = cv2.resize(img, (224, 224))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = cv2.addWeighted(img, 1 - alpha, heatmap_color, alpha, 0)
    return superimposed_img

# Prediction + Grad-CAM function
def predict_and_explain(img_path):
    # Preprocess
    image = tf.io.read_file(img_path)
    image = tf.image.decode_image(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.expand_dims(image, 0)

    # Predict
    disease_pred, severity_pred = model.predict(image)
    disease_label = np.argmax(disease_pred[0])
    severity_label = np.argmax(severity_pred[0])

    disease_result = disease_names[disease_label]
    severity_result = severity_names[severity_label] if disease_result != 'NORMAL' else "N/A"

    # Grad-CAM
    conv_layer = get_last_conv_layer(model)
    heatmap_disease = make_gradcam_heatmap(image, model, conv_layer, output_head='disease_output')
    heatmap_severity = make_gradcam_heatmap(image, model, conv_layer, output_head='severity_output')

    # Apply heatmap
    heatmap_img_disease = apply_heatmap(img_path, heatmap_disease)
    heatmap_img_severity = apply_heatmap(img_path, heatmap_severity)

    return disease_result, severity_result, heatmap_img_disease, heatmap_img_severity

# Gradio UI
demo = gr.Interface(
    fn=predict_and_explain,
    inputs=gr.Image(type="filepath", label="Upload Chest X-ray"),
    outputs=[
        gr.Text(label="Predicted Disease"),
        gr.Text(label="Predicted Severity"),
        gr.Image(label="Grad-CAM: Disease"),
        gr.Image(label="Grad-CAM: Severity"),
    ],
    title="Chest X-ray Classification with Severity Prediction and Explainability",
    description="Upload a Chest X-ray to classify it as COVID-19, Pneumonia, or Normal and see severity (if applicable) with Grad-CAM heatmaps."
)

demo.launch()



Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


