# GradCAM Visualization Demo with ResNet50


In [1]:
# Replace vanila relu to guided relu to get guided backpropagation.
import tensorflow as tf

from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_nn_ops

@ops.RegisterGradient("GuidedRelu")
def _GuidedReluGrad(op, grad):
    return tf.where(0. < grad, gen_nn_ops.relu_grad(grad, op.outputs[0]), tf.zeros_like(grad))

  from ._conv import register_converters as _register_converters


In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import tensorflow as tf
import numpy as np

from tensorflow.contrib.slim.python.slim.nets import resnet_v1
slim = tf.contrib.slim
from nets import resnet_v1

import utils

# Create mini-batch for demo

img1 = utils.load_image("./adv/FGSM_coyote.jpg")
img2 = utils.load_image("./input/coyote.jpg")
img3 = utils.load_image("./adv/FGSM_pandas.jpg")
img4 = utils.load_image("./input/pandas.jpg")

batch1_img = img1.reshape((1, 224, 224, 3))
batch1_label = np.array([1 if i == 271 else 0 for i in range(1000)])  # 1-hot result for Boxer
batch1_label = batch1_label.reshape(1, -1)

# batch2_img = img2.reshape((1, 224, 224, 3))
# batch2_label = np.array([1 if i == 272 else 0 for i in range(1000)])  # 1-hot result for Shih-Tzu
# batch2_label = batch2_label.reshape(1, -1)

# batch3_img = img3.reshape((1, 224, 224, 3))
# batch3_label = np.array([1 if i == 368 else 0 for i in range(1000)])  # 1-hot result for Boxer
# batch3_label = batch3_label.reshape(1, -1)

# batch4_img = img4.reshape((1, 224, 224, 3))
# batch4_label = np.array([1 if i == 388 else 0 for i in range(1000)])  # 1-hot result for Shih-Tzu
# batch4_label = batch4_label.reshape(1, -1)
# batch_img = np.concatenate((batch1_img, batch2_img,batch3_img,batch4_img), 0)
# batch_label = np.concatenate((batch1_label, batch2_label,batch3_label,batch4_label), 0)
batch_img = np.concatenate((batch1_img), 0)
batch_label = np.concatenate((batch1_label), 0)

batch_size = 1

# batch_img = np.concatenate((batch1_img), 0)
# batch_label = np.concatenate((batch1_label), 0)
# batch_size = 1
# batch_img = np.expand_dims(batch_img, 0)
# batch_label = batch_label.reshape(batch_size, -1)

# Create tensorflow graph for evaluation
eval_graph = tf.Graph()
with eval_graph.as_default():
    with eval_graph.gradient_override_map({'Relu': 'GuidedRelu'}):
        images = tf.placeholder("float", [batch_size, 224, 224, 3])
        labels = tf.placeholder(tf.float32, [batch_size, 1000])
        
        preprocessed_images = utils.resnet_preprocess(images)
        
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            with slim.arg_scope([slim.batch_norm], is_training=False):
                # is_training=False means batch-norm is not in training mode. Fixing batch norm layer.
                # net is logit for resnet_v1. See is_training messing up issue: https://github.com/tensorflow/tensorflow/issues/4887
                net, end_points = resnet_v1.resnet_v1_50(preprocessed_images, 1000)
        
        prob = end_points['predictions'] # after softmax
        print('prob:', prob)
        
        print('lables:', labels)
        cost = (-1) * tf.reduce_sum(tf.multiply(labels, tf.log(prob)), axis=1)
        print('cost:', cost)

        # Get last convolutional layer gradient for generating gradCAM visualization
        # print('endpoints:', end_points.keys())
        target_conv_layer = end_points['resnet_v1_50/block4/unit_2/bottleneck_v1']
        # target_conv_layer = end_points['resnet_v1_50/block3/unit_5/bottleneck_v1']
        
        # gradient for partial linearization. We only care about target visualization class. 
        y_c = tf.reduce_sum(tf.multiply(net, labels), axis=1)
        print('y_c:', y_c)
        target_conv_layer_grad = tf.gradients(y_c, target_conv_layer)[0]
        print('target_conv_layer_grad:', target_conv_layer_grad)

        # Guided backpropagtion back to input layer
        gb_grad = tf.gradients(cost, images)[0]
        print('gb_grad:', gb_grad)

        init = tf.global_variables_initializer()
        
        # Load resnet v1 weights
        
        # latest_checkpoint = tf.train.latest_checkpoint("model/resnet_v1_50.ckpt")
        latest_checkpoint = "model/resnet_v1_50.ckpt"
        ## Optimistic restore.
        reader = tf.train.NewCheckpointReader(latest_checkpoint)
        saved_shapes = reader.get_variable_to_shape_map()
        variables_to_restore = tf.global_variables()
        for var in variables_to_restore:
          if not var.name.split(':')[0] in saved_shapes:
            print("WARNING. Saved weight not exists in checkpoint. Init var:", var.name)
          else:
            # print("Load saved weight:", var.name)
            pass

        var_names = sorted([(var.name, var.name.split(':')[0]) for var in variables_to_restore
                if var.name.split(':')[0] in saved_shapes])
        restore_vars = []
        with tf.variable_scope('', reuse=True):
            for var_name, saved_var_name in var_names:
                try:
                    curr_var = tf.get_variable(saved_var_name)
                    var_shape = curr_var.get_shape().as_list()
                    if var_shape == saved_shapes[saved_var_name]:
                        # print("restore var:", saved_var_name)
                        restore_vars.append(curr_var)
                    else:
                        print("cannot restore var:", saved_var_name)
                except ValueError:
                    print("Ignore due to ValueError on getting var:", saved_var_name) 
        saver = tf.train.Saver(restore_vars)
        
        
        
# Run tensorflow 

with tf.Session(graph=eval_graph) as sess:    
    sess.run(init)
    saver.restore(sess, latest_checkpoint)    
    
    prob_np = sess.run(prob, feed_dict={images: batch_img})
    # print('prob_np:', prob_np)
    print('prob_np:', prob_np.shape)
    
    # cost_np, gb_grad_value, target_conv_layer_value, target_conv_layer_grad_value = sess.run([cost, gb_grad, target_conv_layer, target_conv_layer_grad], feed_dict={images: batch_img, labels: prob_np})
    net_np, y_c_np, gb_grad_value, target_conv_layer_value, target_conv_layer_grad_value = sess.run([net, y_c, gb_grad, target_conv_layer, target_conv_layer_grad], feed_dict={images: batch_img, labels: batch_label})
    
#     print("net_np:", net_np)
#     print("y_c_np:", y_c_np)
#     print("gb_grad_value:", gb_grad_value)
#     print("target_conv_layer_value:", target_conv_layer_value)
#     print("target_conv_layer_grad_value:", target_conv_layer_grad_value)
    
    for i in range(batch_size):
        # print('See visualization of below category')
        # utils.print_prob(batch_label[i], './synset.txt')
        utils.print_prob(prob_np[i], './synset.txt')
        # print('gb_grad_value[i]:', gb_grad_value[i])
        # print('gb_grad_value[i] shape:', gb_grad_value[i].shape)
        utils.visualize(batch_img[i], target_conv_layer_value[i], target_conv_layer_grad_value[i], gb_grad_value[i])
    

ModuleNotFoundError: No module named 'nets'