## Imports

In [1]:
import cv2
import numpy as np
import sys
import io
from contextlib import redirect_stdout
from lime import lime_image
from skimage.segmentation import mark_boundaries
from tensorflow.keras.models import load_model
import gradio as gr
from PIL import Image
import matplotlib.pyplot as plt
import tempfile
import os
import random

## Supporting Functions

In [2]:
# Function to load and preprocess images
def load_and_preprocess_image(img_array):
    """
    Resize and normalize an image array, and add a batch dimension.

    Parameters:
        img_array (numpy.ndarray): The image array.

    Returns:
        numpy.ndarray: The processed image array with an added batch dimension.
    """
    if isinstance(img_array, str):
        # If img_array is a path, read the image using cv2
        img = cv2.imread(img_array)
    else:
        # If img_array is already an array, use it directly
        img = img_array

    if img is None:
        raise ValueError("Failed to load the image.")

    img = cv2.resize(img, (224, 224))  # Resize image to (224, 224)
    img = img / 255.0  # Normalize pixel values to the range [0, 1]
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    return img


# Function to create a perturbed version of the image based on the LIME mask
def perturb_image(img, mask):
    """
    Create a perturbed(modified) version of the image by setting selected superpixels to 0 based on the LIME mask.

    Parameters:
        img (numpy.ndarray): The original image array.
        mask (numpy.ndarray): The LIME mask indicating the superpixels to be perturbed.

    Returns:
        numpy.ndarray: The perturbed image array.
    """
    perturbed_img = img.copy()  # Create a copy of the original image
    perturbed_img[mask == 1] = 0  # Set the superpixels in the mask to 0

    return perturbed_img

# Function to display lime predictions
def explain_prediction_lime(model, img_path, num_samples=1000):
    """
    Explain a model's prediction using LIME (Local Interpretable Model-agnostic Explanations).

    Parameters:
    - model (keras.Model): The trained model.
    - img_path (str): Path to the image to be explained.
    - class_index (int): Index of the target class.
    - num_samples (int): Number of samples to generate for LIME.

    Returns:
    - tuple: Tuple containing the LIME segmented image and the LIME explanation.
    """
    # Load and preprocess the image
    img = load_and_preprocess_image(img_path)

    # Save the original stdout
    original_stdout = sys.stdout

    # Create a buffer to capture stdout
    stdout_buffer = io.StringIO()

    # Redirect stdout to the buffer
    with redirect_stdout(stdout_buffer):
        # Define the LIME explainer for image classification
        explainer = lime_image.LimeImageExplainer()
        # Explain the prediction
        explanation = explainer.explain_instance(img[0], model.predict, top_labels=1, hide_color=0, num_samples=num_samples)

    # Get LIME segmented image
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)

    # Create a heatmap from the LIME mask
    lime_segmented_img = mark_boundaries(temp / 2 + 0.5, mask)

    # Get perturbed image
    perturbed_img = perturb_image(img[0], mask)
    perturbed_segmented_img = mark_boundaries(perturbed_img / 2 + 0.5, mask)

   # Use lime_image's various functions
    segmentation = explanation.segments

    return lime_segmented_img, perturbed_segmented_img, segmentation

# Function to get random image paths for both classes
def get_random_image_paths(folder_path, num_images=4):
    """
    Get a list of randomly selected image file paths from a folder.

    Parameters:
        folder_path (str): Path to the folder containing the images.
        num_images (int): Number of image paths to randomly select.

    Returns:
        list: List of lists, each containing a randomly selected image file path.
    """
    image_paths = [os.path.join(folder_path, filename) for filename in os.listdir(folder_path) if filename.endswith(('.jpg', '.jpeg', '.png'))]

    if len(image_paths) < num_images:
        raise ValueError("Not enough images in the folder.")

    random_image_paths = random.sample(image_paths, num_images)
    return [[path] for path in random_image_paths]

## Gradio Interface

### Defining Gradio function

