In [12]:
import torch
from torch.autograd import Variable
from torch.autograd import Function
from torchvision import models
from torchvision import utils
import cv2
import sys
import numpy as np
import argparse
import torch.nn as nn
import pdb 
import torch.utils.data as data
import h5py
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import PIL.Image as Image
%matplotlib inline

import random

In [13]:
## If there is a near-collision in next two seconds or not 

class FrameDataset(data.Dataset):
    
    def __init__(self, f, transform=None, test=False):
        
        self.f = f 
        self.transform = transform
        self.test = test 
        
    def __getitem__(self, index):
        
        rgb = np.array(self.f["rgb"][index])
        label = np.array(self.f["labels"][index], dtype=np.uint8)
        
        t_label = torch.zeros(2)
        
        if (label[0] or label[1]):
            t_label[0] = 1 ## Near-collision within next 2 seconds
        else:
            t_label[1] = 1 ## No Near-collision within next 2 seconds 
            
        t_rgb = torch.zeros(rgb.shape[0], 3, 224, 224)
        
        prob = random.uniform(0, 1)
        
        if self.transform is not None:
            
            for i in range(rgb.shape[0]):
                if (prob > 0.5 and not self.test):
                    flip_transform = transforms.Compose([transforms.ToPILImage(), transforms.RandomHorizontalFlip(1.0)])
                    rgb[i,:,:,:] = flip_transform(rgb[i,:,:,:])
                t_rgb[i,:,:,:] = self.transform(rgb[i,:,:,:])
                
        return rgb, t_rgb, t_label
    
    def __len__(self):
        return len(self.f["rgb"])

In [14]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
hfp_test = h5py.File('/mnt/hdd1/aashi/cmu_data/threeSecsTest.h5', 'r')
test_loader = data.DataLoader(FrameDataset(f = hfp_test, transform = transforms.Compose([transforms.ToTensor(), normalize]), test = True), 
                               batch_size=1)

In [15]:
def load_vgg_voc_weights(MODEL_PATH):
    checkpoint_dict = torch.load(MODEL_PATH)
    vgg_model.load_state_dict(checkpoint_dict)

vgg_model = models.vgg16(pretrained=True)
num_final_in = vgg_model.classifier[-1].in_features
NUM_CLASSES = 20 ## in VOC
vgg_model.classifier[-1] = nn.Linear(num_final_in, NUM_CLASSES)
model_path = '/home/aashi/the_conclusion/model_files/' + 'vgg_on_voc' + str(800)
load_vgg_voc_weights(model_path)

class VGGNet(nn.Module):
    
    def __init__(self):
        super(VGGNet, self).__init__()
        self.rgb_net = self.get_vgg_features()
        
        kernel_size = 3 
        padding = int((kernel_size - 1)/2)
        self.conv_layer = nn.Conv2d(512, 16, kernel_size, 1, padding, bias=True)
        self.conv_bn = nn.BatchNorm2d(16)
        self.feature_size = 16*7*7*4
        self.final_layer = nn.Sequential(
        nn.Linear(self.feature_size, 256),
        nn.Linear(256, 2),  ## 4 classes instead of 2 
        nn.Softmax()  ## If loss function uses Softmax  
        )
        
    def forward(self, rgb): ## sequence of four images - last index is latest 
        four_imgs = []
        for i in range(rgb.shape[1]):
            img_features = self.rgb_net(rgb[:,i,:,:,:])
            channels_reduced = self.conv_bn(self.conv_layer(img_features))
            img_features = channels_reduced.view((-1, 16*7*7))
            four_imgs.append(img_features)
        concat_output = torch.cat(four_imgs, dim = 1)
        out = self.final_layer(concat_output)
        return out
        
    def get_vgg_features(self):

        modules = list(vgg_model.children())[:-1]
        vgg16 = nn.Sequential(*modules)
        
        return vgg16.type(torch.Tensor)

