Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Latest commit

 

History

History
93 lines (71 loc) · 4.05 KB

README.zh-CN.md

File metadata and controls

93 lines (71 loc) · 4.05 KB

Keras Convolution Visualization

Travis Coverage

[中文|English]

安装

pip install git+https://github.com/cyberzhg/keras-conv-vis

只在启用eager execution的情况下可以使用。

Guided Backpropagation

参考论文样例

import keras
import numpy as np
from PIL import Image

from keras_conv_vis import replace_relu, get_gradient, Categorical

model = keras.applications.MobileNetV2()
# 将模型中所有的ReLU替换为所需的特殊反向传播
model = replace_relu(model, relu_type='guided')
gradient_model = keras.models.Sequential()
gradient_model.add(model)
# 只让特定的类别传递梯度
gradient_model.add(Categorical(284))  # ImageNet第284类是暹罗猫
# 获取输入的梯度
gradients = get_gradient(gradient_model, inputs)

# 将得到梯度归一化和可视化
gradient = gradients.numpy()[0]
gradient = (gradient - np.min(gradient)) / (np.max(gradient) - np.min(gradient) + 1e-4)
gradient = (gradient * 255.0).astype(np.uint8)
visualization = Image.fromarray(gradient)
类别 可视化
输入
梯度
Deconvnet without Pooling Switches
Guided Backpropagation

Grad-CAM

参考论文样例

import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from keras_conv_vis import split_model_by_layer, get_gradient, Categorical

model = keras.applications.MobileNetV2()
# 将模型从最后一个卷积处分开,计算出中间结果
head, tail = split_model_by_layer(model, 'Conv_1')
last_conv_output = head(inputs)
# 给最后一个卷积计算梯度
gradient_model = keras.models.Sequential()
gradient_model.add(tail)
gradient_model.add(Categorical(284))  # ImageNet第284类是暹罗猫
gradients = get_gradient(gradient_model, last_conv_output)

# 计算Grad-CAM
gradient = gradients.numpy()[0]
gradient = np.mean(gradient, axis=(0, 1))  # 根据梯度计算每一层输出的权重
grad_cam = np.mean(last_conv_output.numpy()[0] * gradient, axis=-1)  # 对卷积输出进行加权求和
grad_cam = grad_cam * (grad_cam > 0).astype(grad_cam.dtype)  # Grad-CAM输出需要经过ReLU

# 可视化
grad_cam = (grad_cam - np.min(grad_cam)) / (np.max(grad_cam) - np.min(grad_cam) + 1e-4)
heatmap = plt.get_cmap('jet')(grad_cam, bytes=True)
heatmap = Image.fromarray(heatmap[..., :3], mode='RGB')
heatmap = heatmap.resize((original_image.width, original_image.height), resample=Image.BILINEAR)
visualization = Image.blend(original_image, heatmap, alpha=0.5)
Input Relevant CAM Irrelevant CAM