In [None]:
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings('ignore')
from torchvision import models, transforms
from torchvision import models
import numpy as np
import cv2
import requests
from pytorch_grad_cam import GradCAMPlusPlus, GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, \
    deprocess_image, \
    preprocess_image
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

In [None]:
# Define CLAHE Transform class
class CLAHETransform:
    def __init__(self, clip_limit=0.10, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)

    def __call__(self, img):
        # Convert PIL image to numpy array if necessary
        if isinstance(img, Image.Image):
            img = np.array(img)
        
        # If the image is RGB (3 channels), convert it to LAB color space
        if img.ndim == 3:
            lab_img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_img)

            # Apply CLAHE only to the L (lightness) channel
            l_channel = self.clahe.apply(l_channel)

            # Merge back and convert to RGB
            lab_img = cv2.merge((l_channel, a_channel, b_channel))
            img = cv2.cvtColor(lab_img, cv2.COLOR_LAB2RGB)
        else:
            # If the image is grayscale, apply CLAHE directly
            img = self.clahe.apply(img)

        return img

In [None]:
class DenseNet121(nn.Module):
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        # Use the latest ImageNet weights
        self.densenet121 = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.densenet121(x)

In [None]:
N_LABELS = 14
# Load the saved model
best_model = DenseNet121(out_size=N_LABELS)
best_model.load_state_dict(torch.load('path_to_model.pth'))
best_model = best_model.to(device)
best_model.eval()

In [None]:
# Load and preprocess the image for DenseNet121
image_url = "path_to_directory"
img = np.array(Image.open(image_url))
clahe_transform = CLAHETransform(clip_limit=0.34, tile_grid_size=(8, 8))
img = clahe_transform(img)
img = cv2.resize(img, (256, 256))
img = np.float32(img) / 255
img = np.stack([img] * 3, axis=-1)
input_tensor = torch.from_numpy(img[np.newaxis, ...]).permute(0, 3, 1, 2)

targets = [ClassifierOutputTarget(13)]
target_layers = [best_model.densenet121.features.denseblock4]
with GradCAMPlusPlus(model=best_model, target_layers=target_layers) as cam:
    grayscale_cams = cam(input_tensor=input_tensor, targets=targets)
    cam_image = show_cam_on_image(img, grayscale_cams[0, :], use_rgb=True)

# Convert the grayscale CAM to a three-channel image
cam = np.uint8(255 * grayscale_cams[0, :])
cam = cv2.applyColorMap(cam, cv2.COLORMAP_TWILIGHT_SHIFTED)  # Apply color map for better visualization

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original Image
axes[0].imshow(img)
axes[0].axis('off')
axes[0].set_title('(a) Original Image')

# CAM Heatmap
axes[1].imshow(cam, cmap='plasma')
axes[1].axis('off')
axes[1].set_title('(b) Grad-CAM Heatmap')

# Overlay Image
axes[2].imshow(cam_image)
axes[2].axis('off')
axes[2].set_title('(c) Overlay Image')


plt.tight_layout()
plt.show()

# Save the final figure
# Save the overlay image
overlay_image_path =  "path_to_save_directory"
cv2.imwrite(overlay_image_path, cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR))

##### Lateral

In [None]:
class ResNet50(nn.Module):
    def __init__(self, out_size):
        super(ResNet50, self).__init__()
        # Use the latest ImageNet weights for ResNet50
        self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        num_ftrs = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()  # Assuming you're doing a binary classification, adjust as needed
        )

    def forward(self, x):
        return self.resnet50(x)

In [None]:
N_LABELS = 14
# Load the saved model
lateral_model = ResNet50(out_size=N_LABELS)
lateral_model.load_state_dict(torch.load('path_to_best_model.pth'))
lateral_model = lateral_model.to(device)
lateral_model.eval()

In [None]:
# Load and preprocess the image for ResNet50
image_url = "path_to_directory"
img = Image.open(image_url)
clahe_transform = CLAHETransform(clip_limit=0.34, tile_grid_size=(8, 8))
img = clahe_transform(img)
img = cv2.resize(img, (256, 256))
img = np.float32(img) / 255
if img.ndim == 2:  # If grayscale, convert to RGB
    img = np.stack([img] * 3, axis=-1)
input_tensor = torch.from_numpy(img[np.newaxis, ...]).permute(0, 3, 1, 2).to(device)

# Define the target layer and targets for ResNet50
targets = [ClassifierOutputTarget(13)]  # Replace 13 with the target class index as needed
target_layers = [lateral_model.resnet50.layer4[-1]]

# Use GradCAM++ to generate the visualization for ResNet50
with GradCAMPlusPlus(model=lateral_model, target_layers=target_layers) as cam:
    grayscale_cams = cam(input_tensor=input_tensor, targets=targets)
    cam_image = show_cam_on_image(img, grayscale_cams[0, :], use_rgb=True)

# Convert the grayscale CAM to a three-channel image
cam = np.uint8(255 * grayscale_cams[0, :])
cam = cv2.applyColorMap(cam, cv2.COLORMAP_TWILIGHT_SHIFTED)  # Apply color map for better visualization

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original Image
axes[0].imshow(img)
axes[0].axis('off')
axes[0].set_title('(a) Original Image')

# CAM Heatmap
axes[1].imshow(cam, cmap='plasma')
axes[1].axis('off')
axes[1].set_title('(b) Grad-CAM Heatmap')

# Overlay Image
axes[2].imshow(cam_image)
axes[2].axis('off')
axes[2].set_title('(c) Overlay Image')

plt.tight_layout()
plt.show()

# Save the final figure
fig.savefig("file_name.png", bbox_inches='tight')


overlay_image_path = "path_to_save_directory"
cv2.imwrite(overlay_image_path, cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR))