Keras Convolution Visualization
[中文 |English ]
pip install git+https://github.com/cyberzhg/keras-conv-vis
只在启用eager execution的情况下可以使用。
参考论文 和样例 。
import keras
import numpy as np
from PIL import Image
from keras_conv_vis import replace_relu , get_gradient , Categorical
model = keras .applications .MobileNetV2 ()
# 将模型中所有的ReLU替换为所需的特殊反向传播
model = replace_relu (model , relu_type = 'guided' )
gradient_model = keras .models .Sequential ()
gradient_model .add (model )
# 只让特定的类别传递梯度
gradient_model .add (Categorical (284 )) # ImageNet第284类是暹罗猫
# 获取输入的梯度
gradients = get_gradient (gradient_model , inputs )
# 将得到梯度归一化和可视化
gradient = gradients .numpy ()[0 ]
gradient = (gradient - np .min (gradient )) / (np .max (gradient ) - np .min (gradient ) + 1e-4 )
gradient = (gradient * 255.0 ).astype (np .uint8 )
visualization = Image .fromarray (gradient )
类别
可视化
输入
梯度
Deconvnet without Pooling Switches
Guided Backpropagation
参考论文 和样例 。
import keras
import numpy as np
import matplotlib .pyplot as plt
from PIL import Image
from keras_conv_vis import split_model_by_layer , get_gradient , Categorical
model = keras .applications .MobileNetV2 ()
# 将模型从最后一个卷积处分开,计算出中间结果
head , tail = split_model_by_layer (model , 'Conv_1' )
last_conv_output = head (inputs )
# 给最后一个卷积计算梯度
gradient_model = keras .models .Sequential ()
gradient_model .add (tail )
gradient_model .add (Categorical (284 )) # ImageNet第284类是暹罗猫
gradients = get_gradient (gradient_model , last_conv_output )
# 计算Grad-CAM
gradient = gradients .numpy ()[0 ]
gradient = np .mean (gradient , axis = (0 , 1 )) # 根据梯度计算每一层输出的权重
grad_cam = np .mean (last_conv_output .numpy ()[0 ] * gradient , axis = - 1 ) # 对卷积输出进行加权求和
grad_cam = grad_cam * (grad_cam > 0 ).astype (grad_cam .dtype ) # Grad-CAM输出需要经过ReLU
# 可视化
grad_cam = (grad_cam - np .min (grad_cam )) / (np .max (grad_cam ) - np .min (grad_cam ) + 1e-4 )
heatmap = plt .get_cmap ('jet' )(grad_cam , bytes = True )
heatmap = Image .fromarray (heatmap [..., :3 ], mode = 'RGB' )
heatmap = heatmap .resize ((original_image .width , original_image .height ), resample = Image .BILINEAR )
visualization = Image .blend (original_image , heatmap , alpha = 0.5 )
Input
Relevant CAM
Irrelevant CAM