### **Dropping Units in Scene Detection on Place365 Dataset**

In [1]:
import os
import cv2
import io
import torch
import numpy as np
import matplotlib
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from torch.autograd import Variable as V
import torchvision.models as models
from torchvision import transforms as trn
from torch.nn import functional as F
from PIL import ImageFont, ImageDraw, Image
from IPython.display import Image as ipy_Image
from ipywidgets import widgets, interact, interactive, FloatRangeSlider, HBox, Label


matplotlib.rcParams['animation.embed_limit'] = 2**32
%matplotlib inline

In [2]:
image_dir = 'Place365 Images/'
model_dir = 'Models/' 
file_name = 'categories_places365.txt' # load the class label

In [3]:
# th architecture to use
arch = 'resnet50'
classes = list()
with open(file_name) as class_file:
    for line in class_file:
        classes.append(line.strip().split(' ')[0][3:])
classes = tuple(classes)

In [4]:
# load the pre-trained weights
model_file = 'Models/%s_places365.pth.tar' % arch

In [5]:
model = models.__dict__[arch](num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict)

<All keys matched successfully>

In [6]:
features_blobs = []
def hook_feature(module, input, output):
    features_blobs.append(output.data.cpu().numpy())
model._modules.get('layer4').register_forward_hook(hook_feature)

<torch.utils.hooks.RemovableHandle at 0x7f93c2382ad0>

