In [None]:
folder_path = 'Path/To/Your/Project'
csv_path = 'path to test csv/test.csv'
ckpt_path = 'path to ckpt'
save_path = 'path to save images'


import sys
sys.path.append(folder_path)

import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torchvision.io import read_image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision import transforms


from PIL import Image, UnidentifiedImageError, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True




# gradcam function

In [None]:
img_size = 224

test_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])



def visualize_grad_cam(model, target_layers, img_path='', class_target=1, save_path=None, show_image=True,device = 'cpu'):
    '''
    Optimized version to reduce GPU memory usage. Ensure to pass model and target_layers directly if they do not change.
    '''
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Function to preprocess and read the image
    def process_image(img_path):
        def read_img(p):
            image = read_image(p)
            # if 4 channels need to convert colour space
            if image.shape[0] != 3:
                image = cv2.imread(p, cv2.IMREAD_UNCHANGED)
                image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR  ) 
                image = torch.tensor(image).permute((2,0,1))
            return image
        image = read_img(img_path).float() / 255
        image_shape = image.shape
        image = image.unsqueeze(0)
        image = torch.nn.functional.interpolate(image, size=(224, 224), mode='bilinear', align_corners=False)
        
        return image, image_shape #.to(device)
    
    def process_image_direct(image_path):
        # this is their method, the other is what i used to do 
        image = Image.open(image_path).convert('RGB')
        image_shape = (3, image.size[1], image.size[0])
        image = image.resize((224, 224))
        image = test_transform(image)
        image = image.unsqueeze(0)
            
        return image, image_shape

    # Load and process the image
    #with torch.no_grad():  # Disable gradient computation
    input_tensor, image_shape = process_image_direct(img_path)
    input_tensor = input_tensor.to(device)
    cam = GradCAM(model=model, target_layers=target_layers)
    targets = [ClassifierOutputTarget(class_target)]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]  # Get first image CAM

    img = Image.open(img_path).convert('RGB')
    img = img.resize((224, 224))
    img = np.float32(img) / 255

    

    # Overlay the CAM on the image
    visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)

    visualization = cv2.resize(visualization, (image_shape[2],image_shape[1]))

    # Save and/or show the visualization
    if save_path:
        saved_path = f'{save_path}/{img_path.split("/")[-1]}'
        plt.imsave(saved_path, visualization)
        plt.close()

    if show_image:
        plt.imshow(visualization)
        plt.axis('off')
        plt.show()

    # Cleanup
    del input_tensor, grayscale_cam  # Remove tensors from memory
    torch.cuda.empty_cache()  # Optionally clear memory cache if needed

    if save_path: return visualization, saved_path

    return visualization


# setup: load ckpt

In [None]:
from models.classification import ClassificationNet
from util.get_models import get_baseline_model

import pandas as pd


classes = {"TED_1": 1, "CONT_": 0}
# load df 
df_test = pd.read_csv(csv_path)

# load model
encoder = get_baseline_model(pretrained=True, model_architecture= 'resnet50')
model = ClassificationNet( # make sure to define the model architecture and parameters correctly
    feature_dim= 2048,
    encoder=encoder,
    classes=classes,
    lr=3e-5,              
    loss_type="focal"    
)

state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model.load_state_dict(state_dict)

## one example

In [6]:
df_test.directory[0]

'/home/CenteredData/TED Federated Learning Project/ted_manual_preprocessing/TED_1081.png'

In [None]:
# this image is class 0
img_path = '../TED4Share/demo_images/100224.jpg.png'
visualization  = visualize_grad_cam(
                                    model = model.encoder,
                                    target_layers = [model.encoder[0].layer4[-1]],
                                    class_target = 0,
                                    img_path = img_path, 
                                    save_path=None, show_image=True)


## for the whole dataset

In [None]:

for idx, row in df_test.iterrows():
    img_path = row['directory']
    class_target = row['label']
    print(img_path)
    visualization  = visualize_grad_cam(
                                        model = model.encoder,
                                        target_layers = [model.encoder[0].layer4[-1]],
                                        class_target = class_target,
                                        img_path = img_path, 
                                        save_path=save_path, show_image=False)