In [None]:
!nvidia-smi

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import skimage.transform as st
import os
import gc
import warnings
import pickle
from transformation import *
from models import *
from load_data import *
from GradCAM_Salency_map import *
 
print(tf.__version__)

warnings.filterwarnings("ignore")

gpus = tf.config.list_physical_devices(device_type='GPU')
tf.config.set_visible_devices(devices=gpus[0], device_type='GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [None]:
seed = 2021
os.environ['PYTHONHASHSEED']=str(seed)
tf.random.set_seed(seed)
np.random.seed(seed)

## Saliency map

In [None]:
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
group_name = {'race': ['white', 'black', 'asian'], 'gender': ['male', 'female'], 'age': ['0_40', '40_60', '60_80', '80+']}
group_class = {'race': [0, 1, 2], 'gender': [0, 1], 'age': [0, 1, 2, 3]}
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]

def compute_smap(img, model):

    image = tf.reshape(img, [1, 224, 224])

    with tf.GradientTape() as tape:
        tape.watch(image)
        loss = model(image, training=False)[0]
                
            
    grads = tape.gradient(loss, image)

    dgrad_max_ = tf.math.abs(grads)
#     grads = grads[grads < 0] = 0

#     arr_min, arr_max  = np.min(dgrad_max_), np.max(dgrad_max_)
#     smap = (dgrad_max_ - arr_min) / (arr_max - arr_min + 1e-18)

    smap = grads

    gc.collect()
        
    return smap[0]

In [None]:
def make_gradcam_heatmap(img_array, last_conv_layer_model, classifier_model, target_class=None):

    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        # Compute activations of the last conv layer and make the tape watch it
        last_conv_layer_output = last_conv_layer_model(img_array)
        tape.watch(last_conv_layer_output)
        # Compute class predictions
        preds = classifier_model(last_conv_layer_output)
        
        if (target_class is not None):
            top_pred_index = tf.constant(target_class)
        else:
            top_pred_index = tf.argmax(preds[0])
        
        top_class_channel = preds[:, top_pred_index]
                
    # This is the gradient of the top predicted class with regard to
    # the output feature map of the last conv layer
    grads = tape.gradient(top_class_channel, last_conv_layer_output)
    
    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    
    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    last_conv_layer_output = last_conv_layer_output.numpy()[0]
    pooled_grads = pooled_grads.numpy()
    for i in range(pooled_grads.shape[-1]):
        last_conv_layer_output[:, :, i] *= pooled_grads[i]

    # The channel-wise mean of the resulting feature map
    # is our heatmap of class activation
    heatmap = np.mean(last_conv_layer_output, axis=-1)

    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    heatmap = np.maximum(heatmap, 0) / np.max(heatmap)

    return heatmap

def show_heatmap(img_array, last_conv_layer_model, classifier_model, target_class=None):

    heatmap = make_gradcam_heatmap(img_array, last_conv_layer_model, classifier_model, target_class)

    heatmap = np.uint8(255 * heatmap)

    jet = cm.get_cmap("jet")

    # We use RGB values of the colormap
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]

    # We create an image with RGB colorized heatmap
    jet_heatmap = tf.keras.preprocessing.image.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((224, 224))
    jet_heatmap = tf.keras.preprocessing.image.img_to_array(jet_heatmap)
        
    return jet_heatmap

def grad_cam_plus(img_array, last_conv_layer_model, classifier_model, target_class=None):
    """Get a heatmap by Grad-CAM++.
    Args:
        model: A model object, build from tf.keras 2.X.
        img: An image ndarray.
        layer_name: A string, layer name in model.
        label_name: A list or None,
            show the label name by assign this argument,
            it should be a list of all label names.
        category_id: An integer, index of the class.
            Default is the category with the highest score in the prediction.
    Return:
        A heatmap ndarray(without color).
    """

    with tf.GradientTape() as gtape1:
        with tf.GradientTape() as gtape2:
            with tf.GradientTape() as gtape3:
                last_conv_layer_output = last_conv_layer_model(img_array)
                # Compute class predictions
                preds = classifier_model(last_conv_layer_output)

                if (target_class is not None):
                    top_pred_index = tf.constant(target_class)
                else:
                    top_pred_index = tf.argmax(preds[0])

                top_class_channel = preds[:, top_pred_index]
                
            conv_first_grad = gtape3.gradient(top_class_channel, last_conv_layer_output)
            
        conv_second_grad = gtape2.gradient(conv_first_grad, last_conv_layer_output)
        
    conv_third_grad = gtape1.gradient(conv_second_grad, last_conv_layer_output)

    global_sum = np.sum(last_conv_layer_output, axis=(0, 1, 2))

    alpha_num = conv_second_grad[0]
    alpha_denom = conv_second_grad[0]*2.0 + conv_third_grad[0]*global_sum
    alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, 1e-10)

    alphas = alpha_num/alpha_denom
    alpha_normalization_constant = np.sum(alphas, axis=(0,1))
    alpha_normalization_constant = np.where(alpha_normalization_constant != 0.0, alpha_normalization_constant, 1)
    alphas /= alpha_normalization_constant
    
    weights = np.maximum(conv_first_grad[0], 0.0)
    
    deep_linearization_weights = np.sum(weights*alphas, axis=(0,1))
    grad_cam_map = np.sum(deep_linearization_weights*last_conv_layer_output[0], axis=2)
    
    arr_min, arr_max  = np.min(grad_cam_map), np.max(grad_cam_map)
    heatmap = (grad_cam_map - arr_min) / (arr_max - arr_min + 1e-18)
    
    heatmap = np.uint8(255 * heatmap)
    
    jet = cm.get_cmap("jet")

    # We use RGB values of the colormap
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]

    # We create an image with RGB colorized heatmap
    jet_heatmap = tf.keras.preprocessing.image.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((224, 224))
    jet_heatmap = tf.keras.preprocessing.image.img_to_array(jet_heatmap)

    return jet_heatmap

