In [1]:
import os

import mxnet as mx
import mxnet.ndarray as nd
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet.gluon.model_zoo import model_store
from mxnet import autograd

from viz.layers import Activation, Conv2D

import numpy as np
import utils
import cv2

In [2]:
class AlexNet(HybridBlock):
    def __init__(self, classes=1000, **kwargs):
        super(AlexNet, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')
            with self.features.name_scope():
                self.features.add(Conv2D(64, kernel_size=11, strides=4, padding=2))
                self.features.add(Activation('relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                
                self.features.add(Conv2D(192, kernel_size=5, padding=2))
                self.features.add(Activation('relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                
                self.features.add(Conv2D(384, kernel_size=3, padding=1))
                self.features.add(Activation('relu'))
                
                self.features.add(Conv2D(256, kernel_size=3, padding=1))
                self.features.add(Activation('relu'))
                
                self.features.add(Conv2D(256, kernel_size=3, padding=1))
                self.features.add(Activation('relu'))
                self.features.add(nn.MaxPool2D(pool_size=3, strides=2))
                
                self.features.add(nn.Flatten())

                self.features.add(nn.Dense(4096))
                self.features.add(Activation('relu'))
                self.features.add(nn.Dropout(0.5))

                self.features.add(nn.Dense(4096))
                self.features.add(Activation('relu'))
                self.features.add(nn.Dropout(0.5))

            self.output = nn.Dense(classes)

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x

# Constructor
def alexnet(pretrained=False, ctx=mx.cpu(),
            root=os.path.join('~', '.mxnet', 'models'), **kwargs):
    r"""AlexNet model from the `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Parameters
    ----------
    pretrained : bool, default False
        Whether to load the pretrained weights for model.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.
    """
    net = AlexNet(**kwargs)
    if pretrained:
        net.load_params(model_store.get_model_file('alexnet', root=root), ctx=ctx)
    return net

In [3]:
alexnet = alexnet(pretrained=True)

In [4]:
def preprocess(data):
    data = mx.image.imresize(data, 256, 256)
    data, _ = mx.image.center_crop(data, (224, 224))
    data = data.astype(np.float32)
    data = data/255
    data = mx.image.color_normalize(data,
                                    mean=mx.nd.array([0.485, 0.456, 0.406]),
                                    std=mx.nd.array([0.229, 0.224, 0.225]))
    data = mx.nd.transpose(data, (2,0,1))
    return data

In [5]:
with open("img/snake.jpg", 'rb') as fp:
    str_image = fp.read()

image = mx.img.imdecode(str_image)
image = preprocess(image)
image = image.expand_dims(axis=0)

with autograd.record():
    out = alexnet(image)

out.argmax(axis=1)


[ 56.]
<NDArray 1 @cpu(0)>

In [6]:
for k, v in Conv2D.outputs.items():
    print(k, v.shape, v.grad.shape)

alexnet0_conv2d0 (1, 64, 55, 55) (1, 64, 55, 55)
alexnet0_conv2d1 (1, 192, 27, 27) (1, 192, 27, 27)
alexnet0_conv2d2 (1, 384, 13, 13) (1, 384, 13, 13)
alexnet0_conv2d3 (1, 256, 13, 13) (1, 256, 13, 13)
alexnet0_conv2d4 (1, 256, 13, 13) (1, 256, 13, 13)


In [7]:
def save_class_activation_on_image(org_img, activation_map, file_name):
    # Grayscale activation map
    path_to_file = file_name+'_Cam_Grayscale.jpg'
    cv2.imwrite(path_to_file, activation_map)
    # Heatmap of activation map
    activation_heatmap = cv2.applyColorMap(activation_map, cv2.COLORMAP_HSV)
    path_to_file = file_name+'_Cam_Heatmap.jpg'
    cv2.imwrite(path_to_file, activation_heatmap)
    # Heatmap on picture
    org_img = cv2.resize(org_img, (224, 224))
    img_with_heatmap = np.float32(activation_heatmap) + np.float32(org_img)
    img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap)
    path_to_file = file_name+'_Cam_On_Image.jpg'
    cv2.imwrite(path_to_file, np.uint8(255 * img_with_heatmap))

In [8]:
conv_output = Conv2D.outputs["alexnet0_conv2d4"].asnumpy()
model_output = out.asnumpy()

target_class = np.argmax(model_output)

one_hot_output = mx.nd.one_hot(mx.nd.array([target_class]), 1000)

alexnet.collect_params().zero_grad()

out.backward(one_hot_output)

guided_gradients = Conv2D.outputs['alexnet0_conv2d4'].grad[0].asnumpy()

target = Conv2D.outputs['alexnet0_conv2d4'][0].asnumpy()

weights = np.mean(guided_gradients, axis=(1, 2))

In [9]:
print(target.shape)
print(guided_gradients.shape)
#np.mean(guided_gradients, axis=(1, 2))
print(weights.shape)

(256, 13, 13)
(256, 13, 13)
(256,)


In [10]:
for k,v in alexnet.collect_params().items():
    print(k, v.grad().shape)

alexnet0_conv0_weight (64, 3, 11, 11)
alexnet0_conv0_bias (64,)
alexnet0_conv1_weight (192, 64, 5, 5)
alexnet0_conv1_bias (192,)
alexnet0_conv2_weight (384, 192, 3, 3)
alexnet0_conv2_bias (384,)
alexnet0_conv3_weight (256, 384, 3, 3)
alexnet0_conv3_bias (256,)
alexnet0_conv4_weight (256, 256, 3, 3)
alexnet0_conv4_bias (256,)
alexnet0_dense0_weight (4096, 9216)
alexnet0_dense0_bias (4096,)
alexnet0_dense1_weight (4096, 4096)
alexnet0_dense1_bias (4096,)
alexnet0_dense2_weight (1000, 4096)
alexnet0_dense2_bias (1000,)


In [11]:
cam = np.ones(target.shape[1:], dtype=np.float32)

for i, w in enumerate(weights):
    cam += w * target[i, :, :]

In [12]:


    
cam = cv2.resize(cam, (224, 224))
cam = np.maximum(cam, 0)
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) 
cam = np.uint8(cam * 255)



In [15]:
org_img = image[0].transpose((1,2,0)).asnumpy()
save_class_activation_on_image(org_img, cam, "gradcam")