In [None]:
!pip install --upgrade google-generativeai pillow



In [None]:
!pip install -U google-generativeai



In [None]:
import os
import re
import numpy as np
import torch
import torch.nn.functional as F
import timm
import cv2
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet50, efficientnet_b0, densenet121
from torchvision.models.densenet import DenseNet
from google.colab import files, userdata
import google.generativeai as genai
from textwrap import fill

In [None]:
def apply_clahe(img):
    lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    cl = clahe.apply(l)
    merged = cv2.merge((cl, a, b))

    return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)

In [None]:
class PatchedDenseNet(DenseNet):
    def forward(self, x):
        features = self.features(x).clone()
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        return self.classifier(out)

In [None]:
def ensemble_predict(img_tensor_512, img_tensor_224, models_dict):
    probs = []
    for name, model in models_dict.items():
        input_tensor = img_tensor_224 if 'vit' in name or 'swin' in name else img_tensor_512
        with torch.no_grad():
            output = model(input_tensor.unsqueeze(0))
            prob = F.softmax(output, dim=1)
            probs.append(prob)
    avg_prob = torch.mean(torch.stack(probs), dim=0)
    pred_class = torch.argmax(avg_prob, dim=1).item()
    return pred_class, avg_prob.cpu().numpy()

In [None]:
def generate_gradcam(model, input_tensor, target_class, device):
    model.eval()
    input_tensor = input_tensor.unsqueeze(0).to(device)
    grads, activations = [], []

    target_layer = model.features.denseblock4.denselayer16.conv2

    def forward_hook(module, input, output): activations.append(output.clone())
    def backward_hook(module, grad_input, grad_output): grads.append(grad_output[0].clone())

    fw = target_layer.register_forward_hook(forward_hook)
    bw = target_layer.register_full_backward_hook(backward_hook)

    output = model(input_tensor)
    loss = output[0, target_class]
    model.zero_grad()
    loss.backward()
    fw.remove(); bw.remove()

    grads = grads[0].squeeze().detach().cpu().numpy()
    acts = activations[0].squeeze().detach().cpu().numpy()
    weights = grads.mean(axis=(1, 2))
    cam = np.sum(weights[:, None, None] * acts, axis=0)
    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (512, 512))

    return (cam - cam.min()) / (cam.max() + 1e-8)

In [None]:
def overlay_cam(cam, image):
    np_img = np.array(image.resize((512, 512))).astype(np.float32) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB).astype(np.float32) / 255
    overlay = np.clip(heatmap * 0.4 + np_img * 0.6, 0, 1)

    return Image.fromarray(np.uint8(overlay * 255))

In [None]:
def clean_gemini_output(text, width=100):
    text = re.sub(r"\*+", "", text)
    text = re.sub(r"\n{2,}", "\n", text)

    return fill(text.strip(), width=width)

In [None]:
def generate_ophthalmology_report(img_id, prediction, confidence, gemini_text):
    pred_label = f"Likely {'Glaucomatous' if prediction == 'Glaucoma' else 'Non-Glaucomatous'}"
    verdict_match = re.search(r"Final Verdict:\s*(Likely [\w\-]+)", gemini_text, re.IGNORECASE)
    final_verdict = verdict_match.group(1).strip() if verdict_match else pred_label
    explanation_cleaned = re.sub(r"Final Verdict:\s*Likely [\w\-]+\s*", "", gemini_text, flags=re.IGNORECASE).strip()
    explanation = clean_gemini_output(explanation_cleaned)

    report = f"""
Simulated Ophthalmologist Report

Image ID: {img_id}
Classification (Model): {pred_label}
Confidence Score: {round(confidence, 3)}

Explanation: {explanation}

Final Assessment: {final_verdict}

Disclaimer: This report is generated with the assistance of artificial intelligence and should
be considered a supplementary tool. It is not a substitute for professional medical diagnosis or advice.
Clinical correlation and further ophthalmological evaluation are recommended.
""".strip()

    return report

