In [1]:
from PIL import Image
import numpy as np
import torch
# import resnet50_128ASC as ASC_model
import resnet50_128BSC as BSC_model
# import resnet50_128BSG as BSG_model
# import resnet50_128ASG as ASG_model
from misc_functions import get_example_params, save_class_activation_images

In [2]:
class CamExtractor():
    """
        Extracts cam features from the model
    """
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None

    def save_gradient(self, grad):
        self.gradients = grad

    def forward_pass_on_convolutions(self, x):
        """
        
            Does a forward pass on convolutions, hooks the function at given layer
        """
        conv_output = None
        y = x
        for j in range (16):
            #print(j)
            x,y = self.model.grad_layer(x,y,j)   
            if j == self.target_layer:
                x.register_hook(self.save_gradient)
                conv_output = x  # Save the convolution output on that layer     
#         print(torch.mean(conv_output), torch.mean(x)) 
        return conv_output, x

    def forward_pass(self, x):
        """
            Does a full forward pass on the model
        """
        # Forward pass on the convolutions
        conv_output, x = self.forward_pass_on_convolutions(x)
        x = self.model.classifier(x)
        return conv_output, x

In [3]:
class GradCam():
    """
        Produces class activation map
    """
    def __init__(self, model, target_layer):
        self.model = model
        self.model.eval()
        # Define extractor
        self.extractor = CamExtractor(self.model, target_layer)

    def generate_cam(self, input_image, target_class=None):
        # Full forward pass
        # conv_output is the output of convolutions at specified layer
        # model_output is the final output of the model (1, 1000)
        conv_output, model_output = self.extractor.forward_pass(input_image)
        #print(model_output.data.numpy())
        target_class = np.argmax(model_output.data.numpy())
        # Target for backprop
        
        one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()
        one_hot_output[0][target_class] = 1
        # Zero grads
        self.model.zero_grad()
        #self.model.features.zero_grad()
        #self.model.classifier.zero_grad()
        # Backward pass with specified target
        model_output.backward(gradient=one_hot_output, retain_graph=True)
        # Get hooked gradients
        #print(self.extractor.gradients.data.numpy()[0])
        guided_gradients = self.extractor.gradients.data.numpy()[0]
        # Get convolution outputs
        target = conv_output.data.numpy()[0]
        # Get weights from gradients
        weights = np.mean(guided_gradients, axis=(1, 2))  # Take averages for each gradient
        # Create empty numpy array for cam
        cam = np.ones(target.shape[1:], dtype=np.float32)
        # Multiply each weight with its conv output and then, sum
        for i, w in enumerate(weights):
            cam += w * target[i, :, :]
        cam = np.maximum(cam, 0)
        cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))  # Normalize between 0-1
        cam = np.uint8(cam * 255)  # Scale between 0-255 to visualize
        cam = np.uint8(Image.fromarray(cam).resize((input_image.shape[2],
                       input_image.shape[3]), Image.ANTIALIAS))/255

        return cam

In [6]:
def GRAD(model,key): 
    # Get params
    target_example = 0  # Snake
    (original_image, prep_img, target_class, file_name_to_export, pretrained_model) =\
        get_example_params(target_example)
    pretrained_model = model.resnet50_128(weights_path='./model/resnet50_128.pth')
    # Grad cam
    target_layer = 8
    grad_cam = GradCam(pretrained_model, target_layer)
    # Generate cam mask
    cam = grad_cam.generate_cam(prep_img, target_class)
    # Save mask
    save_class_activation_images(key,original_image, cam, file_name_to_export,str(target_layer))
    # print('Grad cam completed')

In [7]:
# GRAD(BSG_model,'BSG')
# GRAD(ASG_model,'ASG')
GRAD(BSC_model,'BSC')
# GRAD(ASC_model,'ASC')