In [3]:
def gr_interface(image):
    preprocessed_img = load_and_preprocess_image(image)
    predictions = model.predict(preprocessed_img)
    threshold = 0.5  # Adjust the threshold as needed
    binary_prediction = 1 if predictions[0][0] > threshold else 0
    pred_class = "Normal" if binary_prediction == 0 else "Pneumonia"
    
    # Get the prediction accuracy
    if binary_prediction == 1:
        confidence  = predictions[0][0] * 100
    else:
        confidence  = (1 - predictions[0][0]) * 100

    lime_segmented_img, perturbed_segmented_img, segmentation = explain_prediction_lime(model, image)

    # Save the segmentation image to a temporary file and change cmap to viridis
    temp_file_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
    plt.imsave(temp_file_path, segmentation, cmap='viridis')

    return (
        f"Prediction: {pred_class}\n confidence: {confidence :.2f}%",
        lime_segmented_img,
        perturbed_segmented_img,
        gr.Image(temp_file_path)
    )

## Initialisation

In [4]:
# Enter path of the best performing model
model = load_model("Saved_models/InceptionV3_8_100.h5")

# Setting paths to each class in our test dataset
normal_examples = r"chest_xray\test\NORMAL"
pneumonia_examples = r"chest_xray\test\PNEUMONIA"

# Definine example image limit
eg_limit = 12

# Defining input/output parameters for our gradio function
image_input = gr.Image(sources = ["upload", "clipboard"])
text_output = gr.Text(label="prediction_output", placeholder="The Prediction along with the confidence will be displayed here", show_copy_button=True)
image_output_1 = gr.Image(type="pil", label="Lime Segmented Image", height=425, width=425)
image_output_2 = gr.Image(type="pil", label="Perturbed Image", height=425, width=425)
image_output_3 = gr.Image(type="pil", label="Segmentation", height=425, width=425)
pred_output = [text_output, image_output_1, image_output_2, image_output_3]

# Defining title, description and article

title_text = "Chest X-Ray Pneumonia Prediction"

desc_text = """
Upload a chest X-ray image to see the model's predictions and explanations.\n
You may select an example from below for a quick demo.
"""

article_text = (
    "<div style='text-align: center; font-size: 20px;'>"
    "The first page contains examples of normal X-rays. <br>"
    "The second page contains examples of X-rays with pneumonia. <br>"
    "</div>"
    "<div style='text-align: center; font-size: 15px;'>"
    "Connect with me on "
    "<a href='https://github.com/Kishor-Gulati' target='_blank'>GitHub</a>, "
    "<a href='https://www.linkedin.com/in/kishor-gulati' target='_blank'>LinkedIn</a>, "
    "<a href='https://leetcode.com/kanushgulati' target='_blank'>Leetcode</a>, "
    "<a href='https://www.hackerrank.com/profile/kanushgulati' target='_blank'>Hackerrank</a>, "
    "and <a href='mailto:kanushgulati@gmail.com'>G-mail</a>."
    "</div>"
)

## Launch

In [5]:
# Ecery time we run this cell, random images will get populated in the Examples
normal_examples_paths = get_random_image_paths(normal_examples, num_images=eg_limit)
pneumonia_examples_paths = get_random_image_paths(pneumonia_examples, num_images=eg_limit)

examples = normal_examples_paths + pneumonia_examples_paths

# Creating an instance of gradio
iface = gr.Interface(
    fn=gr_interface,
    inputs=image_input,
    outputs=pred_output,
    live=True,  # Enable live updates
    title=title_text,
    description=desc_text,
    theme=gr.themes.Glass(), # theme=gr.themes.Soft()/Soft()
    examples = examples,
    examples_per_page = eg_limit,
    article = article_text,
    thumbnail = "thumbnail.jpg"
)


# Launch the Gradio interface
iface.launch(share=True, height = 1600, favicon_path = "favicon.png")

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://f9f3f9d0c9a64b15bb.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