In [None]:
dataset = 'mimic'
task = 'disease'

X_test, y_test, demo = get_data(aug_method='', dataset=dataset, data_split='test', task=task, return_demo=True)


## Extract valid CXR id

In [None]:
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]

prediction_name = 'predictions/model_densenet_mimic_ERM_on_original'
best_thresh_0 = np.loadtxt('thresh/model_densenet_mimic_ERM_thresh.txt')

with open(prediction_name, "rb") as fp:
    y_preds_0 = pickle.load(fp)
fp.close()


prediction_name = 'predictions/model_densenet_mimic_ERM_proposed_on_original'
best_thresh_1 = np.loadtxt('thresh/model_densenet_mimic_ERM_proposed_thresh.txt')

with open(prediction_name, "rb") as fp:
    y_preds_1 = pickle.load(fp)
fp.close()

for demo_idx, demo_group in enumerate(['race', 'gender', 'age']):
    prediction_name = 'predictions/model_densenet_mimic_ERM_{demo_group}_on_original'.format(demo_group=demo_group)

    with open(prediction_name, "rb") as fp:
        y_preds_0_demo = pickle.load(fp)
    fp.close()

    prediction_name = 'predictions/model_densenet_mimic_ERM_{demo_group}_proposed_on_original'.format(demo_group=demo_group)

    with open(prediction_name, "rb") as fp:
        y_preds_1_demo = pickle.load(fp)
    fp.close()

    indice = []
    demo_criteria = []
    for i in range(len(group_class[demo_group])):
        demo_criteria.append(np.percentile(y_preds_0_demo[:, i], 90))
        
    for i in range(len(y_preds_0)):
        
        cond_a = (np.argmax(y_preds_0_demo[i]) == demo[i][demo_idx] and y_preds_0_demo[i][np.argmax(y_preds_0_demo[i])] >= demo_criteria[np.argmax(y_preds_0_demo[i])])
        cond_b = (np.argmax(y_preds_1_demo[i]) != demo[i][demo_idx])
        if (cond_a and cond_b):
            indice.append(i)

    indice = np.array(indice)
    
    for label_idx in target_label:
    
        directory = 'imgs/{label}/'.format(label=Labels_diseases[label_idx])
        if not os.path.exists(directory):
            os.makedirs(directory)

        valid_idx = []
        for idx in indice:
            if (y_test[idx][label_idx] == 1 and y_preds_0[idx][label_idx] > best_thresh_0[label_idx] and y_preds_1[idx][label_idx] > best_thresh_1[label_idx]):
                valid_idx.append(idx)
            else:
                pass
            
        with open('imgs/{label}/{demo_group}_idx'.format(label=Labels_diseases[label_idx], demo_group=demo_group), "wb") as fp:
            pickle.dump(np.array(valid_idx), fp)
        fp.close()
            

In [None]:
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]

prediction_name = 'predictions/model_densenet_mimic_ERM_on_original'
best_thresh_0 = np.loadtxt('thresh/model_densenet_mimic_ERM_thresh.txt')

with open(prediction_name, "rb") as fp:
    y_preds_0 = pickle.load(fp)
fp.close()


prediction_name = 'predictions/model_densenet_mimic_ERM_proposed_on_original'
best_thresh_1 = np.loadtxt('thresh/model_densenet_mimic_ERM_proposed_thresh.txt')

with open(prediction_name, "rb") as fp:
    y_preds_1 = pickle.load(fp)
fp.close()

for demo_idx, demo_group in enumerate(['race', 'gender', 'age']):
    prediction_name = 'predictions/model_densenet_mimic_ERM_task_transfer_{demo_group}_on_original'.format(demo_group=demo_group)

    with open(prediction_name, "rb") as fp:
        y_preds_0_demo = pickle.load(fp)
    fp.close()

    prediction_name = 'predictions/model_densenet_mimic_ERM_task_transfer_{demo_group}_proposed_on_original'.format(demo_group=demo_group)

    with open(prediction_name, "rb") as fp:
        y_preds_1_demo = pickle.load(fp)
    fp.close()

    indice = []
    demo_criteria = []
    for i in range(len(group_class[demo_group])):
        demo_criteria.append(np.percentile(y_preds_0_demo[:, i], 90))
        
    for i in range(len(y_preds_0)):
        
        cond_a = (np.argmax(y_preds_0_demo[i]) == demo[i][demo_idx] and y_preds_0_demo[i][np.argmax(y_preds_0_demo[i])] >= demo_criteria[np.argmax(y_preds_0_demo[i])])
        cond_b = (np.argmax(y_preds_1_demo[i]) != demo[i][demo_idx])
        if (cond_a and cond_b):
            indice.append(i)

    indice = np.array(indice)
    
    for label_idx in target_label:
    
        directory = 'imgs/{label}/'.format(label=Labels_diseases[label_idx])
        if not os.path.exists(directory):
            os.makedirs(directory)

        valid_idx = []
        for idx in indice:
            if (y_test[idx][label_idx] == 1 and y_preds_0[idx][label_idx] > best_thresh_0[label_idx] and y_preds_1[idx][label_idx] > best_thresh_1[label_idx]):
                valid_idx.append(idx)
            else:
                pass
            
        with open('imgs/{label}/task_transfer_{demo_group}_idx'.format(label=Labels_diseases[label_idx], demo_group=demo_group), "wb") as fp:
            pickle.dump(np.array(valid_idx), fp)
        fp.close()
            

