In [None]:
import os
import numpy as np
import torch.nn as nn
import torch
import plotly.graph_objects as go
from torch.optim import Adam

import resnet50_128BSC as model
from misc_functions import preprocess_image, recreate_image, save_image

In [None]:
class CNNLayerVisualization():
    """
        Produces an image that minimizes the loss of a convolution
        operation for a specific layer and filter
    """
    def __init__(self, model, selected_layer):
        self.model = model
        self.model.eval()
        self.selected_layer = selected_layer        
        self.conv_output = 0
        # Create the folder to export images if not exists
        if not os.path.exists('generated'):
            os.makedirs('generated')
        # Generate a random image       
        random_image = np.uint8(np.random.uniform(150, 180, (224, 224, 3)))
        # Process image and return variable
        self.processed_image = preprocess_image(random_image, False)
    def calculate_mean_activation(self):        
        def sortSecond(val): 
            return val[1]         
        filters = [] #X axis of bar chart
        activations = [] # mean of activation in each filters
        sele_filters = [] #filters No. and mean activation 
        maxfilter = [] #selected filters of each layer
        minfilter = []
        sel_filter = 15 # amount of choosen filters       
        x = self.processed_image.cuda()
                
        x = self.model.cho_layer(x, self.selected_layer) 
        print(self.selected_layer,'-', x.shape)
        for k in range(x.shape[1]):
            mea_act = float(torch.mean(x[0, k]))
            filters.append(k)
            activations.append(mea_act)
            sele_filters.append((k,mea_act))

        sele_filters.sort(key = sortSecond, reverse = True)  
        if len(sele_filters) < sel_filter:
            sel_filter = len(sele_filters)
        for f in range(sel_filter):
            maxfilter.append(sele_filters[f][0])
            maxfilter.append(sele_filters[-f-1][0])
        print(self.selected_layer,'-',maxfilter)
#         fig = go.Figure([go.Bar(x=filters, y=activations)])
#         fig.show()        
        return maxfilter
    def visualise_layer_without_hooks(self,selected_filter):       
        # Process image and return variable
        # Define optimizer for the image
        optimizer = Adam([self.processed_image], lr=0.1, weight_decay=1e-6)
        
        for i in range(1, 181):
            x = self.processed_image.cuda()
            optimizer.zero_grad()  
            x = self.model.cho_layer(x, self.selected_layer)                
            self.conv_output = x[0, selected_filter]
            # Loss function is the mean of the output of the selected layer/filter
            # We try to minimize the mean of the output of that specific filter
            loss = -torch.mean(self.conv_output)
            #print('Iteration:', str(i), 'Loss:', "{0:.2f}".format(float(loss)))
            # Backward
            loss.backward()
            # Update image
            optimizer.step()
            # Recreate image
            self.created_image = recreate_image(self.processed_image)
            # Save image
            if i % 90 == 0:
                im_path = 'generated/ASC_res50_vis_l' + str(self.selected_layer) + \
                    '_f' + str(selected_filter) + '_iter' + str(i) + '.jpg'
                save_image(self.created_image, im_path)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
modelRes50 = model.resnet50_128(weights_path='./model/resnet50_128.pth')
modelRes50 = modelRes50.to(device)

In [None]:
for l in range(7,9):
    layer_vis = CNNLayerVisualization(modelRes50, l)
    filter_pos = layer_vis.calculate_mean_activation()
    for f in filter_pos:
        layer_vis.visualise_layer_without_hooks(f) 