In [1]:
!pip install git+https://github.com/jacobgil/pytorch-grad-cam.git

Collecting git+https://github.com/jacobgil/pytorch-grad-cam.git
  Cloning https://github.com/jacobgil/pytorch-grad-cam.git to /tmp/pip-req-build-j0b8r527
  Running command git clone --filter=blob:none --quiet https://github.com/jacobgil/pytorch-grad-cam.git /tmp/pip-req-build-j0b8r527
  Resolved https://github.com/jacobgil/pytorch-grad-cam.git to commit 781dbc0d16ffa95b6d18b96b7b829840a82d93d1
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting ttach (from grad-cam==1.5.5)
  Downloading ttach-0.0.3-py3-none-any.whl.metadata (5.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.7.1->grad-cam==1.5.5)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.7.1->grad-cam==1.5.5)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux201

In [2]:
import os
import torch
import timm
import cv2
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
from google.colab import files
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import numpy as np
from torchvision.models import resnet50, efficientnet_b0, densenet121
from torchvision.models.densenet import DenseNet
import json
import zipfile

In [3]:
class PatchedDenseNet(DenseNet):
    def forward(self, x):
        features = self.features(x)
        features = features.clone()
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

In [4]:
model_paths = {
    "resnet50": "/content/resnet50_best.pth",
    "efficientnet_b0": "/content/efficientnet_b0_best.pth",
    "densenet121": "/content/densenet121_best.pth",
    "vit_base_patch16_224": "/content/vit_base_patch16_224_best.pth",
    "swin_base_patch4_window7_224": "/content/swin_base_patch4_window7_224_best.pth"
}

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models = {}

In [6]:
models['resnet50'] = resnet50(pretrained=False, num_classes=2)
models['resnet50'].load_state_dict(torch.load(model_paths['resnet50'], map_location=device))
models['resnet50'].eval().to(device)

models['efficientnet_b0'] = efficientnet_b0(pretrained=False, num_classes=2)
models['efficientnet_b0'].load_state_dict(torch.load(model_paths['efficientnet_b0'], map_location=device))
models['efficientnet_b0'].eval().to(device)

# models['densenet121'] = densenet121(pretrained=False, num_classes=2)
# models['densenet121'].load_state_dict(torch.load(model_paths['densenet121'], map_location=device))
# models['densenet121'].eval().to(device)

models['densenet121'] = PatchedDenseNet(
    growth_rate=32,
    block_config=(6, 12, 24, 16),
    num_init_features=64,
    bn_size=4,
    drop_rate=0,
    num_classes=2
)
models['densenet121'].load_state_dict(torch.load(model_paths['densenet121'], map_location=device))
models['densenet121'].eval().to(device)

models['vit_base_patch16_224'] = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=2)
models['vit_base_patch16_224'].load_state_dict(torch.load(model_paths['vit_base_patch16_224'], map_location=device))
models['vit_base_patch16_224'].eval().to(device)

models['swin_base_patch4_window7_224'] = timm.create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=2)
models['swin_base_patch4_window7_224'].load_state_dict(torch.load(model_paths['swin_base_patch4_window7_224'], map_location=device))
models['swin_base_patch4_window7_224'].eval().to(device)



SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (layers): Sequential(
    (0): SwinTransformerStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path1): Identity()
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU(approximate='none')
            (

In [7]:
uploaded = files.upload()
img_path = list(uploaded.keys())[0]
orig_image = Image.open(img_path).convert("RGB")
img_id = os.path.splitext(os.path.basename(img_path))[0]

Saving EyePACS-DEV-RG-1.jpg to EyePACS-DEV-RG-1.jpg


In [8]:
transform_512 = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_224 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [9]:
img_tensor_512 = transform_512(orig_image).to(device)
img_tensor_224 = transform_224(orig_image).to(device)

In [10]:
def ensemble_predict_separate_resolutions(img_tensor_512, img_tensor_224, models_dict, class_names=['Normal', 'Glaucoma']):
    probs = []
    for name, model in models_dict.items():
        if 'vit' in name or 'swin' in name:
            input_tensor = img_tensor_224
        else:
            input_tensor = img_tensor_512
        with torch.no_grad():
            output = model(input_tensor.unsqueeze(0))
            prob = F.softmax(output, dim=1)
            probs.append(prob)
    avg_prob = torch.mean(torch.stack(probs), dim=0)
    pred_class = torch.argmax(avg_prob, dim=1).item()

    return class_names[pred_class], avg_prob.cpu().numpy()

In [11]:
prediction, prob = ensemble_predict_separate_resolutions(img_tensor_512, img_tensor_224, models)
predicted_class = int(np.argmax(prob))
print("Prediction:", prediction)
print("Probabilities:", prob)
print("Class", predicted_class)

Prediction: Glaucoma
Probabilities: [[0.01428036 0.98571956]]
Class 1


In [12]:
# def generate_gradcam(resnet_model, img_tensor, orig_image):
#     device = next(resnet_model.parameters()).device
#     resnet_model.eval()
#     input_tensor = img_tensor.unsqueeze(0).to(device)

#     gradients = []
#     activations = []

#     target_layer = resnet_model.layer4[2].conv3

#     def forward_hook(module, input, output):
#         activations.append(output.detach())

#     def backward_hook(module, grad_input, grad_output):
#         gradients.append(grad_output[0].detach())

#     fw_hook = target_layer.register_forward_hook(forward_hook)
#     bw_hook = target_layer.register_full_backward_hook(backward_hook)

#     output = resnet_model(input_tensor)
#     pred_class = output.argmax(dim=1).item()

#     loss = output[0, pred_class]
#     resnet_model.zero_grad()
#     loss.backward()

#     fw_hook.remove()
#     bw_hook.remove()

#     grads = gradients[0].squeeze().cpu().numpy()
#     acts = activations[0].squeeze().cpu().numpy()

#     weights = grads.mean(axis=(1, 2))

#     cam = np.zeros(acts.shape[1:], dtype=np.float32)
#     for i, w in enumerate(weights):
#         cam += w * acts[i, :, :]

#     cam = np.maximum(cam, 0)

#     cam = cam - cam.min()
#     cam = cam / (cam.max() + 1e-8)

#     cam = cv2.resize(cam, (512, 512))

#     orig_resized = orig_image.resize((512, 512))
#     np_img = np.array(orig_resized).astype(np.float32) / 255

#     heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
#     heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
#     heatmap = np.float32(heatmap) / 255

#     overlay = heatmap * 0.4 + np_img * 0.6
#     overlay = np.clip(overlay, 0, 1)

#     cam_img = Image.fromarray(np.uint8(overlay * 255))

#     return cam_img

def generate_gradcam(model, input_tensor, target_class, device, model_name):
    model.eval()
    input_tensor = input_tensor.unsqueeze(0).to(device)

    gradients = []
    activations = []

    if model_name == 'resnet50':
        target_layer = model.layer4[2].conv3
    elif model_name == 'efficientnet_b0':
        target_layer = model.features[-1][0]
    elif model_name == 'densenet121':
        target_layer = model.features.denseblock4.denselayer16.conv2
    else:
        print(f"Grad-CAM not supported for {model_name}")
        return None

    def forward_hook(module, input, output):
        activations.append(output.clone())

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0].clone())

    fw = target_layer.register_forward_hook(forward_hook)
    bw = target_layer.register_full_backward_hook(backward_hook)

    output = model(input_tensor)
    loss = output[0, target_class]
    model.zero_grad()
    loss.backward()

    fw.remove()
    bw.remove()

    grads = gradients[0].squeeze().detach().cpu().numpy()
    acts = activations[0].squeeze().detach().cpu().numpy()

    weights = grads.mean(axis=(1, 2))
    cam = np.zeros(acts.shape[1:], dtype=np.float32)

    for i, w in enumerate(weights):
        cam += w * acts[i, :, :]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (input_tensor.size(3), input_tensor.size(2)))
    cam = (cam - cam.min()) / (cam.max() + 1e-8)

    return cam

In [13]:
def overlay_cam_on_image(cam, orig_img):
    np_img = np.array(orig_img.resize((cam.shape[1], cam.shape[0]))).astype(np.float32) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255
    overlay = heatmap * 0.4 + np_img * 0.6
    overlay = np.clip(overlay, 0, 1)
    return Image.fromarray(np.uint8(overlay * 255))

In [14]:
cam_resnet = generate_gradcam(models['resnet50'], img_tensor_512, predicted_class, device, 'resnet50')
cam_densenet = generate_gradcam(models['densenet121'], img_tensor_512, predicted_class, device, 'densenet121')
cam_efficientnet = generate_gradcam(models['efficientnet_b0'], img_tensor_512, predicted_class, device, 'efficientnet_b0')

overlay_resnet = overlay_cam_on_image(cam_resnet, orig_image)
overlay_densenet = overlay_cam_on_image(cam_densenet, orig_image)
overlay_efficientnet = overlay_cam_on_image(cam_efficientnet, orig_image)

In [15]:
resized_orig = orig_image.resize((512, 512))
overlay_resnet = overlay_resnet.resize((512, 512))
overlay_densenet = overlay_densenet.resize((512, 512))
overlay_efficientnet = overlay_efficientnet.resize((512, 512))

combined = Image.new("RGB", (512 * 4, 512))
combined.paste(resized_orig, (0, 0))
combined.paste(overlay_resnet, (512, 0))
combined.paste(overlay_densenet, (1024, 0))
combined.paste(overlay_efficientnet, (1536, 0))
combined.save("gradcam_all_models.jpg")
combined.show()

In [16]:
from PIL import ImageDraw, ImageFont

