In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import tensorflow as tf
from matplotlib import pyplot as plt

In [None]:
from utils.gradcam import GradCAM, overlay_gradCAM

## Load IMG

In [None]:
# 模型參數設定
# MobileNetV3Large、 MobileNetV3Small、CustomizeLarge、 CustomizeSmall, 
backbone = ['CustomizeLarge', 'CustomizeSmall']
# SE CBAM CA
SE_CBAM_CA = ['SE', 'CBAM', 'CA']
# Adam RMSprop CLR
LR_mode = 'Adam'

# 100 Bird Species  or  325 Bird Species  or  cifar100  or  cifar10
Dataset = "325 Bird Species"

In [None]:
gradcam = []
models =[]
for bb in backbone:
    for attention in SE_CBAM_CA:
        model_dir = './weights/{0}/{1}_{2}_{3}/best_model'.format(bb, attention, LR_mode, Dataset)
        print(model_dir)
        models.append(tf.keras.models.load_model(model_dir))
        print('model load.')

In [None]:
# models[0].summary()
brid_pd = pd.read_csv('./Dataset/325 Bird Species/class_dict.csv')
brid_label = brid_pd['class']

In [None]:
brid_name = 'AFRICAN FIREFINCH'
img = cv2.imread('Dataset/325 Bird Species/train/{0}/020.jpg'.format(brid_name))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
upsample_size = (img.shape[0],img.shape[1])
print(upsample_size)
pred_img = tf.image.resize(img,(224,224))
pred_img = tf.reshape(pred_img, (1,224,224,3))
pred_img = tf.cast(pred_img, dtype=tf.float32)/255

In [None]:
lagers_name = ["CSP1_concatenate", 
               "CSP2_concatenate", 
               "CSP3_concatenate", 
               "CSP4_concatenate", 
               "CSP5_concatenate",
               'Conv2',
               "Conv2hardswish", 
               'Conv3',
               "Conv3hardswish"]

for layer_name in lagers_name:
    gradcam = []
    # 類別、分數
    result = [[],[]]
    for model in models:
        gradCAM = GradCAM(model=model, layerName=layer_name)
        pred = model.predict(pred_img)
        idx = pred.argmax()
        classIdx = tf.cast(pred.max(),dtype=tf.int32)
        
        # print(idx, pred[:,idx])
        result[1].append(pred[:,idx][0])
        result[0].append(idx)

        cam3 = gradCAM.compute_heatmap(image=pred_img, classIdx=classIdx, upsample_size=upsample_size)

        gradcam.append(cv2.cvtColor(overlay_gradCAM(img, cam3), cv2.COLOR_BGR2RGB))
        
    # ---------------------show grad-cam-------------------
    print('layer_name:', layer_name)
    
    attention = ['SE','CBAM','CA']
    plt.figure(num='gradcam',figsize=(15,15))
    i=0
    j=0
    for index,gcam in enumerate(gradcam):
        plt.subplot(len(backbone),len(attention),index+1)
        plt.title(backbone[i]+"_"+attention[j]+'\n'+'{0}:{1:4f}'.format(brid_label[result[0][index]],result[1][index]))
        plt.imshow(gcam)
        j+=1
        if (index+1) % 3 == 0:
            i+=1
            j=0
    plt.show()

# Save per classes heatmap

In [None]:
classes = 50

for num in range(classes):
    brid_name = brid_label[num]
    img = cv2.imread('Dataset/325 Bird Species/train/{0}/001.jpg'.format(brid_name))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    upsample_size = (img.shape[0],img.shape[1])
    pred_img = tf.image.resize(img,(224,224))
    pred_img = tf.reshape(pred_img, (1,224,224,3))
    pred_img = tf.cast(pred_img, dtype=tf.float32)/255
    
    lagers_name = ["CSP1_concatenate", 
                    "CSP2_concatenate", 
                    "CSP3_concatenate", 
                    "CSP4_concatenate", 
                    "CSP5_concatenate",
                    'Conv2',
                    "Conv2hardswish", 
                    'Conv3',
                    "Conv3hardswish"
            ]

    for layer_name in lagers_name:
        gradcam = []
        # 類別、分數
        result = [[],[]]
        for model in models:
            gradCAM = GradCAM(model=model, layerName=layer_name)
            pred = model.predict(pred_img)
            idx = pred.argmax()
            classIdx = idx
            
            # print(idx, pred[:,idx])
            result[1].append(pred[:,idx][0])
            result[0].append(idx)

            cam3 = gradCAM.compute_heatmap(image=pred_img, classIdx=classIdx, upsample_size=upsample_size)

            gradcam.append(cv2.cvtColor(overlay_gradCAM(img, cam3), cv2.COLOR_BGR2RGB))
            
        # ---------------------show grad-cam-------------------
        # print('layer_name:', layer_name)
        
        attention = ['SE','CBAM','CA']
        plt.figure(num='gradcam',figsize=(16,9))

        
        i=0
        j=0
        for index,gcam in enumerate(gradcam):
            plt.subplot(len(backbone),len(attention),index+1)
            plt.title(backbone[i]+"_"+attention[j]+'\n'+'{0}:{1:4f}'.format(brid_label[result[0][index]],result[1][index]))
            plt.axis('off')
            plt.imshow(gcam)
            j+=1
            if (index+1) % 3 == 0:
                i+=1
                j=0
        
        output_dir = './result/gradcam/'+ brid_label[num]          
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
            
        plt.savefig(output_dir + '/{0}.png'.format(layer_name),
                    dpi = 300,
                    facecolor='white',
                    bbox_inches = 'tight',
                    # pad_inches = 0,
                    )
        
        plt.close()
        # plt.show()