## Compute saliency maps

In [None]:
def get_last_conv_model(model):
    input_layer = model.get_layer(model.layers[0].name)
    upsampling = model.get_layer(model.layers[1].name)
    denset121 = model.get_layer(model.layers[2].name)

    last_conv_layer_model = tf.keras.Sequential()
    last_conv_layer_model.add(input_layer)
    last_conv_layer_model.add(upsampling)
    last_conv_layer_model.add(tf.keras.Model(denset121.inputs, denset121.layers[-4].output))
    
    # Second, we create a model that maps the activations of the last conv
    # layer to the final class predictions
    classifier_input = tf.keras.Input(shape=denset121.layers[-3].output.shape[1:])
    x = classifier_input
    for layer in denset121.layers[-3:]:
        x = denset121.get_layer(layer.name)(x)
    x = model.get_layer(model.layers[-1].name)(x)
    classifier_model = tf.keras.Model(classifier_input, x)   
    
    return last_conv_layer_model, classifier_model

In [None]:
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]
    
model_name = 'model_densenet_mimic_ERM'
baseline_model = get_model('ERM', '', 'disease', 'densenet')
checkpoint = tf.train.Checkpoint(baseline_model)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+model_name, max_to_keep=1, checkpoint_name=model_name)
baseline_model.load_weights(manager.checkpoints[0])
baseline_model_last_conv, baseline_model_classifier = get_last_conv_model(baseline_model)

model_name = 'model_densenet_mimic_ERM_proposed'
proposed_model = get_model('ERM', '', 'disease', 'densenet')
checkpoint = tf.train.Checkpoint(proposed_model)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+model_name, max_to_keep=1, checkpoint_name=model_name)
proposed_model.load_weights(manager.checkpoints[0])
proposed_model_last_conv, proposed_model_classifier = get_last_conv_model(proposed_model)


for label_idx in target_label:
    
    all_smap_baseline = []
    all_smap_proposed = []
    
    directory = 'imgs/{label}/mean_map_gradcampp/disease/'.format(label=Labels_diseases[label_idx])
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    for demo_idx, group in enumerate(['race', 'gender', 'age']):

        with open('imgs/{label}/{group}_idx'.format(label=Labels_diseases[label_idx], group=group), "rb") as fp:
            indice = pickle.load(fp)
        fp.close()

        for idx in indice:

#             smap = compute_smap(X_test[idx], baseline_model_last_conv)
#             smap = show_heatmap(np.reshape(X_test[idx], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier)
            smap = grad_cam_plus(np.reshape(X_test[idx], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier)
            all_smap_baseline.append(smap)
            

#             smap = compute_smap(X_test[idx], proposed_model_last_conv)
#             smap = show_heatmap(np.reshape(X_test[idx], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier)
            smap = grad_cam_plus(np.reshape(X_test[idx], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier)
            all_smap_proposed.append(smap)
            
    mean_smap_baseline = np.mean(all_smap_baseline, axis=0)
    filename = 'imgs/{label}/mean_map_gradcampp/disease/baseline_model'.format(label=Labels_diseases[label_idx])
    with open(filename, "wb") as fp:
        pickle.dump(mean_smap_baseline, fp)
    fp.close()

    mean_smap_proposed = np.mean(all_smap_proposed, axis=0)
    filename = 'imgs/{label}/mean_map_gradcampp/disease/proposed_model'.format(label=Labels_diseases[label_idx])
    with open(filename, "wb") as fp:
        pickle.dump(mean_smap_proposed, fp)
    fp.close()
    
            

In [None]:
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]
group_idx = {'race': 0, 'gender': 1, 'age': 2}

for group in ['race', 'gender', 'age']:
        
    model_name = 'model_densenet_mimic_ERM_{group}'.format(group=group)
    baseline_model = get_model('ERM', '', group, 'densenet')
    checkpoint = tf.train.Checkpoint(baseline_model)
    manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+model_name, max_to_keep=1, checkpoint_name=model_name)
    baseline_model.load_weights(manager.checkpoints[0])
    baseline_model = get_last_conv_model(baseline_model)
    
    model_name = 'model_densenet_mimic_ERM_{group}_proposed'.format(group=group)
    proposed_model = get_model('ERM', '', group, 'densenet')
    checkpoint = tf.train.Checkpoint(proposed_model)
    manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+model_name, max_to_keep=1, checkpoint_name=model_name)
    proposed_model.load_weights(manager.checkpoints[0])
    proposed_model_last_conv, proposed_model_classifier = get_last_conv_model(proposed_model)


    for label_idx in target_label:
        
        all_smap_baseline = []
        all_smap_proposed = []
        
        directory = 'imgs/{label}/mean_map_gradcampp/{group}/'.format(label=Labels_diseases[label_idx], group=group)
        if not os.path.exists(directory):
            os.makedirs(directory)

        with open('imgs/{label}/{group}_idx'.format(label=Labels_diseases[label_idx], group=group), "rb") as fp:
            indice = pickle.load(fp)
        fp.close()

        for idx in indice:

