<a href="https://colab.research.google.com/github/DinurakshanRavichandran/Visio-Glance/blob/XAI/LIME_for_custom_one.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install lime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import random
import cv2
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from lime import lime_image
from skimage.segmentation import mark_boundaries

# Load the trained model
model = tf.keras.models.load_model('/content/custom_eye_disease_model.h5', compile=False)

# Select a random index from the validation dataset
random_index = random.randint(0, len(val_generator.filepaths) - 1)
sample_image_path = val_generator.filepaths[random_index]

# Load and preprocess the image
img = load_img(sample_image_path, target_size=(IMG_SIZE, IMG_SIZE))
img_array = img_to_array(img) / 255.0  # Normalize
img_array = np.expand_dims(img_array, axis=0)

# Print the selected index
print(f"Selected validation data index: {random_index}")

# Display the image
plt.imshow(img)
plt.title(f"Selected Image (Index: {random_index})")
plt.axis('off')
plt.show()

# Define LIME image explainer
explainer = lime_image.LimeImageExplainer()

# Function to predict probabilities for LIME
def predict_fn(images):
    images = np.array(images)
    return model.predict([images, images])  # Since the model has dual inputs

# Generate explanation for the selected image
explanation = explainer.explain_instance(
    img_array[0],  # LIME requires a single image (not batch)
    predict_fn,
    top_labels=4,  # Explain the 4 predicted classes
    hide_color=0,
    num_samples=1000  # Number of perturbed samples to generate
)

# Get the explanation for the predicted class
top_predicted_class = explanation.top_labels[0]

# Get the highlighted regions for positive influence
temp, mask = explanation.get_image_and_mask(
    top_predicted_class,
    positive_only=True,
    num_features=5,  # Number of superpixels to highlight
    hide_rest=False
)

# Convert mask to 8-bit format (0-255) for OpenCV
heatmap = np.uint8(255 * mask)

# Apply color mapping to create a heatmap
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # Apply colormap
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB

# Convert both images to float32 for blending
temp_float = temp.astype(np.float32)  # Convert image to float
heatmap_float = heatmap.astype(np.float32) / 255.0  # Normalize heatmap

# Blend the heatmap with the original image
alpha = 0.5  # Transparency level
heatmap_overlay = cv2.addWeighted(temp_float, 0.6, heatmap_float, alpha, 0)

# Extract only important regions
important_regions = temp * mask[:, :, np.newaxis]  # Keep only highlighted parts

# Plot all results
plt.figure(figsize=(12, 6))

# Original Image
plt.subplot(1, 3, 1)
plt.imshow(img)
plt.title("Original Image")
plt.axis('off')

# LIME with Mark Boundaries
plt.subplot(1, 3, 2)
plt.imshow(mark_boundaries(temp, mask))
plt.title(f"LIME Explanation (Class: {CATEGORIES[top_predicted_class]})")
plt.axis('off')

# Heatmap Overlay
plt.subplot(1, 3, 3)
plt.imshow(heatmap_overlay)
plt.title("Heatmap Overlay")
plt.axis('off')

plt.show()

# Plot important regions separately
plt.figure(figsize=(6, 6))
plt.imshow(important_regions)
plt.title("Only Important Regions")
plt.axis('off')
plt.show()
