In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'../..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

attention_dir = os.path.join(project_dir, 'modules/AttentionMap')
if attention_dir not in sys.path:
    sys.path.append(attention_dir)

sparse_dir = os.path.join(project_dir, 'modules/Sparse')
if sparse_dir not in sys.path:
    sys.path.append(sparse_dir) 

import numpy as np
import torch, config
from torch import nn
import os
import pandas as pd

In [None]:
from torchvision.transforms import Compose, ToTensor, ToPILImage, Normalize, Resize
from derma.utils import UnNormalize
from derma.dataset import Derma
from torch.utils.data import DataLoader

transform = Compose([
        Resize((256, 256)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

to_pil = Compose([
        UnNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToPILImage()
    ])

dataset_dir = os.path.join(config.DATASET_DIR, 'test')
dataset = Derma(dataset_dir, transform=transform)
loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)

# GradCAM

Generar un dataframe con el nombre de los ficheros por indice

In [None]:
result = list( map(lambda x: os.path.join(os.path.dirname(x)[-1], os.path.split(x)[1]), dataset.x) )
import pandas as pd

df = pd.DataFrame(result)

In [None]:
from derma.doc.utils import GradCamAttribute

def ObtainGradCam(model, loader):
    model.eval()
    model.cpu()

    grad_cam = None

    for inputs, targets in loader:
        inputs.requires_grad = True

        attribution = GradCamAttribute(model, model.features[-1][0], inputs, targets)
        attribution = attribution.mean(axis=1).abs().detach() # Remove negative values

        grad_cam = torch.concat([grad_cam, attribution], axis=0) if grad_cam is not None else attribution

    return grad_cam

 Obtener los GradCAM con el conjunto de test y almacenarlo en disco

In [None]:
from derma.architecture import InvertedResidual
from torchvision.models import MobileNetV2
from derma.doc.utils import summary

experiment = config.experiment
for idx, model_name in enumerate(experiment.keys()):
    weights_dir = os.path.join(config.RESULT_DIR, 'weights/classification/{}/HAM10000/model.pth'.format(model_name))
    net_config = experiment[model_name]
    model = MobileNetV2(num_classes=2, inverted_residual_setting=net_config['inverted_residual_setting'],
                        block=InvertedResidual if net_config['attention'] else None)

    model.load_state_dict(torch.load(weights_dir))

    grad_cam_save_dir = os.path.join(config.RESULT_DIR, 'grad_cam/HAM10000/{}/'.format(model_name))
    if not os.path.exists(grad_cam_save_dir):
        os.makedirs(grad_cam_save_dir)

    grad_cam = ObtainGradCam(model, loader)
    torch.save(grad_cam, os.path.join(grad_cam_save_dir, 'grad_cam.pth'))
    df.to_csv(os.path.join(grad_cam_save_dir, 'images_name.csv'), index=True, header=None)


# Generate images

In [None]:
grad_cam_result_dir =  os.path.join(config.RESULT_DIR, 'grad_cam/HAM10000/')
dataset_dir = os.path.join(config.DATASET_DIR, 'test')

In [None]:
from PIL import Image
from derma.doc.utils import plot_attribution
from matplotlib import pyplot as plt

exp = list(config.experiment.keys())
for model_name in exp:
    grad_cam_model_dir = os.path.join(grad_cam_result_dir, model_name)
    images_name = pd.read_csv(os.path.join(grad_cam_model_dir, 'images_name.csv'), header=None, index_col=0)
    gradcam_att = torch.load(os.path.join(grad_cam_model_dir, 'grad_cam.pth'))

    img_save_dir = os.path.join(grad_cam_model_dir, 'imgs')

    for idx in range(len(images_name)):
        filename = images_name.iloc[idx].values[0]
        target, name =  os.path.split(filename)
        name, ext = os.path.splitext(name)
        full_path_filename = os.path.join(dataset_dir, filename)

        img = Image.open(full_path_filename).convert('RGB').resize((256, 256))
        att = gradcam_att[idx].numpy()
        try:
            fig = plot_attribution(att, img)
            current_save_dir = os.path.join(img_save_dir, target)
            fig.savefig(os.path.join(current_save_dir, '{}.pdf'.format(name)), bbox_inches = 'tight', pad_inches = 0)
        except:
            print("An exception occurred: {}".format(filename)) 

        plt.close(fig)