#             smap = compute_smap(X_test[idx], baseline_model_last_conv)
#             smap = show_heatmap(np.reshape(X_test[idx], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier)
            smap = grad_cam_plus(np.reshape(X_test[idx], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier)
            all_smap_baseline.append(smap)

            
#             smap = compute_smap(X_test[idx], proposed_model_last_conv)
#             smap = show_heatmap(np.reshape(X_test[idx], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier)
            smap = grad_cam_plus(np.reshape(X_test[idx], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier)
            all_smap_proposed.append(smap)
            
        mean_smap_baseline = np.mean(all_smap_baseline, axis=0)
        filename = 'imgs/{label}/mean_map_gradcampp/{group}/baseline_model_{group}'.format(label=Labels_diseases[label_idx], group=group)
        with open(filename, "wb") as fp:
            pickle.dump(mean_smap_baseline, fp)
        fp.close()

        mean_smap_proposed = np.mean(all_smap_proposed, axis=0)
        filename = 'imgs/{label}/mean_map_gradcampp/{group}/proposed_model_{group}'.format(label=Labels_diseases[label_idx], group=group)
        with open(filename, "wb") as fp:
            pickle.dump(mean_smap_proposed, fp)
        fp.close()


In [None]:
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]
group_idx = {'race': 0, 'gender': 1, 'age': 2}

for group in ['race', 'gender', 'age']:
        
    model_name = 'model_densenet_mimic_ERM_task_transfer_{group}'.format(group=group)
    baseline_model = get_model('ERM', '', group, 'densenet')
    checkpoint = tf.train.Checkpoint(baseline_model)
    manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+model_name, max_to_keep=1, checkpoint_name=model_name)
    baseline_model.load_weights(manager.checkpoints[0])
    baseline_model_last_conv, baseline_model_classifier = get_last_conv_model(baseline_model)
    
    model_name = 'model_densenet_mimic_ERM_task_transfer_{group}_proposed'.format(group=group)
    proposed_model = get_model('ERM', '', group, 'densenet')
    checkpoint = tf.train.Checkpoint(proposed_model)
    manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+model_name, max_to_keep=1, checkpoint_name=model_name)
    proposed_model.load_weights(manager.checkpoints[0])
    proposed_model = get_last_conv_model(proposed_model)


    for label_idx in target_label:
        directory = 'imgs/{label}/task_transfer_mean_map/{group}/'.format(label=Labels_diseases[label_idx], group=group)
        if not os.path.exists(directory):
            os.makedirs(directory)

        with open('imgs/{label}/task_transfer_{group}_idx'.format(label=Labels_diseases[label_idx], group=group), "rb") as fp:
            indice = pickle.load(fp)
        fp.close()

        for idx in indice:

            filename = 'imgs/{label}/task_transfer_mean_map/{group}/model_densenet_mimic_ERM_task_transfer_{group}_{idx}'.format(label=Labels_diseases[label_idx], group=group, model_name=model_name, idx=idx)

            smap = compute_smap(X_test[idx], baseline_model, demo[idx][group_idx[group]])


            with open(filename, "wb") as fp:
                pickle.dump(smap, fp)
            fp.close()
            
            
            filename = 'imgs/{label}/task_transfer_mean_map/{group}/model_densenet_mimic_ERM_task_transfer_{group}_proposed_{idx}'.format(label=Labels_diseases[label_idx], group=group, model_name=model_name, idx=idx)

            smap = compute_smap(X_test[idx], proposed_model, demo[idx][group_idx[group]])

            with open(filename, "wb") as fp:
                pickle.dump(smap, fp)
            fp.close()


## Analysis

In [None]:
dataset = 'mimic'
task = 'disease'

X_test, y_test, demo = get_data(aug_method='', dataset=dataset, data_split='test', task=task, return_demo=True)


In [None]:
def show(smap_1, smap_2, mean_img, label, demo_label, path):
    plt.figure(figsize=(6, 4.5), dpi=300)
    plt.subplots_adjust(wspace=0, hspace=0)
    
    ax1 = plt.subplot(121)
    ax1.axis('off')
    ax1.imshow(mean_img, cmap='gray')
    ax1.imshow(smap_1, cmap='Reds', alpha=0.3)
    if (label == 'Enlarged Cardiomediastinum'):
        ax1.set_title('Enlarged Card.'+ ' '+demo_label.capitalize()+' Baseline', fontsize=5)
    else:
        ax1.set_title(label+ ' '+demo_label.capitalize()+' Baseline', fontsize=5)
    ax1.title.set_size(8)

    ax2 = plt.subplot(122)
    ax2.axis('off')
    ax2.imshow(mean_img, cmap='gray')
    ax2.imshow(smap_2, cmap='Reds', alpha=0.3)
    if (label == 'Enlarged Cardiomediastinum'):
        ax2.set_title('Enlarged Card.'+ ' '+demo_label.capitalize()+' Proposed', fontsize=5)
    else:
        ax2.set_title(label+ ' '+demo_label.capitalize()+' Proposed', fontsize=35)
    ax2.title.set_size(8)
    
    
    if not (os.path.exists(path)):
        os.makedirs(path)
    
    plt.savefig(path+'{label}_{demo}.jpg'.format(label=label, demo=demo_label.capitalize()), bbox_inches='tight', transparent="True")
    plt.show() 
    
    plt.figure(figsize=(6, 2), dpi=300)
    plt.subplots_adjust(wspace=0, hspace=0)
    
    counts, bins = np.histogram(smap_1, bins=100)
    ax1 = plt.subplot(121)
    ax1.stairs(counts, bins, label='Baseline')
    
    counts, bins = np.histogram(smap_2, bins=100)
    ax2 = plt.subplot(121)
    ax2.stairs(counts, bins, label='Proposed')
    
    plt.legend()
    
    plt.savefig(path+'{label}_{demo}_histogram.jpg'.format(label=label, demo=demo_label.capitalize()), bbox_inches='tight', transparent="True")
    plt.show() 