In [None]:
class GlaucomaAnalysisPipeline:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.models = self.load_models()
        genai.configure(api_key=userdata.get("GOOGLE_API_KEY"))

    def load_models(self):
        model_paths = {
            "resnet50": "/content/resnet50_best.pth",
            "efficientnet_b0": "/content/efficientnet_b0_best.pth",
            "densenet121": "/content/densenet121_best.pth",
            "vit_base_patch16_224": "/content/vit_base_patch16_224_best.pth",
            "swin_base_patch4_window7_224": "/content/swin_base_patch4_window7_224_best.pth"
        }

        models = {
            "resnet50": resnet50(pretrained=False, num_classes=2),
            "efficientnet_b0": efficientnet_b0(pretrained=False, num_classes=2),
            "densenet121": PatchedDenseNet(32, (6, 12, 24, 16), 64, 4, 0, 2),
            "vit_base_patch16_224": timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=2),
            "swin_base_patch4_window7_224": timm.create_model("swin_base_patch4_window7_224", pretrained=False, num_classes=2),
        }

        for name, model in models.items():
            model.load_state_dict(torch.load(model_paths[name], map_location=self.device))
            model.eval().to(self.device)

        return models

    def run(self):
        uploaded = files.upload()
        img_path = list(uploaded.keys())[0]
        img_id = os.path.splitext(os.path.basename(img_path))[0]
        orig_image = Image.open(img_path).convert("RGB")

        img_np = np.array(orig_image)
        clahe_np = apply_clahe(img_np)
        clahe_img = Image.fromarray(clahe_np)

        transform_512 = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        transform_224 = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        img_tensor_512 = transform_512(clahe_img).to(self.device)
        img_tensor_224 = transform_224(clahe_img).to(self.device)

        class_names = ['Normal', 'Glaucoma']
        predicted_class, prob_array = ensemble_predict(img_tensor_512, img_tensor_224, self.models)
        prediction = class_names[predicted_class]
        confidence = float(prob_array[0][predicted_class])

        print(f"Prediction: {prediction}")
        print(f"Confidence: {confidence:.3f}")

        cam = generate_gradcam(self.models['densenet121'], img_tensor_512, predicted_class, self.device)
        gradcam_image = overlay_cam(cam, orig_image)

        combined = Image.new("RGB", (1024, 512))
        combined.paste(orig_image.resize((512, 512)), (0, 0))
        combined.paste(gradcam_image, (512, 0))
        combined_path = "merged_densenet121.jpg"
        combined.save(combined_path)
        combined.show()

        prompt = f"""
Hello Gemini, you are simulating an ophthalmologist specialized in glaucoma detection using fundus photographs.

The ensemble model prediction is: **{prediction}**
The confidence score for this prediction is: **{confidence:.3f}**

Instructions:

1. Analyze the attached Grad-CAM overlay image, focusing on the regions it highlights.
2. Assess for glaucomatous signs: increased cup-to-disc ratio (CDR), neuroretinal rim thinning, optic disc hemorrhages.
3. Review the rest of the fundus for additional signs that support or contradict glaucoma.
4. Based on the findings, provide a clear final classification: 'Likely Glaucomatous' or 'Likely Non-Glaucomatous'.
5. Explain the reasoning behind your classification using concise clinical language.
6. Include a brief statement if your verdict differs from the model's prediction.
7. End with:
    Final Verdict: Likely Glaucomatous
    or
    Final Verdict: Likely Non-Glaucomatous

The entire response **must fit within 512 tokens**. Be clear, medically precise, and avoid unnecessary repetition.
"""

        with open(combined_path, "rb") as f:
            image_data = f.read()

        model_gemini = genai.GenerativeModel("gemini-1.5-flash")
        response = model_gemini.generate_content(
            [prompt, {"mime_type": "image/jpeg", "data": image_data}],
            generation_config={"max_output_tokens": 512, "temperature": 0.7}
        )

        final_report = generate_ophthalmology_report(img_id, prediction, confidence, response.text)

        return final_report

In [None]:
pipeline = GlaucomaAnalysisPipeline()



In [None]:
report = pipeline.run()

In [None]:
print(report)

Simulated Ophthalmologist Report

Image ID: drishtiGS_095
Classification (Model): Likely Non-Glaucomatous
Confidence Score: 0.988

Explanation: The Grad-CAM overlay highlights the optic disc.  Upon examination of the fundus photograph, the cup-
to-disc ratio appears slightly increased, although precise measurement is impossible without
calibrated imaging.  There is subtle suggestion of neuroretinal rim thinning, particularly in the
inferior aspect. No hemorrhages are evident. The overall appearance of the neuroretinal rim is not
significantly altered.  The peripheral retina appears unremarkable. While the CDR increase and
subtle thinning are suggestive of early glaucoma, they are not definitive.  The high confidence
score from the model (0.988) suggests a normal finding, which contrasts with my assessment.
Considering the subtle findings and the model's strong prediction, I lean towards a non-glaucomatous
classification.  Further investigation, including visual field testing and additi