In [7]:
class DNNUnits():
    
    def __init__(self, filename, model):
        
        self.filename = filename
        self.model = model
        self.wrk_img = cv2.imread(self.filename)
        # get the softmax weight
        self.params = list(model.parameters())
        self.weight_softmax = np.squeeze(self.params[-2].data.numpy()) #Fully Connected Layer
        self.font = cv2.FONT_HERSHEY_SIMPLEX #Font
        file = open(self.filename, "rb")
        self.img = file.read()
        
    def returnCAM(self, feature_conv, weight_softmax, class_idx):
        # generate the class activation maps upsample to 256x256
        size_upsample = (256, 256)
        bz, nc, h, w = feature_conv.shape
        output_cam = []
        for idx in class_idx:
            cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h*w)))
            cam = cam.reshape(h, w)
            cam = cam - np.min(cam)
            cam_img = cam / np.max(cam)
            cam_img = np.uint8(255 * cam_img)
            output_cam.append(cv2.resize(cam_img, size_upsample))

        return output_cam
        
    
    def image_prediction(self):
        
        self.model.eval()
        #Bytes to PIL Image
        img = Image.open(io.BytesIO(self.img))
        normalize = trn.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        preprocess = trn.Compose([trn.Resize((224,224)), trn.ToTensor(), normalize])

        #Preprocessing
        img_tensor = preprocess(img)
        img_variable = V(img_tensor.unsqueeze(0))

        #Prediction
        logit = self.model(img_variable)
        h_x = F.softmax(logit, dim=1).data.squeeze()

        #Probs and class index
        probs, idx = h_x.sort(0, True)
        probs = probs.numpy()
        self.idx = idx.numpy()
        
        max_wgt = '%.3f' % round(max(self.weight_softmax[self.idx[0],:]), 3)
        min_wgt = '%.3f' % round(min(self.weight_softmax[self.idx[0],:]), 3)
        
        return img_variable, max_wgt, min_wgt
    
    
    def weight_slider(self, threshold_min, threshold_max, max_val=1):
        # Image Widget
        self.image = widgets.Image(value=self.img)
        
        #Plot Float Widget
        self.weight_slider = FloatRangeSlider(value = [threshold_min,  threshold_max] , min=-1, max=max_val, step=0.001)
        self.ui = widgets.HBox([Label('Weight Value'), self.weight_slider])

        return self.image, self.weight_slider, self.ui
    
    
    def f(self, wgt, img):
        self.weight_softmax[self.idx[0]][list(np.where(self.weight_softmax[self.idx[0],:]> wgt[1])[0])] = 0
        self.weight_softmax[self.idx[0]][list(np.where(self.weight_softmax[self.idx[0],:]< wgt[0])[0])] = 0
        img_variable, max_wgt, min_wgt = self.image_prediction()
        
        units_dropped_R = len(np.where(self.weight_softmax[self.idx[0],:]> wgt[1])[0])
        units_dropped_L = len(np.where(self.weight_softmax[self.idx[0],:]< wgt[0])[0])

        if units_dropped_R > 0:
            print(f'Number of units dropped in Fully Connected Layer is {len(np.where(self.weight_softmax[self.idx[0],:]> wgt[1])[0])}')
        else:
            print(f'Number of units dropped in Fully Connected Layer is {len(np.where(self.weight_softmax[self.idx[0],:]< wgt[0])[0])}')
        ##### New Lines start
        self.model.fc.weight[:,:] = torch.from_numpy(self.weight_softmax)
        
        #Prediction
        logit = self.model(img_variable)
        h_x = F.softmax(logit, dim=1).data.squeeze()

        #Probs and class index
        probs, idx = h_x.sort(0, True)
        probs = probs.numpy()
        idx = idx.numpy()
        print("Predicted Class: ", classes[idx[0]])
        ##### New Lines End
        
        CAMs = self.returnCAM(features_blobs[0], self.weight_softmax, [idx[0]])

        height, width, c = self.wrk_img.shape

        heatmap = cv2.applyColorMap(cv2.resize(CAMs[0],(width, height)), cv2.COLORMAP_JET)
        result = heatmap * 0.4 + self.wrk_img * 0.6
        string = ""
        
        
        if (units_dropped_L > 0) or (units_dropped_R > 0):
            x, y = 0, int(height/15)
            for i in range(0, 3):
                string = "Confidence: "+str(round(probs[i].item(),2))+"->"+"Label: "+classes[idx[i]]
                cv2.putText(result, string, (x, y), self.font, 0.38, (0, 0, 0), 2)
                y += 15
            if units_dropped_R > 0:
                cv2.putText(result, "Top Most "+str(units_dropped_R)+" units dropped", (x, y), self.font, 0.38, (0, 0, 0), 2)
                cv2.imwrite('dnn_units/CAM'+str(wgt[1])+'.jpg', result)
            else:
                cv2.putText(result, "Bottom Most "+str(units_dropped_L)+" units dropped", (x, y), self.font, 0.38, (0, 0, 0), 2)
                cv2.imwrite('dnn_units/CAM'+str(wgt[0])+'.jpg', result)
                
            

    def units_slider(self, weight, image, ui):
        out = widgets.interactive_output(self.f, {'wgt': weight, 'img': image})
        display(ui, out)

In [8]:
@interact
def show_images(file=os.listdir(image_dir)):
    _interactive_ = DNNUnits(filename=image_dir+file, model=model)
    display(ipy_Image(image_dir+file))
    img_variable, max_wgt, min_wgt = _interactive_.image_prediction()
    image, weight, ui = _interactive_.weight_slider(threshold_min=min_wgt, threshold_max=max_wgt)
    _interactive_.units_slider(weight = weight, image = image, ui=ui)

interactive(children=(Dropdown(description='file', options=('attic.jpg', 'aquarium.jpg', 'airplane_cabin.jpg',…

In [11]:
fig = plt.figure(figsize=(10,10))

ims = []
path = "dnn_units/"
files = sorted(os.listdir(path), reverse=True)

for file in files:
    img = Image.open(path+file)
    im = plt.imshow(img, animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=1000, blit=True,
                                repeat_delay=1000)
HTML(ani.to_jshtml())

<Figure size 720x720 with 0 Axes>

In [10]:
ani.save('dnn_units/animation.gif', writer='imagemagick', fps=2)