Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Add Grad-CAM
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Oct 1, 2020
1 parent 367e04b commit 3cfc143
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 15 deletions.
45 changes: 44 additions & 1 deletion README.md
Expand Up @@ -13,9 +13,11 @@
pip install git+https://github.com/cyberzhg/keras-conv-vis
```

The codes only work when eager execution is enabled.

## Guided Backpropagation

See [the paper](https://arxiv.org/pdf/1412.6806.pdf) and [demo](./demo/guided_backpropagation.py).
See [the paper](https://arxiv.org/pdf/1412.6806.pdf) and [demo](https://github.com/CyberZHG/keras-conv-vis/blob/master/demo/guided_backpropagation.py).

```python
import keras
Expand Down Expand Up @@ -47,3 +49,44 @@ visualization = Image.fromarray(gradient)
| Gradient | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_gradient_relevant.jpg" width="224" height="224" /> |
| Deconvnet without Pooling Switches | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_deconvnet_relevant.jpg" width="224" height="224" /> |
| Guided Backpropagation | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_guided_relevant.jpg" width="224" height="224" /> |


## Grad-CAM

See [the paper](https://arxiv.org/abs/1610.02391) and [demo](https://github.com/CyberZHG/keras-conv-vis/blob/master/demo/grad_cam.py).

```python
import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from keras_conv_vis import split_model_by_layer, get_gradient, Categorical

model = keras.applications.MobileNetV2()
# Split the model at the last convolutional layer and compute the intermediate result
head, tail = split_model_by_layer(model, 'Conv_1')
last_conv_output = head(inputs)
# Computer the gradient for the convolution
gradient_model = keras.models.Sequential()
gradient_model.add(tail)
gradient_model.add(Categorical(284)) # 284 is the siamese cat in ImageNet
gradients = get_gradient(gradient_model, last_conv_output)

# Calculate Grad-CAM
gradient = gradients.numpy()[0]
gradient = np.mean(gradient, axis=(0, 1))
grad_cam = np.mean(last_conv_output.numpy()[0] * gradient, axis=-1)
grad_cam = grad_cam * (grad_cam > 0).astype(grad_cam.dtype)

# Visualization
grad_cam = (grad_cam - np.min(grad_cam)) / (np.max(grad_cam) - np.min(grad_cam) + 1e-4)
heatmap = plt.get_cmap('jet')(grad_cam, bytes=True)
heatmap = Image.fromarray(heatmap[..., :3], mode='RGB')
heatmap = heatmap.resize((original_image.width, original_image.height), resample=Image.BILINEAR)
visualization = Image.blend(original_image, heatmap, alpha=0.5)
```

| Input | Relevant CAM | Irrelevant CAM|
|:-:|:-:|:-:|
| <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat.jpg" width="224" height="224" /> | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_grad-cam_relevant.jpg" width="224" height="224" /> | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_grad-cam_irrelevant.jpg" width="224" height="224" /> |
46 changes: 45 additions & 1 deletion README.zh-CN.md
Expand Up @@ -13,9 +13,11 @@
pip install git+https://github.com/cyberzhg/keras-conv-vis
```

只在启用eager execution的情况下可以使用。

## Guided Backpropagation

参考[论文](https://arxiv.org/pdf/1412.6806.pdf)[样例](./demo/guided_backpropagation.py).
参考[论文](https://arxiv.org/pdf/1412.6806.pdf)[样例](./demo/guided_backpropagation.py)

```python
import keras
Expand Down Expand Up @@ -47,3 +49,45 @@ visualization = Image.fromarray(gradient)
| 梯度 | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_gradient_relevant.jpg" width="224" height="224" /> |
| Deconvnet without Pooling Switches | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_deconvnet_relevant.jpg" width="224" height="224" /> |
| Guided Backpropagation | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_guided_relevant.jpg" width="224" height="224" /> |


## Grad-CAM


参考[论文](https://arxiv.org/abs/1610.02391)[样例](https://github.com/CyberZHG/keras-conv-vis/blob/master/demo/grad_cam.py)

```python
import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from keras_conv_vis import split_model_by_layer, get_gradient, Categorical

model = keras.applications.MobileNetV2()
# 将模型从最后一个卷积处分开,计算出中间结果
head, tail = split_model_by_layer(model, 'Conv_1')
last_conv_output = head(inputs)
# 给最后一个卷积计算梯度
gradient_model = keras.models.Sequential()
gradient_model.add(tail)
gradient_model.add(Categorical(284)) # ImageNet第284类是暹罗猫
gradients = get_gradient(gradient_model, last_conv_output)

# 计算Grad-CAM
gradient = gradients.numpy()[0]
gradient = np.mean(gradient, axis=(0, 1)) # 根据梯度计算每一层输出的权重
grad_cam = np.mean(last_conv_output.numpy()[0] * gradient, axis=-1) # 对卷积输出进行加权求和
grad_cam = grad_cam * (grad_cam > 0).astype(grad_cam.dtype) # Grad-CAM输出需要经过ReLU

# 可视化
grad_cam = (grad_cam - np.min(grad_cam)) / (np.max(grad_cam) - np.min(grad_cam) + 1e-4)
heatmap = plt.get_cmap('jet')(grad_cam, bytes=True)
heatmap = Image.fromarray(heatmap[..., :3], mode='RGB')
heatmap = heatmap.resize((original_image.width, original_image.height), resample=Image.BILINEAR)
visualization = Image.blend(original_image, heatmap, alpha=0.5)
```

| Input | Relevant CAM | Irrelevant CAM|
|:-:|:-:|:-:|
| <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat.jpg" width="224" height="224" /> | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_grad-cam_relevant.jpg" width="224" height="224" /> | <img src="https://github.com/CyberZHG/keras-conv-vis/raw/master/samples/cat_grad-cam_irrelevant.jpg" width="224" height="224" /> |
55 changes: 55 additions & 0 deletions demo/grad_cam.py
@@ -0,0 +1,55 @@
import os

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

from keras_conv_vis import split_model_by_layer, get_gradient, Categorical
from keras_conv_vis.backend import keras

CLASS_CAT = 284
CLASS_GUITAR = 546

# Load an image
current_path = os.path.dirname(os.path.realpath(__file__))
sample_path = os.path.join(current_path, '..', 'samples')
image_path = os.path.join(sample_path, 'cat.jpg')
original_image = Image.open(image_path)
image = original_image.resize((224, 224))
inputs = np.expand_dims(np.array(image).astype(np.float) / 255.0, axis=0)
inputs = inputs * 2.0 - 1.0


def process(target_class,
cmap='jet',
alpha=0.5):
# Build model and get gradients
model = keras.applications.MobileNetV2()
head, tail = split_model_by_layer(model, 'Conv_1')
last_conv_output = head(inputs)
gradient_model = keras.models.Sequential()
gradient_model.add(tail)
gradient_model.add(Categorical(target_class))
gradients = get_gradient(gradient_model, last_conv_output)

# Calculate Grad-CAM
gradient = gradients.numpy()[0]
gradient = np.mean(gradient, axis=(0, 1))
grad_cam = np.mean(last_conv_output.numpy()[0] * gradient, axis=-1)
grad_cam = grad_cam * (grad_cam > 0).astype(grad_cam.dtype)

# Visualization
grad_cam = (grad_cam - np.min(grad_cam)) / (np.max(grad_cam) - np.min(grad_cam) + 1e-4)
heatmap = plt.get_cmap(cmap)(grad_cam, bytes=True)
heatmap = Image.fromarray(heatmap[..., :3], mode='RGB')
heatmap = heatmap.resize((original_image.width, original_image.height), resample=Image.BILINEAR)
return Image.blend(original_image, heatmap, alpha=alpha)


for target_class in [CLASS_CAT, CLASS_GUITAR]:
visualization = process(target_class)
cat_name = 'relevant'
if target_class != CLASS_CAT:
cat_name = 'irrelevant'
save_name = f'cat_grad-cam_{cat_name}.jpg'
visualization.save(os.path.join(sample_path, save_name))
79 changes: 67 additions & 12 deletions keras_conv_vis/gradient.py
Expand Up @@ -5,7 +5,7 @@

from .backend import keras

__all__ = ['Categorical', 'get_gradient']
__all__ = ['Categorical', 'get_gradient', 'split_model_by_layer']


class Categorical(keras.layers.Layer):
Expand All @@ -27,27 +27,82 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))


