# Visualize CNN kernels

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Let's start by defining the model to finetune
# For this example we will use VGG11
import torchvision.models as models

#vgg11 = models.vgg11(pretrained=True) # load vgg model pretrained on ImageNet
#print(vgg11)

alexnet = models.alexnet(pretrained=True).to(device)
print(alexnet)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F

layers_list = [0,3,6,8,10]

for k in layers_list:
    # get the kernels from the first layer
    # as per the name of the layer
    kernels = alexnet.features[k].weight.detach().clone().cpu()
    
    #check size for sanity check
    print(kernels.size())
    
    n,c,w,h = kernels.shape
    row_num = 8

    kernels = kernels.view(n*c, -1, w, h)

    rows = np.min((kernels.shape[0] // row_num + 1, 16))
           
    if kernels.shape[0] > 64:
        kernels = kernels[:64]
    # normalize to (0,1) range so that matplotlib
    # can plot them
    kernels = kernels - kernels.min()
    kernels = kernels / kernels.max()
    
    #kernels = F.interpolate(kernels, (20,20))
    filter_img = torchvision.utils.make_grid(kernels, nrow = rows)
    
    plt.figure(figsize=(15, 5))
    for i in range(kernels.shape[0]):
        plt.subplot(row_num, rows, i + 1)
        # change ordering since matplotlib requires images to 
        # be (H, W, C)
        plt.imshow(kernels[i].permute(1, 2, 0))
        plt.axis('off')


# Visualize CNN activations

In [None]:
from PIL import Image # pip install Pillow

# Load an image
img = Image.open('res/cat.jpg')
plt.imshow(img)
plt.axis('off')

In [None]:
transform = transforms.Compose([transforms.CenterCrop((512,512)),
                                transforms.Resize((256,256)),
                                transforms.ToTensor(),
                                transforms.Lambda(lambda x: x[:3])]) # remove the alpha channel if present
img_tensor = transform(img)
img_tensor = img_tensor.to(device)

In [None]:
# Use HOOKS
conv_output = []

#append all the conv layers and their respective wights to the list
def append_conv(self, input, output):
    conv_output.append(output.detach().cpu())

In [None]:
for k in layers_list:
    alexnet.features[k].register_forward_hook(append_conv)

In [None]:
conv_output = []
# pass the image through the net
out = alexnet(img_tensor.unsqueeze(0))

for c_out in conv_output:
    print(c_out.size())

In [None]:
for num_layer in range(len(conv_output)):
    plt.figure(figsize=(30, 30))
    layer_viz = conv_output[num_layer][0, :, :, :]
    layer_viz = layer_viz.data
    print(layer_viz.size())
    for i, filter in enumerate(layer_viz):
        if i == 64: # we will visualize only 8x8 blocks from each layer
            break
        plt.subplot(8, 8, i + 1)
        plt.imshow(filter)
        plt.axis("off")