In [None]:
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]
group_name = {'race': ['white', 'black', 'asian'], 'gender': ['male', 'female'], 'age': ['0_40', '40_60', '60_80', '80+']}
group_type = {'race': [0, 1, 4], 'gender': [0, 1], 'age': [0, 1, 2, 3]}

all_img = []

for label_idx in target_label:
    
    for group in ['race', 'age', 'gender']:
        
    
        with open('imgs/{label}/{group}_idx'.format(label=Labels_diseases[label_idx], group=group), "rb") as fp:
            ids = pickle.load(fp)
        fp.close()
        
        for i, idx in enumerate(ids):
            
            img = X_test[idx]

            all_img.append(img)
        
        with open('imgs/{label}/task_transfer_{group}_idx'.format(label=Labels_diseases[label_idx], group=group), "rb") as fp:
            ids = pickle.load(fp)
        fp.close()
        
        for i, idx in enumerate(ids):
            
            img = X_test[idx]

            all_img.append(img)

In [None]:
mean_all_img = np.mean(all_img, axis=0)

pixvals = np.copy(mean_all_img)

minval = np.percentile(pixvals, 15)
maxval = np.percentile(pixvals, 100)
pixvals = np.clip(pixvals, minval, maxval)

plt.axis('off')
plt.imshow(pixvals, cmap='gray')

In [None]:
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]
group_name = {'race': ['white', 'black', 'asian'], 'gender': ['male', 'female'], 'age': ['0_40', '40_60', '60_80', '80+']}
group_type = {'race': [0, 1, 4], 'gender': [0, 1], 'age': [0, 1, 2, 3]}

for group in ['race', 'age', 'gender']:
    all_smap_baseline = []
    all_smap_proposed = []
    
    for label_idx in target_label:
                
        with open('imgs/{label}/{group}_idx'.format(label=Labels_diseases[label_idx], group=group), "rb") as fp:
            ids = pickle.load(fp)
        fp.close()
                
        for i, idx in enumerate(ids):

            directory = 'imgs/{label}/mean_map/{group}/model_densenet_mimic_ERM_{group}_{idx}'.format(label=Labels_diseases[label_idx], group=group, idx=idx, l=demo[idx][1])
            with open(directory, "rb") as fp:
                smap_baseline = pickle.load(fp)
            fp.close()
            
            directory = 'imgs/{label}/mean_map/{group}/model_densenet_mimic_ERM_{group}_proposed_{idx}'.format(label=Labels_diseases[label_idx], group=group, idx=idx, l=demo[idx][1])
            with open(directory, "rb") as fp:
                smap_proposed = pickle.load(fp)
            fp.close()

            all_smap_baseline.append(np.abs(smap_baseline))
            all_smap_proposed.append(np.abs(smap_proposed))
            
#             all_smap_baseline.append(np.clip(smap_baseline, 0, 1))
#             all_smap_proposed.append(np.clip(smap_proposed, 0, 1))
                        
    mean_smap_baseline = np.mean(all_smap_baseline, axis=0)
    mean_smap_proposed = np.mean(all_smap_proposed, axis=0)

    maxval = np.max((np.max(mean_smap_baseline), np.max(mean_smap_proposed)))

    mean_smap_baseline[0][0] = maxval
    mean_smap_proposed[0][0] = maxval
    
    path = 'heatmaps/mean_map/'
        
    show(mean_smap_baseline, mean_smap_proposed, pixvals, 'All', group, path)
        


In [None]:
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]
group_name = {'race': ['white', 'black', 'asian'], 'gender': ['male', 'female'], 'age': ['0_40', '40_60', '60_80', '80+']}
group_type = {'race': [0, 1, 4], 'gender': [0, 1], 'age': [0, 1, 2, 3]}

for group in ['race', 'age', 'gender']:
    all_smap_baseline = []
    all_smap_proposed = []
    
    for label_idx in target_label:
                
        with open('imgs/{label}/task_transfer_{group}_idx'.format(label=Labels_diseases[label_idx], group=group), "rb") as fp:
            ids = pickle.load(fp)
        fp.close()
                
        for i, idx in enumerate(ids):

            directory = 'imgs/{label}/task_transfer_mean_map/{group}/model_densenet_mimic_ERM_task_transfer_{group}_{idx}'.format(label=Labels_diseases[label_idx], group=group, idx=idx, l=demo[idx][1])
            with open(directory, "rb") as fp:
                smap_baseline = pickle.load(fp)
            fp.close()
            
            directory = 'imgs/{label}/task_transfer_mean_map/{group}/model_densenet_mimic_ERM_task_transfer_{group}_proposed_{idx}'.format(label=Labels_diseases[label_idx], group=group, idx=idx, l=demo[idx][1])
            with open(directory, "rb") as fp:
                smap_proposed = pickle.load(fp)
            fp.close()

            all_smap_baseline.append(np.abs(smap_baseline))
            all_smap_proposed.append(np.abs(smap_proposed))
            
#             all_smap_baseline.append(np.clip(smap_baseline, 0, 1))
#             all_smap_proposed.append(np.clip(smap_proposed, 0, 1))
                        
    mean_smap_baseline = np.mean(all_smap_baseline, axis=0)
    mean_smap_proposed = np.mean(all_smap_proposed, axis=0)

    maxval = np.max((np.max(mean_smap_baseline), np.max(mean_smap_proposed)))

    mean_smap_baseline[0][0] = maxval
    mean_smap_proposed[0][0] = maxval
    
    path = 'heatmaps/task_transfer_mean_map/'
        
    show(mean_smap_baseline, mean_smap_proposed, pixvals, 'All', group, path)
        


