In [1]:
from glob import glob
import os
from natsort import os_sorted
from dino.vision_transformer import DINOHead, VisionTransformer
from dino.vim.models_mamba import VisionMamba
from dino.config import configurations
from dino.main import get_args_parser
from functools import partial
from dino.utils import load_pretrained_weights
from torchvision import transforms
from torch import nn
import torch
from PIL import Image
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image


from tqdm import tqdm
import random
import matplotlib.gridspec as gridspec
import cv2

In [2]:
dataset_dir = 'Extracted_labels/Balanced/224_10x/test/'

In [3]:
parser = get_args_parser()
args = parser.parse_known_args()[0]

In [4]:
def get_model(args):

    config = configurations[args.arch]
    config['img_size'] = args.image_size
    config['patch_size'] = args.patch_size
    config['num_classes'] = args.num_classes
    if args.arch in configurations:
        config = configurations[args.arch]
        config['img_size'] = args.image_size
        config['patch_size'] = args.patch_size
        config['num_classes'] = args.num_classes

        if 'norm_layer' in config and config['norm_layer'] == "nn.LayerNorm":
            config['norm_layer'] = partial(nn.LayerNorm, eps=config['eps'])
        config['drop_path_rate'] = 0  
        if args.arch.startswith('vim'):
            model = VisionMamba(return_features=True, **config)
            embed_dim = model.embed_dim
        elif args.arch.startswith('vit'):
            model = VisionTransformer(**config)
            embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
        print('EMBEDDED DIM:', embed_dim)
    else:
        print(f"Unknown architecture: {args.arch}")
    return model