def get_gradient(model: keras.models.Model,
def get_gradient(model: Union[keras.models.Model, List[keras.models.Model]],
inputs: Union[np.ndarray, tf.Tensor, List[Union[np.ndarray, tf.Tensor]]],
targets: Optional[Union[tf.Tensor, List[tf.Tensor]]] = None):
"""Get the gradient of input, weights, of intermediate outputs.
:param model: The keras model.
:param inputs: The batched input data.
:param targets: The default is the input tensor.
:return: The gradients of the targets.
"""
models = model
if not isinstance(model, list):
models = [model]
if not isinstance(inputs, list):
inputs = [inputs]
for i, input_item in enumerate(inputs):
if isinstance(input_item, np.ndarray):
inputs[i] = tf.convert_to_tensor(input_item)
if len(inputs) == 1:
inputs = inputs[0]
input_targets = targets
if targets is None:
target_tensors = inputs
else:
target_tensors = []
if not isinstance(targets, list):
targets = [targets]
for target in targets:
target_tensors.append(target)
targets = []
if not isinstance(targets, list):
targets = [targets]
model_output = inputs
with tf.GradientTape() as tape:
for target in target_tensors:
for target in targets:
tape.watch(target)
model_output = model(inputs)
gradients = tape.gradient(model_output, target_tensors)
for i, model in enumerate(models):
if input_targets is None:
targets.append(model_output)
tape.watch(model_output)
model_output = model(model_output)
gradients = tape.gradient(model_output, targets)
if len(gradients) == 1:
gradients = gradients[0]
return gradients


def split_model_by_layer(model: keras.models.Model,
layer_cut: Union[str, keras.layers.Layer]):
"""Split a model into two parts.
:param model: The keras model.
:param layer_cut: The layer whose output will be cut. The layer must be a cut point,
a.k.a., the output edge is a bridge.
:return: The two models.
"""
if isinstance(layer_cut, str):
layer_cut = model.get_layer(layer_cut)
head = keras.models.Model(model.inputs, layer_cut.output)
meet = False
mappings, depends = {}, set()
for layer in model.layers:
if layer_cut is layer:
meet = True
mappings[layer.name] = keras.layers.Input(layer.output.shape[1:])
depends.add(layer.name)
continue
if not meet:
continue
inputs = layer.input
if not isinstance(inputs, list):
inputs = [inputs]
new_inputs = []
for input_tensor in inputs:
name = input_tensor.name.rsplit('/')[0]
depends.add(name)
new_inputs.append(mappings[name])
if not isinstance(layer.input, list):
new_inputs = new_inputs[0]
mappings[layer.name] = layer(new_inputs)
outputs = []
for layer in model.layers:
if layer.name in mappings and layer.name not in depends:
outputs.append(mappings[layer.name])
tail = keras.models.Model(mappings[layer_cut.name], outputs)
return head, tail
1 change: 1 addition & 0 deletions requirements-dev.txt
Expand Up @@ -6,3 +6,4 @@ nose
pycodestyle
coverage
Pillow
matplotlib
Binary file added samples/cat_grad-cam_irrelevant.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added samples/cat_grad-cam_relevant.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 10 additions & 1 deletion tests/test_get_gradient.py
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from keras_conv_vis import get_gradient, Categorical
from keras_conv_vis import get_gradient, Categorical, split_model_by_layer
from keras_conv_vis.backend import keras


Expand All @@ -17,3 +17,12 @@ def test_get_gradient(self):
get_gradient(gradient_model, np.random.random((1, 224, 224, 3)))
get_gradient(gradient_model, np.random.random((1, 224, 224, 3)),
targets=model.get_layer('bn_Conv1').trainable_weights[0])

def test_cut_model(self):
model = keras.applications.MobileNetV2()
head, tail = split_model_by_layer(model, 'block_5_add')
gradient_model = keras.models.Sequential()
gradient_model.add(tail)
gradient_model.add(Categorical(7))
gradients = get_gradient([head, gradient_model], np.random.random((1, 224, 224, 3)))
self.assertEqual(2, len(gradients))

0 comments on commit 3cfc143

Please sign in to comment.