In [None]:
Labels_diseases = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
target_label = [0, 1, 2, 3, 4, 7, 8, 9, 11, 12]
group_name = {'race': ['white', 'black', 'asian'], 'gender': ['male', 'female'], 'age': ['0_40', '40_60', '60_80', '80+']}
group_type = {'race': [0, 1, 4], 'gender': [0, 1], 'age': [0, 1, 2, 3]}

for label_idx in target_label:
    
    all_smap_baseline = []
    all_smap_proposed = []
    all_img = []
    
    for group in ['race', 'age', 'gender']:
        
        with open('imgs/{label}/{group}_idx'.format(label=Labels_diseases[label_idx], group=group), "rb") as fp:
            ids = pickle.load(fp)
        fp.close()
    
        for i, id in enumerate(ids):

            with open('imgs/{label}/mean_map/{group}/model_densenet_mimic_ERM_{id}'.format(group=group, label=Labels_diseases[label_idx], id=id), "rb") as fp:
                smap_baseline = pickle.load(fp)
            fp.close()

            with open('imgs/{label}/mean_map/{group}/model_densenet_mimic_ERM_proposed_{id}'.format(group=group, label=Labels_diseases[label_idx], id=id), "rb") as fp:
                smap_proposed = pickle.load(fp)
            fp.close()

            img = X_test[id]

            all_smap_baseline.append(np.abs(smap_baseline))
            all_smap_proposed.append(np.abs(smap_proposed))
            
#             all_smap_baseline.append(np.clip(smap_baseline, 0, 1))
#             all_smap_proposed.append(np.clip(smap_proposed, 0, 1))
    
    mean_smap_baseline = np.mean(all_smap_baseline, axis=0)
    mean_smap_proposed = np.mean(all_smap_proposed, axis=0)
    
    maxval = np.max((np.max(mean_smap_baseline), np.max(mean_smap_proposed)))
        
    mean_smap_baseline[0][0] = maxval
    mean_smap_proposed[0][0] = maxval
    
    path = 'heatmaps/mean_map/'

    show(mean_smap_baseline, mean_smap_proposed, pixvals, Labels_diseases[label_idx], '', path)
    

## Show individual heatmaps

In [None]:
def get_data(dataset='mimic', label_type='No Finding'):
    np.random.seed(2021)
    
    count_white = 0
    count_black = 0
    count_asian = 0
        
    X = []
    y = []
    demo = []
    if (dataset == 'mimic'):
        filename = 'data/mimic_test.tfrecords'
    elif (dataset == 'chexpert'):
        filename = '../Data/Chexpert_test.tfrecords'
        
    raw_dataset = tf.data.TFRecordDataset(filename)
    for raw_record in raw_dataset:

        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
                
        nparr = np.fromstring(example.features.feature['jpg_bytes'].bytes_list.value[0], np.uint8)
        img_np = cv.imdecode(nparr, cv.IMREAD_GRAYSCALE)  

        if (example.features.feature[label_type].float_list.value[0] == 1):
            X.append(tf.reshape(np.float32(st.resize(img_np, (224, 224))), [224, 224, 1]))

            label = []
            label.append(1 if example.features.feature['Atelectasis'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Cardiomegaly'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Consolidation'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Edema'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Enlarged Cardiomediastinum'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Fracture'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Lung Lesion'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Lung Opacity'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['No Finding'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Pleural Effusion'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Pleural Other'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Pneumonia'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Pneumothorax'].float_list.value[0] == 1 else 0)
            label.append(1 if example.features.feature['Support Devices'].float_list.value[0] == 1 else 0)


            y.append(label)
            
        else:
            X.append(0)
            y.append(np.zeros(14))
            
        race = example.features.feature['race'].int64_list.value[0]
        if (race == 4 and dataset == 'mimic'):
            race = 2
        age = example.features.feature['age'].int64_list.value[0]
        if (age > 0 and dataset == 'mimic'):
            age -= 1
        gender = example.features.feature['gender'].int64_list.value[0]
        
        temp = [race, gender, age]
#         {"race":race, "gender":gender, "age":age}
        demo.append(temp)
    
    return np.array(X), np.array(y), np.array(demo)

with tf.device('cpu'):
    img_array, y, demo = get_data(dataset='mimic', label_type='Consolidation')

In [None]:
prediction_name = 'predictions/model_densenet_mimic_ERM_on_original'

with open(prediction_name, "rb") as fp:
    y_preds_0 = pickle.load(fp)
fp.close()


prediction_name = 'predictions/model_densenet_mimic_ERM_proposed_on_original'

with open(prediction_name, "rb") as fp:
    y_preds_1 = pickle.load(fp)
fp.close()

In [None]:
def get_last_conv_model(model):
    input_layer = model.get_layer(model.layers[0].name)
    upsampling = model.get_layer(model.layers[1].name)
    denset121 = model.get_layer(model.layers[2].name)

    last_conv_layer_model = tf.keras.Sequential()
    last_conv_layer_model.add(input_layer)
    last_conv_layer_model.add(upsampling)
    last_conv_layer_model.add(tf.keras.Model(denset121.inputs, denset121.layers[-4].output))
    
    # Second, we create a model that maps the activations of the last conv
    # layer to the final class predictions
    classifier_input = tf.keras.Input(shape=denset121.layers[-3].output.shape[1:])
    x = classifier_input
    for layer in denset121.layers[-3:]:
        x = denset121.get_layer(layer.name)(x)
    x = model.get_layer(model.layers[-1].name)(x)
    classifier_model = tf.keras.Model(classifier_input, x)   
    
    return last_conv_layer_model, classifier_model

In [None]:
model_type = ''

checkpoint_name = 'model_densenet_mimic_ERM{}'.format(model_type)
    
model = get_model('ERM', '', 'disease', 'densenet')
checkpoint = tf.train.Checkpoint(model)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
model.load_weights(manager.checkpoints[0])
model_last_conv, model_classifier = get_last_conv_model(model)

checkpoint_name = 'model_densenet_mimic_ERM_race{}'.format(model_type)
    
model_race = get_model('ERM', '', 'race', 'densenet')
checkpoint = tf.train.Checkpoint(model_race)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
model_race.load_weights(manager.checkpoints[0])
model_race_last_conv, model_race_classifier = get_last_conv_model(model_race)

checkpoint_name = 'model_densenet_mimic_ERM_age{}'.format(model_type)
    
model_age = get_model('ERM', '', 'age', 'densenet')
checkpoint = tf.train.Checkpoint(model_age)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
model_age.load_weights(manager.checkpoints[0])
model_age_last_conv, model_age_classifier = get_last_conv_model(model_age)

checkpoint_name = 'model_densenet_mimic_ERM_gender{}'.format(model_type)
    
model_gender = get_model('ERM', '', 'gender', 'densenet')
checkpoint = tf.train.Checkpoint(model_gender)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
model_gender.load_weights(manager.checkpoints[0])
model_gender_last_conv, model_gender_classifier = get_last_conv_model(model_gender)

In [None]:
target_class = 2
idx = np.where(y[:, target_class]==1)[0]

In [None]:
best_thresh_0 = np.loadtxt('thresh/model_densenet_mimic_ERM_thresh.txt')
best_thresh_1 = np.loadtxt('thresh/model_densenet_mimic_ERM_proposed_thresh.txt')

In [None]:
valid_idx_0 = np.where(y_preds_0[idx, target_class] < best_thresh_0[target_class])[0]
valid_idx_1 = np.where(y_preds_1[idx, target_class] > best_thresh_1[target_class])[0]

In [None]:
valid_idx = np.intersect1d(valid_idx_0, valid_idx_1)

In [None]:
for i in valid_idx:
    print(i)
    
    plt.figure(figsize=(3, 1.5), dpi=300)
    plt.subplots_adjust(wspace=0, hspace=0)
    
    ax1 = plt.subplot(141)
    ax1.set_title('Disease', size=5)
    ax1.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), model_last_conv, model_classifier, target_class=target_class)
    ax1.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    ax2 = plt.subplot(142)
    ax2.set_title('Race', size=5)
    ax2.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), model_race_last_conv, model_race_classifier, target_class=None)
    ax2.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    
    ax3 = plt.subplot(143)
    ax3.set_title('Age', size=5)
    ax3.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), model_age_last_conv, model_age_classifier, target_class=None)
    ax3.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    ax4 = plt.subplot(144)
    ax4.set_title('Gender', size=5)
    ax4.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), model_gender_last_conv, model_gender_classifier, target_class=None)
    ax4.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
                               
    plt.show()