def add_label(image, label, font_size=24):
    img = image.convert("RGBA").copy()
    draw = ImageDraw.Draw(img)

    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
    except:
        font = ImageFont.load_default()

    text_bbox = draw.textbbox((0, 0), label, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]

    padding = 10
    bar_height = text_height + 2 * padding

    header = Image.new("RGBA", (img.width, bar_height), (0, 0, 0, 180))
    img.paste(header, (0, 0), header)

    text_x = (img.width - text_width) // 2
    text_y = padding
    draw.text((text_x, text_y), label, font=font, fill=(255, 255, 255, 255))

    return img.convert("RGB")

In [17]:
os.makedirs("llava_inputs/images", exist_ok=True)
os.makedirs("llava_inputs/prompts", exist_ok=True)

In [18]:
resized_orig = orig_image.resize((512, 512))
resized_orig.save(f"llava_inputs/images/{img_id}_original.jpg")

combined = Image.new("RGB", (512 * 4, 512))
combined.paste(resized_orig, (0, 0))

In [19]:
model_labels = {
    "resnet50": "ResNet50",
    "efficientnet_b0": "EfficientNetB0",
    "densenet121": "DenseNet121",
}

In [20]:
x_pos = 1
for model_name, label in model_labels.items():
    model = models[model_name]
    cam = generate_gradcam(model, img_tensor_512, predicted_class, device, model_name)
    overlay = overlay_cam_on_image(cam, orig_image)
    labeled_overlay = add_label(overlay, label)
    labeled_overlay.save(f"llava_inputs/images/{img_id}_{model_name}_cam.jpg")
    combined.paste(labeled_overlay, (512 * x_pos, 0))
    x_pos += 1

combined.save("llava_inputs/images/combined_with_header.jpg")

In [21]:
combined_path = "llava_inputs/images/combined_gradcam.jpg"
combined.save(combined_path)
combined.show()

In [22]:
prob_rounded = [round(float(x), 3) for x in prob.flatten()]
prob_str = json.dumps(prob_rounded)

In [23]:
prompt_zero_shot = f"""
This image shows a fundus photograph and Grad-CAM overlays (left to right): Original Fundus Image, ResNet50 Grad-CAM, DenseNet121 Grad-CAM, EfficientNetB0 Grad-CAM.

The ensemble prediction is: **{prediction}**.
Probabilities: {prob_rounded}

Please analyze the highlighted regions in the combined image and explain if they medically justify the prediction.
"""

In [24]:
prompt_one_shot = f"""
- Image: A glaucoma fundus photograph with Grad-CAM highlighting the optic disc and cup region.
- Prediction: Glaucoma.
- Explanation: The highlighted regions correspond to increased cup-to-disc ratio and rim thinning typical of glaucoma.

Now analyze the combined image showing:

- Fundus and Grad-CAM overlays (Original Fundus Image, ResNet50 Grad-CAM, DenseNet121 Grad-CAM, EfficientNetB0 Grad-CAM).

Prediction: **{prediction}**
Probabilities: {prob_rounded}

Provide your clinical interpretation.
"""

In [25]:
prompt_few_shot = f"""
Example 1 (Glaucoma):

- Image: Fundus photo with Grad-CAM highlighting glaucoma features.
- Prediction: Glaucoma.
- Explanation: The highlighted areas correspond to increased cup-to-disc ratio and rim thinning.

Example 2 (Normal):

- Image: Fundus photo with Grad-CAM showing healthy optic nerve head.
- Prediction: Normal.
- Explanation: No signs of glaucomatous damage in highlighted regions.

Now analyze this combined image containing:

- Original Fundus Image, ResNet50 Grad-CAM, DenseNet121 Grad-CAM, EfficientNetB0 Grad-CAM (left to right).

Prediction: **{prediction}**
Probabilities: {prob_rounded}

Provide a detailed diagnostic explanation.
"""

In [26]:
with open("llava_inputs/prompts/prompt_zero_shot.txt", "w") as f:
    f.write(prompt_zero_shot.strip())

with open("llava_inputs/prompts/prompt_one_shot.txt", "w") as f:
    f.write(prompt_one_shot.strip())

with open("llava_inputs/prompts/prompt_few_shot.txt", "w") as f:
    f.write(prompt_few_shot.strip())

In [27]:
metadata = {
    "image_id": img_id,
    "ensemble_prediction": prediction,
    "probabilities": prob_rounded,
    "class_index": predicted_class,
    "combined_image_path": combined_path
}

In [28]:
with open("llava_inputs/metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

with open("llava_inputs/ensemble_prediction.txt", "w") as f:
    f.write(f"{prediction} - {prob_rounded}")


In [29]:
zip_path = "llava_inputs.zip"
with zipfile.ZipFile(zip_path, 'w') as zipf:
    for root, _, files in os.walk("llava_inputs"):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, "llava_inputs")
            zipf.write(file_path, arcname=arcname)

print(f"✅ All files saved and zipped at: {zip_path}")

✅ All files saved and zipped at: llava_inputs.zip
