In [1]:
# !pip install  scikit-learn pandas matplotlib tqdm gradio
# !pip install kagglehub
# !pip install torch torchvision
# !pip install opencv-python

In [2]:
# Core
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
import os

# Grad-CAM visualization
import cv2
import torchvision.transforms.functional as F

# Web interface
import gradio as gr

In [3]:
# Load trained model and set to eval
device = torch.device("cpu")

# Load model architecture
model = models.mobilenet_v3_large(pretrained=False)
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 2)
model.load_state_dict(torch.load("/Users/bharathreddy/Downloads/best_glaucoma_model.pth", map_location=device))
model = model.to(device)
model.eval()

# Define transforms for inference
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],  # ImageNet stats
                         [0.229, 0.224, 0.225])
])



In [4]:
from PIL import Image
import numpy as np
import cv2
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

def predict_with_explanation(image_pil):
    # Preprocess image
    image_tensor = val_transform(image_pil).unsqueeze(0).to(device)

    # Grad-CAM setup
    model.eval()
    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    def forward_hook(module, input, output):
        activations.append(output)

    # Hook the last convolutional layer
    target_layer = model.features[-1]
    forward_handle = target_layer.register_forward_hook(forward_hook)
    backward_handle = target_layer.register_backward_hook(backward_hook)

    output = model(image_tensor)
    pred_class = output.argmax(dim=1).item()
    confidence = torch.softmax(output, dim=1)[0][pred_class].item()

    model.zero_grad()
    output[0, pred_class].backward()

    # Remove hooks
    forward_handle.remove()
    backward_handle.remove()

    # Grad-CAM computation
    grad = gradients[0]
    act = activations[0]
    pooled_grad = torch.mean(grad, dim=[0, 2, 3])
    weighted_act = (act[0] * pooled_grad[:, None, None]).sum(dim=0)

    heatmap = weighted_act.cpu().detach().numpy()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap) + 1e-8
    heatmap = cv2.resize(heatmap, (image_pil.width, image_pil.height))
    heatmap = np.uint8(255 * heatmap)
    heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    img_np = np.array(image_pil.convert("RGB"))
    overlayed_img = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0)

    # Explanation
    if pred_class == 1:
        explanation = f"Prediction: **Glaucoma** (Confidence: {confidence*100:.2f}%)\n\nThe model focused on areas of optic nerve cupping or peripheral thinning to make this decision."
    else:
        explanation = f"Prediction: **Normal** (Confidence: {confidence*100:.2f}%)\n\nThe optic nerve appears healthy, and no significant indicators of glaucoma were detected."

    return Image.fromarray(overlayed_img), explanation


In [5]:
import cv2
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F

def generate_gradcam(model, image_tensor, target_class, device):
    model.eval()
    image_tensor = image_tensor.unsqueeze(0).to(device)

    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    def forward_hook(module, input, output):
        activations.append(output)

    target_layer = model.features[-1]
    forward_handle = target_layer.register_forward_hook(forward_hook)
    backward_handle = target_layer.register_backward_hook(backward_hook)

    output = model(image_tensor)
    class_score = output[0, target_class]
    model.zero_grad()
    class_score.backward()

    forward_handle.remove()
    backward_handle.remove()

    grad = gradients[0]
    act = activations[0]
    pooled_grad = torch.mean(grad, dim=[0, 2, 3])
    weighted_act = (act[0] * pooled_grad[:, None, None]).sum(dim=0)

    heatmap = weighted_act.cpu().detach().numpy()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap) + 1e-8

    return heatmap


def overlay_gradcam_on_image(image_tensor, heatmap):
    image_np = F.to_pil_image(image_tensor.cpu()).convert("RGB")
    image_np = np.array(image_np)

    heatmap_resized = cv2.resize(heatmap, (image_np.shape[1], image_np.shape[0]))
    heatmap_resized = np.uint8(255 * heatmap_resized)
    heatmap_color = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)

    overlayed_img = cv2.addWeighted(image_np, 0.6, heatmap_color, 0.4, 0)
    return overlayed_img


val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def predict_with_explanation(image_pil):
    model.eval()
    
    image_tensor = val_transform(image_pil).unsqueeze(0).to(device)
    image_tensor.requires_grad_()

    # Capture activations and gradients
    activations = []
    gradients = []

    def forward_hook(module, input, output):
        activations.append(output)
        output.register_hook(lambda grad: gradients.append(grad))

    target_layer = model.features[-1]
    handle = target_layer.register_forward_hook(forward_hook)

    output = model(image_tensor)
    pred_class = output.argmax(dim=1).item()

    model.zero_grad()
    class_score = output[0, pred_class]
    class_score.backward()

    handle.remove()

    if not gradients or not activations:
        explanation = "Grad-CAM failed: no gradients or activations collected."
        return np.array(image_pil.resize((224, 224))), explanation

    grad = gradients[0][0]  # shape: [C, H, W]
    act = activations[0][0]  # shape: [C, H, W]

    weights = grad.mean(dim=(1, 2))
    cam = torch.zeros_like(act[0])
    for i, w in enumerate(weights):
        cam += w * act[i]

    cam = cam.cpu().detach().numpy()
    cam = np.maximum(cam, 0)
    cam = cam / (cam.max() + 1e-8)
    cam = cv2.resize(cam, (224, 224))
    heatmap = np.uint8(255 * cam)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    # Original image (resized and converted to NumPy)
    img_np = np.array(image_pil.resize((224, 224)).convert("RGB"))
    overlay = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)

    label = "Glaucoma" if pred_class == 1 else "Normal"
    explanation = f"Prediction: {label}"

    return overlay, explanation


In [6]:
import gradio as gr
from PIL import Image

def inference_interface(image):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    overlay_img, explanation = predict_with_explanation(image)
    return overlay_img, explanation

demo = gr.Interface(
    fn=inference_interface,
    inputs=gr.Image(type="numpy", label="Upload Retinal Image"),
    outputs=[
        gr.Image(type="numpy", label="Grad-CAM Overlay"),
        gr.Textbox(label="Explanation")
    ],
    title="Glaucoma Detection AI",
    description="Upload a retinal image. The model will predict whether it is glaucoma or normal and show a Grad-CAM heatmap as explanation.",
    allow_flagging="never"
)




In [7]:
demo.close()
demo.launch(share=True)

* Running on local URL:  http://127.0.0.1:7860

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