In [None]:
model_type = '_proposed'

checkpoint_name = 'model_densenet_mimic_ERM{}'.format(model_type)
    
model = get_model('ERM', '', 'disease', 'densenet')
checkpoint = tf.train.Checkpoint(model)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
model.load_weights(manager.checkpoints[0])
model_last_conv, model_classifier = get_last_conv_model(model)

checkpoint_name = 'model_densenet_mimic_ERM_race{}'.format(model_type)
    
model_race = get_model('ERM', '', 'race', 'densenet')
checkpoint = tf.train.Checkpoint(model_race)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
model_race.load_weights(manager.checkpoints[0])
model_race_last_conv, model_race_classifier = get_last_conv_model(model_race)

checkpoint_name = 'model_densenet_mimic_ERM_age{}'.format(model_type)
    
model_age = get_model('ERM', '', 'age', 'densenet')
checkpoint = tf.train.Checkpoint(model_age)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
model_age.load_weights(manager.checkpoints[0])
model_age_last_conv, model_age_classifier = get_last_conv_model(model_age)

checkpoint_name = 'model_densenet_mimic_ERM_gender{}'.format(model_type)
    
model_gender = get_model('ERM', '', 'gender', 'densenet')
checkpoint = tf.train.Checkpoint(model_gender)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
model_gender.load_weights(manager.checkpoints[0])
model_gender_last_conv, model_gender_classifier = get_last_conv_model(model_gender)

In [None]:
for i in valid_idx:
    print(i)
    
    plt.figure(figsize=(3, 1.5), dpi=300)
    plt.subplots_adjust(wspace=0, hspace=0)
    
    ax1 = plt.subplot(141)
    ax1.set_title('Disease', size=5)
    ax1.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), model_last_conv, model_classifier, target_class=target_class)
    ax1.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    ax2 = plt.subplot(142)
    ax2.set_title('Race', size=5)
    ax2.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), model_race_last_conv, model_race_classifier, target_class=None)
    ax2.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    
    ax3 = plt.subplot(143)
    ax3.set_title('Age', size=5)
    ax3.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), model_age_last_conv, model_age_classifier, target_class=None)
    ax3.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    ax4 = plt.subplot(144)
    ax4.set_title('Gender', size=5)
    ax4.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), model_gender_last_conv, model_gender_classifier, target_class=None)
    ax4.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
                               
    plt.show()

In [None]:
checkpoint_name = 'model_densenet_mimic_ERM'
    
baseline_model = get_model('ERM', '', 'disease', 'densenet')
checkpoint = tf.train.Checkpoint(baseline_model)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
baseline_model.load_weights(manager.checkpoints[0])
baseline_model_last_conv, baseline_model_classifier = get_last_conv_model(baseline_model)


