pip install git+https://github.com/cyberzhg/keras-conv-vis
The codes only work when eager execution is enabled.
See guided backpropagation and the demo.
from tensorflow import keras
import numpy as np
from PIL import Image
from keras_conv_vis import replace_relu, get_gradient, Categorical
model = keras.applications.MobileNetV2()
# Replace all the ReLUs with guided backpropagation
model = replace_relu(model, relu_type='guided')
gradient_model = keras.models.Sequential()
gradient_model.add(model)
# Activate only the target class
gradient_model.add(Categorical(284)) # 284 is the siamese cat in ImageNet
# Get the gradient
gradients = get_gradient(gradient_model, inputs)
# Normalize gradient and convert it to image
gradient = gradients.numpy()[0]
gradient = (gradient - np.min(gradient)) / (np.max(gradient) - np.min(gradient) + 1e-4)
gradient = (gradient * 255.0).astype(np.uint8)
visualization = Image.fromarray(gradient)
Type | Relevant | Irrelevant |
---|---|---|
Input | ||
Gradient | ||
Deconvnet without Pooling Switches | ||
Guided Backpropagation |
See:
For Grad-CAM:
from tensorflow import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras_conv_vis import grad_cam
model = keras.applications.MobileNetV2()
cam = grad_cam(model=model, layer_cut='Conv_1', inputs=inputs, target_class=284)[0]
# Visualization
heatmap = plt.get_cmap('jet')(grad_cam, bytes=True)
heatmap = Image.fromarray(heatmap[..., :3], mode='RGB')
heatmap = heatmap.resize((original_image.width, original_image.height), resample=Image.BILINEAR)
visualization = Image.blend(original_image, heatmap, alpha=0.5)
For Grad-CAM++:
from tensorflow import keras
from keras_conv_vis import grad_cam, replace_layers
model = keras.applications.MobileNetV2()
# The `softmax` activation in the last layer should be removed.
model = replace_layers(model, activation_mapping={'softmax': 'linear'})
cam = grad_cam(
model=model,
layer_cut='Conv_1',
inputs=inputs,
target_class=284,
plus=True, # Enable Grad-CAM++
)[0]
Type | Input | Relevant CAM | Irrelevant CAM |
---|---|---|---|
Grad-CAM | |||
Grad-CAM++ |