In [5]:
import os
from glob import glob
from tqdm import tqdm
import torch
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image
from efficientnet_pytorch import EfficientNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

efficientnet_b0 = EfficientNet.from_pretrained('efficientnet-b0')
efficientnet_b0._fc = torch.nn.Linear(efficientnet_b0._fc.in_features, 2)
efficientnet_b0.load_state_dict(torch.load('saved_models/Combined-DEPTH.pth', map_location=device))
efficientnet_b0.to(device)
efficientnet_b0.eval()

target_layer = efficientnet_b0._conv_head
cam = GradCAMPlusPlus(model=efficientnet_b0, target_layers=[target_layer])

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

input_root = 'Test_Dataset/ColourTest3/Depth_Blender/Before'
output_root = 'GradCAM_Masks2'


def extract_center_mask(cam, center_fraction, threshold):
    h, w = cam.shape
    ch, cw = int(h * center_fraction), int(w * center_fraction)
    y1 = (h - ch) // 2
    x1 = (w - cw) // 2
    center_crop = cam[y1:y1+ch, x1:x1+cw]
    mask = np.zeros_like(cam, dtype=np.uint8)
    binary_center = (center_crop > threshold).astype(np.uint8) * 255
    mask[y1:y1+ch, x1:x1+cw] = binary_center
    return mask

image_paths = glob(os.path.join(input_root, '**', '*.*'), recursive=True)

for img_path in tqdm(image_paths, desc="Generating GradCAM masks"):
    try:
        
        img = Image.open(img_path).convert('RGB')
        input_tensor = preprocess(img).unsqueeze(0).to(device)

        
        grayscale_cam = cam(input_tensor=input_tensor)[0]

        
        binary_mask = extract_center_mask(grayscale_cam, center_fraction=0.8, threshold=0.55)

        rel_path = os.path.relpath(img_path, input_root)
        out_path = os.path.join(output_root, rel_path)
        os.makedirs(os.path.dirname(out_path), exist_ok=True)

        cv2.imwrite(out_path, binary_mask)

    except Exception as e:
        print(f"Error processing {img_path}: {e}")


  efficientnet_b0.load_state_dict(torch.load('saved_models/Combined-DEPTH.pth', map_location=device))


Loaded pretrained weights for efficientnet-b0


Generating GradCAM masks: 100%|████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 20.49it/s]