In [5]:
val_transform = transforms.Compose([
    transforms.Resize(args.image_size, interpolation=3),
    transforms.CenterCrop(args.image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [6]:
def reshape_transform_vit(tensor, height=14, width=14):
    result = tensor[:, 1 :  , :].reshape(tensor.size(0),
        height, width, tensor.size(2))

    result = result.transpose(2, 3).transpose(1, 2)
    return result

In [7]:
def reshape_transform_vim(tensor, height=14, width=14, token_position=98):
    hidden_state = tensor
    hidden_state = torch.cat((hidden_state[:, 1:token_position, :], hidden_state[:, token_position+1:, :]), dim=1)
    result = hidden_state.reshape(hidden_state.size(0), height, width, hidden_state.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result

## ImageNet Vs Cam16 Heatmaps

In [8]:
args.image_size = 224
args.patch_size = 16
args.num_classes = 2
args.n_last_blocks = 4
args.avgpool_patchtokens = False
args.checkpoint_key = 'teacher'

args.arch = 'vim-t-plus'
print(f'Loading {args.arch}')
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vim-t-plus_224-96/checkpoint.pth'
model_vim = get_model(args)
model_vim.cuda()
model_vim.eval()
load_pretrained_weights(model_vim, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

args.arch = 'vim-s'
print(f'\nLoading {args.arch}')
args.pretrained_weights = '/home/ubuntu/checkpoints/imagenet/vim_s_midclstok_80p5acc.pth'
model_vim2 = get_model(args)
state_dict = torch.load(args.pretrained_weights, map_location="cpu")['model']


# Adjusting the state dict to skip loading for layers that have a size mismatch
for name, param in model_vim2.named_parameters():
    if name not in state_dict or state_dict[name].size() != param.size():
        print(f"Skipping loading parameter {name} due to size mismatch or it not being present in the checkpoint.")
        state_dict.pop(name, None)  # Remove incompatible parameters

model_vim2.load_state_dict(state_dict, strict=False)
model_vim2.cuda()
model_vim2.eval()

args.arch = 'vit-s'
print(f'\nLoading {args.arch}')
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vit-s_224-96/checkpoint.pth'
model_vit = get_model(args)
model_vit.cuda()
model_vit.eval()
load_pretrained_weights(model_vit, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

args.arch = 'vit-s'
print(f'\nLoading {args.arch}')
args.pretrained_weights = '/home/ubuntu/checkpoints/imagenet/dino_deitsmall16_pretrain.pth'
model_vit2 = get_model(args)
model_vit2.cuda()
model_vit2.eval()
load_pretrained_weights(model_vit2, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

Loading vim-t-plus
EMBEDDED DIM: 384
Take key teacher in provided checkpoint dict
Skipping loading parameter head.weight due to size mismatch or it not being present in the checkpoint.
Skipping loading parameter head.bias due to size mismatch or it not being present in the checkpoint.
Pretrained weights found at /home/ubuntu/checkpoints/camelyon16_224_10x/vim-t-plus_224-96/checkpoint.pth and loaded with msg: _IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])

Loading vim-s
EMBEDDED DIM: 384
Skipping loading parameter head.weight due to size mismatch or it not being present in the checkpoint.
Skipping loading parameter head.bias due to size mismatch or it not being present in the checkpoint.

Loading vit-s
EMBEDDED DIM: 1536
Take key teacher in provided checkpoint dict
Skipping loading 

In [9]:
for class_name in ['tumor', 'normal']:

    img_paths = glob(os.path.join(dataset_dir, class_name, "*jpg"))
    img_paths = os_sorted(img_paths)

    os.makedirs(f'heatmaps/cam_imagenet/{class_name}', exist_ok=True)

    for i in tqdm(range(60)):
        img = Image.open(img_paths[i])
        img_transformed = val_transform(img).unsqueeze(0)
        cam_vim = GradCAM(model=model_vim, target_layers=[model_vim.layers[-1].drop_path],
                          reshape_transform=reshape_transform_vim)
        grayscale_cam_vim = cam_vim(input_tensor=img_transformed)
        grayscale_cam_vim = grayscale_cam_vim[0, :]  

        
        cam_vim_imagenet = GradCAM(model=model_vim2, target_layers=[model_vim2.layers[-1].drop_path],
                          reshape_transform=reshape_transform_vim)
        grayscale_cam_vim_imagenet = cam_vim_imagenet(input_tensor=img_transformed)
        grayscale_cam_vim_imagenet = grayscale_cam_vim_imagenet[0, :]  
        
        cam_vit = GradCAM(model=model_vit, target_layers=[model_vit.blocks[-1].norm1],
                          reshape_transform=reshape_transform_vit)
        grayscale_cam_vit = cam_vit(input_tensor=img_transformed)
        grayscale_cam_vit = grayscale_cam_vit[0, :]  


        cam_vit_imagenet = GradCAM(model=model_vit2, target_layers=[model_vit2.blocks[-1].norm1],
                                  reshape_transform=reshape_transform_vit)
        grayscale_cam_vit_imagenet = cam_vit_imagenet(input_tensor=img_transformed)
        grayscale_cam_vit_imagenet = grayscale_cam_vit_imagenet[0, :]  

        img_show = img_transformed.cpu().squeeze().permute(1, 2, 0).numpy()
        img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min())  # Normalize to [0,1]

        cam_image_vim = show_cam_on_image(img_show, grayscale_cam_vim, use_rgb=True)
        cam_image_vim_imagenet = show_cam_on_image(img_show, grayscale_cam_vim_imagenet, use_rgb=True)
        cam_image_vit = show_cam_on_image(img_show, grayscale_cam_vit, use_rgb=True)
        cam_image_vit_imagenet = show_cam_on_image(img_show, grayscale_cam_vit_imagenet, use_rgb=True)

        # Prepare the 2x4 plot
        plt.figure(figsize=(35, 20))
        gs = gridspec.GridSpec(2, 4)
        
        # Original image spanning 2x2
        ax0 = plt.subplot(gs[0:2, 0:2])
        ax0.imshow(img_show)
        ax0.set_title('Original Image')
        ax0.axis('off')

        # Subsequent CAM images
        ax1 = plt.subplot(gs[0, 2])
        ax1.imshow(cam_image_vim)
        ax1.set_title('Grad-CAM-Vim-Cam16', fontsize=40)
        ax1.axis('off')

        ax2 = plt.subplot(gs[0, 3])
        ax2.imshow(cam_image_vit)
        ax2.set_title('Grad-CAM-ViT-Cam16', fontsize=40)
        ax2.axis('off')

        ax3 = plt.subplot(gs[1, 2])
        ax3.imshow(cam_image_vim_imagenet)
        ax3.set_title('Grad-CAM-Vim-ImageNet', fontsize=40)
        ax3.axis('off')

        ax4 = plt.subplot(gs[1, 3])
        ax4.imshow(cam_image_vit_imagenet)
        ax4.set_title('Grad-CAM-ViT-ImageNet', fontsize=40)
        ax4.axis('off')
        
        # Ensure everything is displayed properly
        plt.tight_layout()


        plt.savefig(f'heatmaps/cam_imagenet/{class_name}/{i}.jpg', bbox_inches='tight', dpi=200)
        plt.close()

100%|███████████████████████████████████████████| 60/60 [01:02<00:00,  1.04s/it]
100%|███████████████████████████████████████████| 60/60 [00:59<00:00,  1.02it/s]


## Scaling Vim sizes

In [10]:
args.image_size = 224
args.patch_size = 16
args.num_classes = 2
args.n_last_blocks = 4
args.avgpool_patchtokens = False

args.checkpoint_key = 'teacher'

args.arch = 'vim-t'
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vim-t_224-96/checkpoint.pth'
model_vim_ti = get_model(args)
model_vim_ti.cuda()
model_vim_ti.eval()
load_pretrained_weights(model_vim_ti, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)


args.arch = 'vim-t-plus'
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vim-t-plus_224-96/checkpoint.pth'
model_vim_ti_plus = get_model(args)
model_vim_ti_plus.cuda()
model_vim_ti_plus.eval()
load_pretrained_weights(model_vim_ti_plus, args.pretrained_weights, 
                        args.checkpoint_key, args.arch, args.patch_size)

args.arch = 'vim-s'
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vim-s_224-96/checkpoint.pth'
model_vim_s = get_model(args)
model_vim_s.cuda()
model_vim_s.eval()
load_pretrained_weights(model_vim_s, args.pretrained_weights, 
                        args.checkpoint_key, args.arch, args.patch_size)


args.arch = 'vit-t'
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vit-t_224-96/checkpoint.pth'
model_vit_ti = get_model(args)
model_vit_ti.cuda()
model_vit_ti.eval()
load_pretrained_weights(model_vit_ti, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

args.arch = 'vit-s'
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vit-s_224-96/checkpoint.pth'
model_vit_s = get_model(args)
model_vit_s.cuda()
model_vit_s.eval()
load_pretrained_weights(model_vit_s, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)



EMBEDDED DIM: 192
Take key teacher in provided checkpoint dict
Skipping loading parameter head.weight due to size mismatch or it not being present in the checkpoint.
Skipping loading parameter head.bias due to size mismatch or it not being present in the checkpoint.
Pretrained weights found at /home/ubuntu/checkpoints/camelyon16_224_10x/vim-t_224-96/checkpoint.pth and loaded with msg: _IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
EMBEDDED DIM: 384
Take key teacher in provided checkpoint dict
Skipping loading parameter head.weight due to size mismatch or it not being present in the checkpoint.
Skipping loading parameter head.bias due to size mismatch or it not being present in the checkpoint.
Pretrained weights found at /home/ubuntu/checkpoints/camelyon16_224_10x/vim-t-plus_224-96/

In [11]:
for class_name in ['tumor', 'normal']:

    img_paths = glob(os.path.join(dataset_dir, class_name, "*jpg"))
    img_paths = os_sorted(img_paths)

    os.makedirs(f'heatmaps/heatmaps_scaling/{class_name}', exist_ok=True)

    for i in tqdm(range(120)):
        img = Image.open(img_paths[i])
        img_transformed = val_transform(img).unsqueeze(0)
        img_show = img_transformed.cpu().squeeze().permute(1, 2, 0).numpy()
        img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min())  # Normalize to [0,1]

        plt.figure(figsize=(35, 20))
        plt.subplot(1, 4, 1)
        plt.imshow(img_show)
        plt.title('Original Image', fontsize=40)
        plt.axis('off')

        for j, (model_name, model_vim) in enumerate([('Vim-ti', model_vim_ti), 
                                      ('Vim-ti-plus', model_vim_ti_plus), ('Vim-s', model_vim_s)]):
            cam_vim = GradCAM(model=model_vim, target_layers=[model_vim.layers[-1].drop_path],
                              reshape_transform=reshape_transform_vim)
            grayscale_cam_vim = cam_vim(input_tensor=img_transformed)
            grayscale_cam_vim = grayscale_cam_vim[0, :]  


            cam_image_vim = show_cam_on_image(img_show, grayscale_cam_vim, use_rgb=True)
        

            # Subsequent CAM images
            plt.subplot(1, 4, j+2)
            plt.imshow(cam_image_vim)
            plt.title(model_name, fontsize=40)
            plt.axis('off')

            # Ensure everything is displayed properly
            plt.tight_layout()


        plt.savefig(f'heatmaps/heatmaps_scaling/{class_name}/{i}.jpg', bbox_inches='tight', dpi=200)
        plt.close()

100%|█████████████████████████████████████████| 120/120 [01:02<00:00,  1.93it/s]
100%|█████████████████████████████████████████| 120/120 [01:02<00:00,  1.91it/s]


# All models

In [12]:
import openslide
import h5py
from shapely.affinity import scale
from shapely.geometry import Polygon, box, mapping
from lxml import etree

In [13]:
def validate_and_correct_polygon(polygon):
    if not polygon.is_valid:
        corrected_polygon = polygon.buffer(0)
        if corrected_polygon.is_valid:
            return corrected_polygon
        else:
            # Further correction attempts or logging for manual review
            # For complex cases, consider using simplify() with a small tolerance
            simplified_polygon = polygon.simplify(1.0, preserve_topology=True)
            if simplified_polygon.is_valid:
                return simplified_polygon
            else:
                raise ValueError("Polygon could not be corrected and remains invalid.")
    return polygon

def xml_to_polygons(xml_path):
    tree = etree.parse(xml_path)
    annotations = []
    for annotation in tree.xpath('.//Annotation'):
        points = []
        for coordinate in annotation.xpath('.//Coordinate'):
            x = float(coordinate.get('X'))
            y = float(coordinate.get('Y'))
            points.append([x, y])
        annotations.append(points)
    
    polygons = []
    for annotation in annotations:
        if len(annotation)>3:
            polygon = Polygon(annotation)
            corrected_polygon = validate_and_correct_polygon(polygon)
            polygons.append(corrected_polygon)

    return polygons

def polygons_to_patches(polygons, slide_dimension, patch_level, overlap = 0.1):    
    step_size = int(slide_dimension * (1 - overlap))*(2**patch_level)
    patch_area = slide_dimension ** 2  
    patches_info = []
    for target_polygon in polygons:
        min_x, min_y, max_x, max_y = map(int, target_polygon.bounds)
        for x_start in np.arange(min_x, max_x - slide_dimension + 1, step_size):
            for y_start in np.arange(min_y, max_y - slide_dimension + 1, step_size):
                x_end = int(x_start) + slide_dimension
                y_end = int(y_start) + slide_dimension

                # Create a shapely Polygon object for the current patch
                patch_polygon = box(x_start, y_start, x_end, y_end)

                # Calculate the intersection of the patch with the target polygon
                intersection = target_polygon.intersection(patch_polygon)

                # Calculate the percentage of the patch that lies within the target polygon
                percentage_inside = (intersection.area / patch_area) * 100

                patches_info.append({
                    'coordinates': (x_start, y_start),
                    'percentage_inside': percentage_inside,
                    'target_polygon': target_polygon,
                    'intersection':intersection
                })
    return patches_info

In [14]:
def get_high_res_img(img_path):
    
    file_name = img_path.split('/')[-1].split('.')[0]
    wsi_name = '_'.join(file_name.split('_')[:-1])
    tumor_idx = int(file_name.split('_')[-1])

    wsi = openslide.open_slide(f'/home/ubuntu/Downloads/Camelyon16/testing/images/{wsi_name}.tif')
    hdf5_file = h5py.File(f'dataset/Camelyon16/testing/224_10x/h5/images/patches/{wsi_name}.h5', 'r')

    patch_level = hdf5_file['coords'].attrs['patch_level']
    patch_size = hdf5_file['coords'].attrs['patch_size']
    polygons = xml_to_polygons(f'/home/ubuntu/Downloads/Camelyon16/testing/lesion_annotations/{wsi_name}.xml')
    patches_info = polygons_to_patches(polygons, slide_dimension=patch_size, patch_level=patch_level)
    hdf5_file.close()
    idx = 0
    for i in range(len(patches_info)):
        if patches_info[i]['percentage_inside']>50:
            idx += 1 
        if idx==tumor_idx+1:
            break
    
    coord = patches_info[i]['coordinates']
    img = wsi.read_region((coord[0], coord[1]), patch_level//4, (patch_size*4, patch_size*4)).convert('RGB') 
    return img

First slide only

In [16]:
models = {
    'Vim-ti':(model_vim_ti, model_vim_ti.layers[-1].drop_path),
    'Vim-ti-plus':(model_vim_ti_plus, model_vim_ti_plus.layers[-1].drop_path),
    'Vim-s':(model_vim_s, model_vim_s.layers[-1].drop_path),
    'ViT-ti':(model_vit_ti, model_vit_ti.blocks[-1].norm1),
    'ViT-s':(model_vit_s, model_vit_s.blocks[-1].norm1)
}


for class_name in ['tumor', 'normal']:

    img_paths = glob(os.path.join(dataset_dir, class_name, "*jpg"))
    img_paths = os_sorted(img_paths)

    os.makedirs(f'heatmaps/heatmaps_single/{class_name}', exist_ok=True)
    os.makedirs(f'heatmaps/heatmaps_single/{class_name}/raw', exist_ok=True)
    for i in tqdm(range(60)):
        img = Image.open(img_paths[i])
        img_transformed = val_transform(img).unsqueeze(0)
        img_show = img_transformed.cpu().squeeze().permute(1, 2, 0).numpy()
        img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min())  # Normalize to [0,1]
        plt.figure(figsize=(50, 23))
        gs = gridspec.GridSpec(2, 5)
        
        # Original image
        if class_name == 'tumor':
            final_img = get_high_res_img(img_paths[i])
        else:
            final_img = img
        img_name = os.path.splitext(os.path.basename(img_paths[i]))[0]
        final_img.save(f'heatmaps/heatmaps_single/{class_name}/raw/{img_name}_orig.png')
        
        final_img = np.array(final_img)/255
            
        ax0 = plt.subplot(gs[0:2, 0:2])
        ax0.imshow(final_img)
        ax0.set_title('Original Image', fontsize=40)
        ax0.axis('off')

        
        cams = []
        for idx, (model_name, (model, target_layer)) in enumerate(models.items()):
            cam = GradCAM(model=model, target_layers=[target_layer], 
                          reshape_transform=reshape_transform_vim if 'mamba' in model.__class__.__name__.lower() else reshape_transform_vit)
            grayscale_cam = cam(input_tensor=img_transformed)[0, :]
            grayscale_cam = cv2.resize(grayscale_cam, (final_img.shape[:2]))
            cam_image = show_cam_on_image(final_img, grayscale_cam, use_rgb=True)

            ax = plt.subplot(gs[idx // 3, idx % 3 + 2])
            ax.imshow(cam_image)
            Image.fromarray(cam_image).save(f'heatmaps/heatmaps_single/{class_name}/raw/{img_name}_{model_name}.png')

            ax.set_title(f'{model_name} Heatmap', fontsize=40)
            ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(f'heatmaps/heatmaps_single/{class_name}/{img_name}.jpg', bbox_inches='tight', dpi=200)
        plt.close()

100%|███████████████████████████████████████████| 60/60 [04:00<00:00,  4.00s/it]
100%|███████████████████████████████████████████| 60/60 [01:43<00:00,  1.72s/it]


Random Samples

In [17]:
models = {
    'Vim-ti':(model_vim_ti, model_vim_ti.layers[-1].drop_path),
    'Vim-ti-plus':(model_vim_ti_plus, model_vim_ti_plus.layers[-1].drop_path),
    'Vim-s':(model_vim_s, model_vim_s.layers[-1].drop_path),
    'ViT-ti':(model_vit_ti, model_vit_ti.blocks[-1].norm1),
    'ViT-s':(model_vit_s, model_vit_s.blocks[-1].norm1)
}

for class_name in ['tumor', 'normal']:

    img_paths = glob(os.path.join(dataset_dir, class_name, "*jpg"))
    img_paths = os_sorted(img_paths)
    target_image_idx = list(np.random.randint(0, len(img_paths), 60))

    os.makedirs(f'heatmaps/heatmaps_diverse/{class_name}', exist_ok=True)
    os.makedirs(f'heatmaps/heatmaps_diverse/{class_name}/raw', exist_ok=True)

    for i in tqdm(target_image_idx):
        img = Image.open(img_paths[i])
        img_transformed = val_transform(img).unsqueeze(0)
        img_show = img_transformed.cpu().squeeze().permute(1, 2, 0).numpy()
        img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min())  # Normalize to [0,1]
        plt.figure(figsize=(50, 23))
        gs = gridspec.GridSpec(2, 5)
        
        # Original image
        if class_name == 'tumor':
            final_img = get_high_res_img(img_paths[i])
        else:
            final_img = img
        img_name = os.path.splitext(os.path.basename(img_paths[i]))[0]
        final_img.save(f'heatmaps/heatmaps_diverse/{class_name}/raw/{img_name}_orig.png')
        
        final_img = np.array(final_img)/255
            
        ax0 = plt.subplot(gs[0:2, 0:2])
        ax0.imshow(final_img)
        ax0.set_title('Original Image', fontsize=40)
        ax0.axis('off')

        
        cams = []
        for idx, (model_name, (model, target_layer)) in enumerate(models.items()):
            cam = GradCAM(model=model, target_layers=[target_layer], 
                          reshape_transform=reshape_transform_vim if 'mamba' in model.__class__.__name__.lower() else reshape_transform_vit)
            grayscale_cam = cam(input_tensor=img_transformed)[0, :]
            grayscale_cam = cv2.resize(grayscale_cam, (final_img.shape[:2]))
            cam_image = show_cam_on_image(final_img, grayscale_cam, use_rgb=True)

            ax = plt.subplot(gs[idx // 3, idx % 3 + 2])
            ax.imshow(cam_image)
            Image.fromarray(cam_image).save(f'heatmaps/heatmaps_diverse/{class_name}/raw/{img_name}_{model_name}.png')
            ax.set_title(f'{model_name} Heatmap', fontsize=40)
            ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(f'heatmaps/heatmaps_diverse/{class_name}/{img_name}.jpg', bbox_inches='tight', dpi=200)
        plt.close()

100%|███████████████████████████████████████████| 60/60 [09:57<00:00,  9.96s/it]
100%|███████████████████████████████████████████| 60/60 [01:43<00:00,  1.73s/it]
