diff --git a/README.md b/README.md
index 5c0667a..1461e48 100644
--- a/README.md
+++ b/README.md
@@ -13,9 +13,11 @@
pip install git+https://github.com/cyberzhg/keras-conv-vis
```
+The codes only work when eager execution is enabled.
+
## Guided Backpropagation
-See [the paper](https://arxiv.org/pdf/1412.6806.pdf) and [demo](./demo/guided_backpropagation.py).
+See [the paper](https://arxiv.org/pdf/1412.6806.pdf) and [demo](https://github.com/CyberZHG/keras-conv-vis/blob/master/demo/guided_backpropagation.py).
```python
import keras
@@ -47,3 +49,44 @@ visualization = Image.fromarray(gradient)
| Gradient | |
| Deconvnet without Pooling Switches | |
| Guided Backpropagation | |
+
+
+## Grad-CAM
+
+See [the paper](https://arxiv.org/abs/1610.02391) and [demo](https://github.com/CyberZHG/keras-conv-vis/blob/master/demo/grad_cam.py).
+
+```python
+import keras
+import numpy as np
+import matplotlib.pyplot as plt
+from PIL import Image
+
+from keras_conv_vis import split_model_by_layer, get_gradient, Categorical
+
+model = keras.applications.MobileNetV2()
+# Split the model at the last convolutional layer and compute the intermediate result
+head, tail = split_model_by_layer(model, 'Conv_1')
+last_conv_output = head(inputs)
+# Computer the gradient for the convolution
+gradient_model = keras.models.Sequential()
+gradient_model.add(tail)
+gradient_model.add(Categorical(284)) # 284 is the siamese cat in ImageNet
+gradients = get_gradient(gradient_model, last_conv_output)
+
+# Calculate Grad-CAM
+gradient = gradients.numpy()[0]
+gradient = np.mean(gradient, axis=(0, 1))
+grad_cam = np.mean(last_conv_output.numpy()[0] * gradient, axis=-1)
+grad_cam = grad_cam * (grad_cam > 0).astype(grad_cam.dtype)
+
+# Visualization
+grad_cam = (grad_cam - np.min(grad_cam)) / (np.max(grad_cam) - np.min(grad_cam) + 1e-4)
+heatmap = plt.get_cmap('jet')(grad_cam, bytes=True)
+heatmap = Image.fromarray(heatmap[..., :3], mode='RGB')
+heatmap = heatmap.resize((original_image.width, original_image.height), resample=Image.BILINEAR)
+visualization = Image.blend(original_image, heatmap, alpha=0.5)
+```
+
+| Input | Relevant CAM | Irrelevant CAM|
+|:-:|:-:|:-:|
+| | | |
diff --git a/README.zh-CN.md b/README.zh-CN.md
index 52a6bcc..b9417d0 100644
--- a/README.zh-CN.md
+++ b/README.zh-CN.md
@@ -13,9 +13,11 @@
pip install git+https://github.com/cyberzhg/keras-conv-vis
```
+只在启用eager execution的情况下可以使用。
+
## Guided Backpropagation
-参考[论文](https://arxiv.org/pdf/1412.6806.pdf)和[样例](./demo/guided_backpropagation.py).
+参考[论文](https://arxiv.org/pdf/1412.6806.pdf)和[样例](./demo/guided_backpropagation.py)。
```python
import keras
@@ -47,3 +49,45 @@ visualization = Image.fromarray(gradient)
| 梯度 | |
| Deconvnet without Pooling Switches | |
| Guided Backpropagation | |
+
+
+## Grad-CAM
+
+
+参考[论文](https://arxiv.org/abs/1610.02391)和[样例](https://github.com/CyberZHG/keras-conv-vis/blob/master/demo/grad_cam.py)。
+
+```python
+import keras
+import numpy as np
+import matplotlib.pyplot as plt
+from PIL import Image
+
+from keras_conv_vis import split_model_by_layer, get_gradient, Categorical
+
+model = keras.applications.MobileNetV2()
+# 将模型从最后一个卷积处分开,计算出中间结果
+head, tail = split_model_by_layer(model, 'Conv_1')
+last_conv_output = head(inputs)
+# 给最后一个卷积计算梯度
+gradient_model = keras.models.Sequential()
+gradient_model.add(tail)
+gradient_model.add(Categorical(284)) # ImageNet第284类是暹罗猫
+gradients = get_gradient(gradient_model, last_conv_output)
+
+# 计算Grad-CAM
+gradient = gradients.numpy()[0]
+gradient = np.mean(gradient, axis=(0, 1)) # 根据梯度计算每一层输出的权重
+grad_cam = np.mean(last_conv_output.numpy()[0] * gradient, axis=-1) # 对卷积输出进行加权求和
+grad_cam = grad_cam * (grad_cam > 0).astype(grad_cam.dtype) # Grad-CAM输出需要经过ReLU
+
+# 可视化
+grad_cam = (grad_cam - np.min(grad_cam)) / (np.max(grad_cam) - np.min(grad_cam) + 1e-4)
+heatmap = plt.get_cmap('jet')(grad_cam, bytes=True)
+heatmap = Image.fromarray(heatmap[..., :3], mode='RGB')
+heatmap = heatmap.resize((original_image.width, original_image.height), resample=Image.BILINEAR)
+visualization = Image.blend(original_image, heatmap, alpha=0.5)
+```
+
+| Input | Relevant CAM | Irrelevant CAM|
+|:-:|:-:|:-:|
+| | | |
diff --git a/demo/grad_cam.py b/demo/grad_cam.py
new file mode 100644
index 0000000..8726528
--- /dev/null
+++ b/demo/grad_cam.py
@@ -0,0 +1,55 @@
+import os
+
+from PIL import Image
+import numpy as np
+import matplotlib.pyplot as plt
+
+from keras_conv_vis import split_model_by_layer, get_gradient, Categorical
+from keras_conv_vis.backend import keras
+
+CLASS_CAT = 284
+CLASS_GUITAR = 546
+
+# Load an image
+current_path = os.path.dirname(os.path.realpath(__file__))
+sample_path = os.path.join(current_path, '..', 'samples')
+image_path = os.path.join(sample_path, 'cat.jpg')
+original_image = Image.open(image_path)
+image = original_image.resize((224, 224))
+inputs = np.expand_dims(np.array(image).astype(np.float) / 255.0, axis=0)
+inputs = inputs * 2.0 - 1.0
+
+
+def process(target_class,
+ cmap='jet',
+ alpha=0.5):
+ # Build model and get gradients
+ model = keras.applications.MobileNetV2()
+ head, tail = split_model_by_layer(model, 'Conv_1')
+ last_conv_output = head(inputs)
+ gradient_model = keras.models.Sequential()
+ gradient_model.add(tail)
+ gradient_model.add(Categorical(target_class))
+ gradients = get_gradient(gradient_model, last_conv_output)
+
+ # Calculate Grad-CAM
+ gradient = gradients.numpy()[0]
+ gradient = np.mean(gradient, axis=(0, 1))
+ grad_cam = np.mean(last_conv_output.numpy()[0] * gradient, axis=-1)
+ grad_cam = grad_cam * (grad_cam > 0).astype(grad_cam.dtype)
+
+ # Visualization
+ grad_cam = (grad_cam - np.min(grad_cam)) / (np.max(grad_cam) - np.min(grad_cam) + 1e-4)
+ heatmap = plt.get_cmap(cmap)(grad_cam, bytes=True)
+ heatmap = Image.fromarray(heatmap[..., :3], mode='RGB')
+ heatmap = heatmap.resize((original_image.width, original_image.height), resample=Image.BILINEAR)
+ return Image.blend(original_image, heatmap, alpha=alpha)
+
+
+for target_class in [CLASS_CAT, CLASS_GUITAR]:
+ visualization = process(target_class)
+ cat_name = 'relevant'
+ if target_class != CLASS_CAT:
+ cat_name = 'irrelevant'
+ save_name = f'cat_grad-cam_{cat_name}.jpg'
+ visualization.save(os.path.join(sample_path, save_name))
diff --git a/keras_conv_vis/gradient.py b/keras_conv_vis/gradient.py
index dc3abd4..e2af4ce 100644
--- a/keras_conv_vis/gradient.py
+++ b/keras_conv_vis/gradient.py
@@ -5,7 +5,7 @@
from .backend import keras
-__all__ = ['Categorical', 'get_gradient']
+__all__ = ['Categorical', 'get_gradient', 'split_model_by_layer']
class Categorical(keras.layers.Layer):
@@ -27,27 +27,82 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))
-def get_gradient(model: keras.models.Model,
+def get_gradient(model: Union[keras.models.Model, List[keras.models.Model]],
inputs: Union[np.ndarray, tf.Tensor, List[Union[np.ndarray, tf.Tensor]]],
targets: Optional[Union[tf.Tensor, List[tf.Tensor]]] = None):
+ """Get the gradient of input, weights, of intermediate outputs.
+
+ :param model: The keras model.
+ :param inputs: The batched input data.
+ :param targets: The default is the input tensor.
+ :return: The gradients of the targets.
+ """
+ models = model
+ if not isinstance(model, list):
+ models = [model]
if not isinstance(inputs, list):
inputs = [inputs]
for i, input_item in enumerate(inputs):
if isinstance(input_item, np.ndarray):
inputs[i] = tf.convert_to_tensor(input_item)
+ if len(inputs) == 1:
+ inputs = inputs[0]
+ input_targets = targets
if targets is None:
- target_tensors = inputs
- else:
- target_tensors = []
- if not isinstance(targets, list):
- targets = [targets]
- for target in targets:
- target_tensors.append(target)
+ targets = []
+ if not isinstance(targets, list):
+ targets = [targets]
+ model_output = inputs
with tf.GradientTape() as tape:
- for target in target_tensors:
+ for target in targets:
tape.watch(target)
- model_output = model(inputs)
- gradients = tape.gradient(model_output, target_tensors)
+ for i, model in enumerate(models):
+ if input_targets is None:
+ targets.append(model_output)
+ tape.watch(model_output)
+ model_output = model(model_output)
+ gradients = tape.gradient(model_output, targets)
if len(gradients) == 1:
gradients = gradients[0]
return gradients
+
+
+def split_model_by_layer(model: keras.models.Model,
+ layer_cut: Union[str, keras.layers.Layer]):
+ """Split a model into two parts.
+
+ :param model: The keras model.
+ :param layer_cut: The layer whose output will be cut. The layer must be a cut point,
+ a.k.a., the output edge is a bridge.
+ :return: The two models.
+ """
+ if isinstance(layer_cut, str):
+ layer_cut = model.get_layer(layer_cut)
+ head = keras.models.Model(model.inputs, layer_cut.output)
+ meet = False
+ mappings, depends = {}, set()
+ for layer in model.layers:
+ if layer_cut is layer:
+ meet = True
+ mappings[layer.name] = keras.layers.Input(layer.output.shape[1:])
+ depends.add(layer.name)
+ continue
+ if not meet:
+ continue
+ inputs = layer.input
+ if not isinstance(inputs, list):
+ inputs = [inputs]
+ new_inputs = []
+ for input_tensor in inputs:
+ name = input_tensor.name.rsplit('/')[0]
+ depends.add(name)
+ new_inputs.append(mappings[name])
+ if not isinstance(layer.input, list):
+ new_inputs = new_inputs[0]
+ mappings[layer.name] = layer(new_inputs)
+ outputs = []
+ for layer in model.layers:
+ if layer.name in mappings and layer.name not in depends:
+ outputs.append(mappings[layer.name])
+ tail = keras.models.Model(mappings[layer_cut.name], outputs)
+ return head, tail
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 1f22a48..dcd634e 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -6,3 +6,4 @@ nose
pycodestyle
coverage
Pillow
+matplotlib
diff --git a/samples/cat_grad-cam_irrelevant.jpg b/samples/cat_grad-cam_irrelevant.jpg
new file mode 100644
index 0000000..d9b067c
Binary files /dev/null and b/samples/cat_grad-cam_irrelevant.jpg differ
diff --git a/samples/cat_grad-cam_relevant.jpg b/samples/cat_grad-cam_relevant.jpg
new file mode 100644
index 0000000..9358e1c
Binary files /dev/null and b/samples/cat_grad-cam_relevant.jpg differ
diff --git a/tests/test_get_gradient.py b/tests/test_get_gradient.py
index 31c93be..e2f87a3 100644
--- a/tests/test_get_gradient.py
+++ b/tests/test_get_gradient.py
@@ -2,7 +2,7 @@
import numpy as np
-from keras_conv_vis import get_gradient, Categorical
+from keras_conv_vis import get_gradient, Categorical, split_model_by_layer
from keras_conv_vis.backend import keras
@@ -17,3 +17,12 @@ def test_get_gradient(self):
get_gradient(gradient_model, np.random.random((1, 224, 224, 3)))
get_gradient(gradient_model, np.random.random((1, 224, 224, 3)),
targets=model.get_layer('bn_Conv1').trainable_weights[0])
+
+ def test_cut_model(self):
+ model = keras.applications.MobileNetV2()
+ head, tail = split_model_by_layer(model, 'block_5_add')
+ gradient_model = keras.models.Sequential()
+ gradient_model.add(tail)
+ gradient_model.add(Categorical(7))
+ gradients = get_gradient([head, gradient_model], np.random.random((1, 224, 224, 3)))
+ self.assertEqual(2, len(gradients))