checkpoint_name = 'model_densenet_mimic_ERM_proposed'

proposed_model = get_model('ERM', '', 'disease', 'densenet')
checkpoint = tf.train.Checkpoint(proposed_model)
manager = tf.train.CheckpointManager(checkpoint, directory='checkpoints/'+checkpoint_name, max_to_keep=1, checkpoint_name=checkpoint_name)
proposed_model.load_weights(manager.checkpoints[0])
proposed_model_last_conv, proposed_model_classifier = get_last_conv_model(proposed_model)

In [None]:
for i in valid_idx:
    print(i)
    
    plt.figure(figsize=(3, 1.5), dpi=300)
    plt.subplots_adjust(wspace=0, hspace=0)
    
    ax1 = plt.subplot(161)
    ax1.axis('off')
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier, target_class=target_class)
    ax1.imshow(jet_heatmap[0])
    
    ax2 = plt.subplot(162)
    ax2.axis('off')
    
    smap = np.abs(compute_smap(img_array[idx][i], proposed_model_last_conv))
    ax2.imshow(img_array[idx][i], cmap='gray')
    ax2.imshow(smap, cmap='Reds', alpha=0.6)
    
    ax3 = plt.subplot(163)
    ax3.axis('off')
    jet_heatmap = grad_cam_plus(np.reshape(img_array[idx][i], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier, target_class=target_class)
    ax3.imshow(jet_heatmap)
    
    ax4 = plt.subplot(164)
    ax4.axis('off')
    
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier, target_class=target_class)
    ax4.imshow(jet_heatmap[0])
    
    ax5 = plt.subplot(165)
    smap = np.abs(compute_smap(img_array[idx][i], baseline_model_last_conv))
    ax5.imshow(img_array[idx][i], cmap='gray')
    ax4.imshow(smap, cmap='Reds', alpha=0.6)
    ax5.axis('off')
    
    ax6 = plt.subplot(166)
    ax6.axis('off')
    jet_heatmap = grad_cam_plus(np.reshape(img_array[idx][i], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier, target_class=target_class)
    ax6.imshow(jet_heatmap)
    
    
    plt.show()

In [None]:
for i in [1435, 2469, 2598, 2621, 2647]:
    print(i)
    
    fig, ax = plt.subplots(4, 2, figsize=(6, 3), dpi=400)
    fig.subplots_adjust(wspace=0, hspace=0)
    
    ax[0, 0].axis('off')
    ax[0, 0].set_title('Baseline model', size=5)
    ax[0, 0].imshow(img_array[idx][i], cmap='gray')
    
    ax[0, 1].axis('off')
    ax[0, 1].set_title('Proposed model', size=5)
    ax[0, 1].imshow(img_array[idx][i], cmap='gray')
    
    ax[1, 0].axis('off')
    ax[1, 0].set_title('GradCAM', size=3)
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier, target_class=target_class)
    ax[1, 0].imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    ax[2, 0].axis('off')
    ax[2, 0].set_title('GradCAM++', size=3)
    jet_heatmap = grad_cam_plus(np.reshape(img_array[idx][i], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier, target_class=target_class)
    jet_heatmap = np.abs(jet_heatmap - 255)
    ax[2, 0].imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    ax[3, 0].axis('off')
    ax[3, 0].set_title('Saliency map', size=3)
    smap = np.abs(compute_smap(img_array[idx][i], baseline_model_last_conv))
    ax[3, 0].imshow(img_array[idx][i], cmap='gray')
    ax[3, 0].imshow(smap, cmap='Reds', alpha=0.6)
    
    ax[1, 1].axis('off')
    ax[1, 1].set_title('GradCAM', size=3)
    jet_heatmap = show_heatmap(np.reshape(img_array[idx][i], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier, target_class=target_class)
    ax[1, 1].imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    ax[2, 1].axis('off')
    ax[2, 1].set_title('GradCAM++', size=3)
    jet_heatmap = grad_cam_plus(np.reshape(img_array[idx][i], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier, target_class=target_class)
    jet_heatmap = np.abs(jet_heatmap - 255)
    ax[2, 1].imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    ax[3, 1].axis('off')
    ax[3, 1].set_title('Saliency map', size=3)
    smap = np.abs(compute_smap(img_array[idx][i], proposed_model_last_conv))
    ax[3, 1].imshow(img_array[idx][i], cmap='gray')
    ax[3, 1].imshow(smap, cmap='Reds', alpha=0.6)
    
    
    fig.tight_layout()
    plt.show()

In [None]:
for i in [1435, 2469, 2598, 2621, 2647]:
    print(i)
    
    
    plt.axis('off')
    jet_heatmap = grad_cam_plus(np.reshape(img_array[idx][i], (1, 224, 224, 1)), proposed_model_last_conv, proposed_model_classifier, target_class=target_class)
    jet_heatmap = np.abs(jet_heatmap - 255)
    plt.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    plt.savefig('plots/gradcampp_proposed_{}.jpg'.format(i))
    plt.show()

In [None]:
for i in [1435, 2469, 2598, 2621, 2647]:
    print(i)
    
    
    plt.axis('off')
    jet_heatmap = grad_cam_plus(np.reshape(img_array[idx][i], (1, 224, 224, 1)), baseline_model_last_conv, baseline_model_classifier, target_class=target_class)
    jet_heatmap = np.abs(jet_heatmap - 255)
    plt.imshow((jet_heatmap/255)*0.4+img_array[idx][i])
    
    plt.savefig('plots/gradcampp_baseline_{}.jpg'.format(i))
    plt.show()