In [None]:
import PIL, io, ipywidgets

import keras
import keras.applications.imagenet_utils

import numpy as np
import pandas as pd

import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

inputShape = (299, 299, 3)
network = keras.applications.xception

print("[INFO] loading {}...".format(network.__name__))
model = network.Xception(weights="imagenet")
preprocess = network.preprocess_input

In [None]:
# Grayscale -> Colormap mapping functions
def s_jetmap(s):
    g = s[0]
    if g < 64:  return np.asarray((0, 4*g, 255)).astype('uint8')
    if g < 128: return np.asarray((0, 255, 255+4*(64-g))).astype('uint8')
    if g < 192: return np.asarray((4*(g-128), 255, 0)).astype('uint8')
    return np.asarray((255, 255+4*(192-g), 0)).astype('uint8')

def s_graymap(s):
    return np.asarray((s[0], s[0], s[0])).astype('uint8')

def colormap(heatmap, cmap=s_jetmap):
    return np.apply_along_axis(cmap, -1, np.expand_dims(heatmap, axis=-1)).reshape((heatmap.shape[0],heatmap.shape[1],3))

# Pillow -> ipywidgets transformation
def pil2ipy(image):
    buf = io.BytesIO()
    image.save(buf, format='png')
    return buf.getvalue()

# Image scale to uint8 0-255
def uint8scale(array):
    array /= np.max(array)
    array *= 255
    return np.maximum(array, 0.0).astype('uint8')

class GAPCAM:
    # Generate GAP model to return both classification and final feature maps layers
    def __init__(self, model, preprocess):
        self.GAP_model = keras.models.Model(
                             inputs=model.input,
                             outputs=(model.layers[-3].output, model.layers[-1].output)
                         )
        self.preprocess = preprocess
        self.pred = None
        self.conv = None
        
    def process(self, image):
        img = keras.preprocessing.image.img_to_array(image.resize(inputShape[:2])).astype('uint8')
        img = self.preprocess(np.expand_dims(img, axis=0))
        self.conv, self.pred = self.GAP_model.predict(img)
    
    def get_predictions(self, top=5):
        P = keras.applications.imagenet_utils.decode_predictions(self.pred, top)
        results=[{
            'rank': i+1, 
            'score': prob*100.0, 
            'class': np.where(self.pred == prob)[1][0], 
            'id': imagenetID, 
            'label': label
        } for (i, (imagenetID, label, prob)) in enumerate(P[0])]
        return results
    
    def get_heatmap(self, channel):
        last_conv_output = np.squeeze(self.conv)
        all_amp_layer_weights = self.GAP_model.layers[-1].get_weights()[0]
        amp_layer_weights = all_amp_layer_weights[:, channel]
        fh, fw, fn = last_conv_output.shape
        final_output = np.dot(last_conv_output.reshape((fh*fw, fn)), amp_layer_weights).reshape(fh,fw)
        return uint8scale(final_output)


In [None]:
# This is the image to be classified with above network
image_file = '/tmp/workspace/Pictures/Saxophone.jpg'

print("[INFO] loading image...")

orig_image = PIL.Image.open(image_file)
img = ipywidgets.Image(value=pil2ipy(orig_image))
display(img)

In [None]:
print("[INFO] classifying image {} with {}...".format(image_file, network.__name__))
gap_cam = GAPCAM(model, preprocess)
gap_cam.process(orig_image)
results = gap_cam.get_predictions(top=10)

p = pd.DataFrame(results).set_index('rank')
print(p)

In [None]:
# Get the GAP-CAM heatmap for the dominant class
maxclass = results[0]['class']
heatmap = gap_cam.get_heatmap(maxclass)
jetmap = colormap(heatmap, s_jetmap)
jetmap_image = PIL.Image.fromarray(jetmap).resize((orig_image.width, orig_image.height), resample=PIL.Image.BICUBIC)

# Overlay the heatmap over the original picture
image_array = keras.preprocessing.image.img_to_array(orig_image)
jetmap_array = keras.preprocessing.image.img_to_array(jetmap_image)

overlay_image = PIL.Image.fromarray(uint8scale(image_array * jetmap_array))
img.value = pil2ipy(overlay_image)