In [1]:
import torch
from torch.autograd import Variable
from torch.autograd import Function
from torchvision import transforms
from torchvision import utils
from PIL import Image
import sys
import numpy as np
import argparse
import cv2

In [2]:
class FeatureExtractor():
    """ Class for extracting activations and 
    registering gradients from targetted intermediate layers """
    def __init__(self, model, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.gradients = []

    def save_gradient(self, grad):
    	self.gradients.append(grad)

    def __call__(self, x):
        outputs = []
        self.gradients = []
        for name, module in self.model._modules.items():
            if name=='avgpool':
                x=module(x)
                x=torch.reshape(x, (1,512))
            else:
                x = module(x)
            if name in self.target_layers:
                x.register_hook(self.save_gradient)
                outputs += [x]
        return outputs, x

In [3]:
class ModelOutputs():
    """ Class for making a forward pass, and getting:
    1. The network output.
    2. Activations from intermeddiate targetted layers.
    3. Gradients from intermeddiate targetted layers. """
    def __init__(self, model, target_layers):
        self.model = model
        self.feature_extractor = FeatureExtractor(self.model, target_layers)

    def get_gradients(self):
        return self.feature_extractor.gradients

    def __call__(self, x):
        target_activations, output  = self.feature_extractor(x)
        output = output.view(output.size(0), -1)
        #output = self.model.fc(output)
        return target_activations, output

In [4]:
def preprocess_image(img):
    #preprocessed_img = img.copy()[: , :, ::-1]
    #preprocessed_img = np.float32(np.ascontiguousarray(np.transpose(preprocessed_img, (2, 0, 1))))
    # create Tensor datasets
    transform = transforms.Compose([transforms.Resize(270), transforms.CenterCrop(256), transforms.ToTensor()])
    input=transform(img)
    #preprocessed_img = torch.from_numpy(preprocessed_img)
    #preprocessed_img.unsqueeze_(0)
    #input = Variable(preprocessed_img, requires_grad = True)
    return input


In [5]:
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    cv2.imwrite("cam.jpg", np.uint8(255 * cam))

In [15]:
class GradCam:
	def __init__(self, model, target_layer_names, use_cuda):
		self.model = model
		self.model.eval()
		self.cuda = use_cuda
		if self.cuda:
			self.model = model.cuda()

		self.extractor = ModelOutputs(self.model, target_layer_names)

	def forward(self, input):
		return self.model(input) 

	def __call__(self, input, index = None):
		if self.cuda:
			features, output = self.extractor(input.cuda())
		else:
			features, output = self.extractor(input)

		if index == None:
			index = np.argmax(output.cpu().data.numpy())

		one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32)
		one_hot[0][index] = 1
		one_hot = Variable(torch.from_numpy(one_hot), requires_grad = True)
		if self.cuda:
			one_hot = torch.sum(one_hot.cuda() * output)
		else:
			one_hot = torch.sum(one_hot * output)

		#self.model.features.zero_grad()
		self.model.fc.zero_grad()
		one_hot.backward(retain_graph=True)

		grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()

		target = features[-1]
		target = target.cpu().data.numpy()[0, :]

		weights = np.mean(grads_val, axis = (2, 3))[0, :]
        
		cam = np.zeros(target.shape[1 : ], dtype = np.float32)

		for i, w in enumerate(weights):
			cam += w * target[i, :, :]

		cam = np.maximum(cam, 0)
		cam = cv2.resize(cam, (256,256))
		cam = cam - np.min(cam)
		cam = cam / np.max(cam)
		return cam

In [16]:
cnn=torch.load("resnet_best")
cnn.to('cuda')
grad_cam = GradCam(model = cnn, target_layer_names = ["layer2"], use_cuda=True)
img = Image.open("../LSTM_AFDB/dwt/testdata/ecg1.png").convert('RGB')
input = preprocess_image(img)
img=input.numpy().transpose(1,2,0)
target_index = None
mask = grad_cam(input.unsqueeze_(0), 0)
show_cam_on_image(img, mask)