In [16]:
def load_model_weights(epoch_num):
    model_file = '/mnt/hdd1/aashi/binary_classification_' + str(epoch_num).zfill(3)
    checkpoint_dict = torch.load(model_file)
    model.load_state_dict(checkpoint_dict)

In [17]:
class FeatureExtractor():
    """ Class for extracting activations and 
    registering gradients from targetted intermediate layers """
    def __init__(self, model, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.gradients = []

    def save_gradient(self, grad):
    	self.gradients.append(grad)

    def __call__(self, x):
        outputs = []
        self.gradients = []
        for name, module in self.model._modules.items():
            x = module(x)
            if name in self.target_layers:
                x.register_hook(self.save_gradient)
                outputs += [x]
        return outputs, x

In [18]:
class ModelOutputs():
    """ Class for making a forward pass, and getting:
    1. The network output.
    2. Activations from intermeddiate targetted layers.
    3. Gradients from intermeddiate targetted layers. """
    def __init__(self, model, target_layers):
        self.model = model
        self.feature_extractor = FeatureExtractor(self.model.rgb_net, target_layers)

    def get_gradients(self):
        return self.feature_extractor.gradients

    def __call__(self, x): ## x is a sequence of four imgs [1,4,3,224,224] 
        four_imgs = []
        target_activations, _ = self.feature_extractor(x[:,0,:,:,:])

        for i in range(x.shape[1]):
            _, output  = self.feature_extractor(x[:,i,:,:,:])

            channels_reduced = self.model.conv_layer(output)
            img_features = channels_reduced.view((-1, 16*7*7))
            four_imgs.append(img_features)

        concat_output = torch.cat(four_imgs, dim = 1)

        out = self.model.final_layer(concat_output)

        return target_activations, out 

In [25]:
def show_cam_on_img(img, mask, seq):
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    cv2.imwrite("heatmaps/" + str(iter+1) + ".jpg", np.uint8(255 * cam))

In [20]:
class GradCam:
	def __init__(self, model, target_layer_names, use_cuda):
		self.model = model
		self.model.eval()
		self.cuda = use_cuda
		if self.cuda:
			self.model = model.cuda()

		self.extractor = ModelOutputs(self.model, target_layer_names)

	def forward(self, input):
		return self.model(input) 

	def __call__(self, input, index = None):
		if self.cuda:
			features, output = self.extractor(input.cuda())
		else:
			features, output = self.extractor(input)

		if index == None:
			index = np.argmax(output.cpu().data.numpy())

		one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32)
		one_hot[0][index] = 1
		one_hot = Variable(torch.from_numpy(one_hot), requires_grad = True)
		if self.cuda:
			one_hot = torch.sum(one_hot.cuda() * output)
		else:
			one_hot = torch.sum(one_hot * output)

		self.model.rgb_net.zero_grad()
		self.model.conv_layer.zero_grad()
		self.model.final_layer.zero_grad()
		one_hot.backward()

		grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()

		target = features[-1]
		target = target.cpu().data.numpy()[0, :]

		weights = np.mean(grads_val, axis = (2, 3))[0, :]
		cam = np.zeros(target.shape[1 : ], dtype = np.float32)

		for i, w in enumerate(weights):
			cam += w * target[i, :, :]        

		cam = np.maximum(cam, 0)        ## ReLU here ## only the pixels which have positive influence on the class 
		cam = cv2.resize(cam, (224, 224))
		cam = cam - np.min(cam) ## Normalized for visualization purpose 
		cam = cam / np.max(cam)
		return cam

In [26]:
use_cuda = 1 

model = VGGNet()

grad_cam = GradCam(model = model, target_layer_names = ["0"], use_cuda=use_cuda)


e = 2
load_model_weights(2)

target_index = None 

for iter, (rgb, t_rgb, label) in enumerate(test_loader, 0):
    
    input = t_rgb.float().cuda()
    mask = grad_cam(input, target_index)
    vrgb = rgb[0,3,:,:,:] ## heatmap of last image 
    
    img = np.ascontiguousarray(vrgb)
    
    show_cam_on_img(img/255, mask, iter)

  input = module(input)
