In [None]:
import os

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision.transforms import transforms

from models.helper.dataLoader import KermanyXRayImageFolder
from models.resnet18.model import XRayResNet18
from models.resnet50.model import XRayResNet50

In [None]:
def apply_gradcam(input_tensor, model, target_layer):
    """
    Apply Grad-CAM to the given tensor and model.
    """
    model.eval()

    # Register hooks
    gradients = []
    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    feature_maps = None
    def forward_hook(module, input, output):
        nonlocal feature_maps
        feature_maps = output

    backward_handle = target_layer.register_backward_hook(backward_hook)
    forward_handle = target_layer.register_forward_hook(forward_hook)

    # Forward pass
    output = model(input_tensor)
    
    print(output)

    # Target for backprop
    one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(output.device)
    one_hot_output[0][output.argmax()] = 1

    # Backward pass
    model.zero_grad()
    output.backward(gradient=one_hot_output, retain_graph=True)

    # Remove hooks
    backward_handle.remove()
    forward_handle.remove()

    # Get gradients and feature maps
    gradients = gradients[0]

    # Weighted combination of feature maps
    weights = torch.mean(gradients, [2, 3])
    cam = torch.sum(weights.unsqueeze(-1).unsqueeze(-1) * feature_maps, dim=1)

    # Normalize and convert to image format
    cam = cam.detach().cpu().numpy()[0]
    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224, 224))
    cam = cam - np.min(cam)
    cam = cam / np.max(cam)

    return cam

In [None]:
CURR_MODEL = "resnet50"
MODEL_PATH = CURR_MODEL + "/cache/" + CURR_MODEL + ".pth"
# MODEL_PARAMS = [128, 6, 5, 0.00011038211496918009, 5.916018460561741e-05, 0.3952291290587717, 256, 256, 1]
MODEL_PARAMS = [32, 6, 3, 0.00015, 0.00005, 0.7, 64, 256, 0]
batch_size, freez_ep, unfreez_ep, lr_1, lr_2, dropout_head, hl1, hl2, num_hl = MODEL_PARAMS

In [None]:
if os.path.isfile(MODEL_PATH):
    print("Found pretrained model in cache")
    print("")
    if CURR_MODEL == "resnet18":
        model = XRayResNet18(hl1, hl2, num_hl, dropout_head)
    elif CURR_MODEL == "resnet50":
        model = XRayResNet50(hl1, hl2, num_hl, dropout_head)
    model.load_state_dict(torch.load(MODEL_PATH))
else:
    from models.helper.training import step_train_model
    print("No pretrained model in cache found")
    print("")
    score, model = step_train_model(
        MODEL_PARAMS,
        pt_model=CURR_MODEL,
        return_model=True)


In [None]:
torch.save(model.state_dict(), MODEL_PATH)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

val_dataset = KermanyXRayImageFolder('../data/val', transform=transform)

X = val_dataset.data
Y = val_dataset.targets



In [None]:
# Choose an image from the validation dataset
input_image, _, _ = val_dataset[0]  # Replace 0 with the index of the image you want to visualize
input_tensor = input_image.unsqueeze(0).to(device)

# Assuming target_layer is the last convolutional layer of the ResNet model
target_layer = model.model.layer4[-1]  # Adjust this according to your model structure
print(model.model)
# Compute Grad-CAM
cam = apply_gradcam(input_tensor, model, target_layer)

# Display the Grad-CAM heatmap
plt.imshow(cam, cmap='jet')
plt.axis('off')
plt.show()

In [None]:
from torch.utils.data import DataLoader


def overlay_heatmap_on_image(heatmap, original_image, alpha=0.6):
    """
    Overlay the Grad-CAM heatmap on the original image.
    """
    # Convert heatmap to RGB
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    
    if original_image.dim() == 4 and original_image.size(0) == 1:
        original_image = original_image.squeeze(0)
    
    # Convert tensor image to numpy array
    original_image = original_image.permute(1, 2, 0).cpu().numpy()
    original_image = np.uint8(255 * original_image)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR)

    # Overlay the heatmap on the original image
    overlayed_image = cv2.addWeighted(original_image, alpha, heatmap, 1 - alpha, 0)

    return overlayed_image

true_c = 0
false_c = 0

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)


# Choose an image from the validation dataset
for input_image, lab_tar, _ in val_loader:
    if lab_tar == 0:
        if true_c > 5:
            continue
        true_c += 1
    else:
        if false_c > 5:
            continue
        false_c += 1
    
    print(val_dataset.classes[lab_tar[0]])
    
    tensor_img = input_image.to(device)
    
    # Compute Grad-CAM
    cam = apply_gradcam(tensor_img, model, target_layer)
    
    # Overlay heatmap on the original image
    overlayed_image = overlay_heatmap_on_image(cam, input_image)
    
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    
    if input_image.dim() == 4 and input_image.size(0) == 1:
        input_image = input_image.squeeze(0)

    # Original image
    ax[0].imshow(input_image.permute(1, 2, 0))
    ax[0].set_title("Original Image")
    ax[0].axis('off')
    
    # Image with Grad-CAM
    ax[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB))
    ax[1].set_title("Grad-CAM Overlay")
    ax[1].axis('off')
    
    plt.show()