In [1]:
'''
Author        : Aditya Jain
Date Started  : 21st June, 2021
About         : This script generates Grad-Class Activation Maps (GradCAM) for lots of images
'''
import sys, os
sys.path.append('/home/mila/a/aditya.jain/mothAI/deeplearning')

import numpy as np
import tensorflow as tf
from tensorflow import keras
import json
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
from models.resnet50 import Resnet50
from torchsummary import summary
import torchvision.models as models
from torch import nn
from PIL import Image
import cv2
from scipy import ndimage

# Display
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.cm as cm

from data.mothdataset import MOTHDataset

image_path   = 'test2.jpg'
THRESHOLD    = 0.25

bbox_dir     = '/home/mila/a/aditya.jain/bboximages/correct/'
config_file  = '/home/mila/a/aditya.jain/mothAI/deeplearning/config/01-config.json' 
PATH         = '/home/mila/a/aditya.jain/logs/v01_mothmodel_2021-06-08-04-53.pt'
label_file   = '/home/mila/a/aditya.jain/mothAI/deeplearning/data/numeric_labels.json'

f             = open(config_file)
config_data   = json.load(f)

root_dir      = config_data['dataset']['root_dir']
test_set      = config_data['dataset']['test_set']
label_list    = config_data['dataset']['label_info']
batch_size    = 1
image_resize  = config_data['training']['image_resize']

In [18]:
def generate_cam(model, image, label_file, predict_indx, last_conv_out):
    '''
    given an image and model, returns the CAM
    '''
    f            = open(label_file)
    label_info   = json.load(f)
    species_list = label_info['species_list']
    
    last_conv_out = last_conv_out[0]
    conv_shape = last_conv_out.shape

    for name, param in model.named_parameters():
        if name == 'classifier.weight':
            classifier_weights = param 
            break
    
    topclass_weights = classifier_weights[predict_indx,:]

    cam = np.zeros((conv_shape[1], conv_shape[1]))

    for i in range(conv_shape[1]):
        for j in range(conv_shape[1]):
            for k in range(conv_shape[0]):
                cam[i,j] += topclass_weights[k]*last_conv_out[k,i,j]
                
    cam = np.maximum(cam, 0)/np.amax(cam)
    
    return cam


def cropped_image(cam, threshold, orig_im):
    '''
    based on CAM, returns the cropped image from the original based on the chosen threshold
    '''
    heatmap           = cv2.resize(cam, (image_resize, image_resize))
    heatmap           = np.uint8(255 * heatmap)
    
    threshold_val     = threshold*np.amax(heatmap)
    threshold_heatmap = (heatmap > threshold_val) * heatmap
    
    labels, nb        = ndimage.label(threshold_heatmap)
    if nb==1:
        crop_loc          = ndimage.find_objects(labels==1)
        cropped_im        = orig_im[crop_loc[0]]
        return cropped_im
    else:
        return 'Multiple segments found'

In [19]:
def bounding_box(cam, image_batch, image_resize, threshold, bbox_dir, count):
    '''
    prints the bounding box on the original image
    '''
    orig_im           = np.array(image_batch.squeeze())
    orig_im           = np.moveaxis(orig_im, 0, -1)
    
    heatmap           = cv2.resize(cam, (image_resize, image_resize))
    heatmap           = np.uint8(255 * heatmap)    
    threshold_val     = threshold*np.amax(heatmap)
    threshold_heatmap = (heatmap > threshold_val) * heatmap    
    labels, nb        = ndimage.label(threshold_heatmap)
    
    if nb==1:
        crop_loc          = ndimage.find_objects(labels==1)    
        y_start           = crop_loc[0][0].start
        y_stop            = crop_loc[0][0].stop
        x_start           = crop_loc[0][1].start
        x_stop            = crop_loc[0][1].stop
        
        plt.imshow(orig_im)
        ax       = plt.gca()
        rect     = patches.Rectangle((x_start,y_start),
                                     x_stop-x_start,
                                     y_stop-y_start,
                                     linewidth=3,
                                     edgecolor='red',
                                     fill = False)
        ax.add_patch(rect)
        plt.savefig(bbox_dir + str(count) + '.jpg')
        plt.close()
    else:
        print('Multiple segments found')

#### Getting bboxes for test data images

In [22]:
device        = "cuda" if torch.cuda.is_available() else "cpu"
model         = Resnet50(config_data).to(device)

checkpoint  = torch.load(PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

# adding hook on the last convolution layer
last_conv_output = None

def hook(module, input, output):
    global last_conv_output
    last_conv_output = output    

model.backbone[7][2].relu.register_forward_hook(hook)

<torch.utils.hooks.RemovableHandle at 0x7fa6e3ff1ed0>

In [23]:
# test data loader
test_transformer  = transforms.Compose([
                        transforms.Resize((image_resize, image_resize)),              # resize the image to 224x224 
                        transforms.ToTensor()])
test_data         = MOTHDataset(root_dir, test_set, label_list, test_transformer)
test_dataloader   = DataLoader(test_data,batch_size=batch_size, shuffle=True)

count = 0
model.eval()
with torch.no_grad():                                 # switching off gradient computation in evaluation mode
    for image_batch, label_batch in test_dataloader:  
        image_batch, label_batch = image_batch.to(device), label_batch.to(device)
        out                      = model(image_batch)
        _, predict_indx          = torch.topk(out, 1)
        print(label_batch, predict_indx)
        
        if label_batch in predict_indx:
            _, predict_indx          = torch.topk(out, 5)
            predict_indx             = predict_indx[0]
            
            for index in predict_indx:
                print(index)
                cam                      = generate_cam(model, image_batch, label_file, index, last_conv_output)
                bounding_box(cam, image_batch, image_resize, THRESHOLD, bbox_dir, count)        
                count +=1 
        
            break

tensor([[437]]) tensor([[437]])
tensor(437)
tensor(509)
tensor(619)
tensor(310)
tensor(691)
