In [1]:
import os
import numpy as np
import glob
import random
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

import torch
from torch.nn import ReLU
from torch.optim import Adam
from torch.autograd import Variable
import torchvision
from torchvision import models, transforms

# from misc_functions import *
from model import *
from dataset import VNOnDB, get_data_loader
from utils import ScaleImageByHeight

device = f'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
checkpoint = torch.load('runs/tf_ctc_word_resnet18.pt', map_location=device)
root_config = checkpoint['config']
best_metrics = dict()
config = root_config['common']

image_transform = transforms.Compose([
        ImageOps.invert,
        ScaleImageByHeight(config['scale_height']),
        transforms.Grayscale(3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
])
loader = get_data_loader(config['dataset'],
                             'test',
                             config['batch_size'],
                             2,
                             image_transform,
                             False,
                             flatten_type=config.get('flatten_type', None),
                             add_blank=True) # CTC need add_blank

if config['dataset'] in ['vnondb', 'vnondb_line']:
    vocab = VNOnDB.vocab

cnn = ResnetFE('resnet18')

model_config = root_config['tf']
model = CTCModelTFEncoder(cnn, vocab, model_config)

model.to(device)
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [3]:
img_folder = 'misc'
if not os.path.exists(img_folder):
    os.mkdir(img_folder)
    
fp = 'data/VNOnDB/word/test_word/20151208_0146_7105_1_tg_0_0_6.png'

In [4]:
iter_test_loader = iter(loader)
batch = next(iter_test_loader)
imgs, targets = batch.images.to(device), batch.labels.to(device)

In [11]:
original_img = Image.open(fp).convert('RGB')
print(original_img.size)
img = image_transform(original_img)
print(img.size())

(128, 72)
torch.Size([3, 96, 170])


In [6]:
model.eval()
with torch.no_grad():
    x = img.unsqueeze(0)
    print(x.size())
    for index, layer in enumerate(model.cnn.cnn):
        x = layer(x)
        print(x.size())
        grid_x = x.reshape(-1, 1, x.size(2), x.size(3))
        grid = torchvision.utils.make_grid(grid_x, nrow=int(grid_x.size(0)/8))
        torchvision.utils.save_image(grid, '{}/cnn_layer_{}.png'.format(img_folder, index))

torch.Size([1, 3, 96, 170])
torch.Size([1, 64, 48, 85])
torch.Size([1, 64, 48, 85])
torch.Size([1, 64, 48, 85])
torch.Size([1, 64, 24, 43])
torch.Size([1, 64, 24, 43])
torch.Size([1, 128, 12, 22])
torch.Size([1, 256, 6, 11])
torch.Size([1, 512, 3, 6])


# Layer activation with guided backprop

In [7]:
class GuidedBackprop():
    """
       Produces gradients generated with guided back propagation from the given image
    """
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.forward_relu_outputs = []
        # Put model in evaluation mode
        self.model.eval()
        self.update_relus()
        self.hook_layers()

    def hook_layers(self):
        def hook_function(module, grad_in, grad_out):
            self.gradients = grad_in[0]
        # Register hook to the first layer
        first_layer = self.model.cnn.cnn[0]
        first_layer.register_backward_hook(hook_function)

    def update_relus(self):
        """
            Updates relu activation functions so that
                1- stores output in forward pass
                2- imputes zero for gradient values that are less than zero
        """
        def relu_backward_hook_function(module, grad_in, grad_out):
            """
            If there is a negative gradient, change it to zero
            """
            # Get last forward output
            corresponding_forward_output = self.forward_relu_outputs[-1]
            corresponding_forward_output[corresponding_forward_output > 0] = 1
            modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0)
            del self.forward_relu_outputs[-1]  # Remove last forward output
            return (modified_grad_out,)

        def relu_forward_hook_function(module, ten_in, ten_out):
            """
            Store results of forward pass
            """
            self.forward_relu_outputs.append(ten_out)

        # Loop through layers, hook up ReLUs
        for pos, module in self.model.cnn.cnn._modules.items():
            if isinstance(module, ReLU):
                module.register_backward_hook(relu_backward_hook_function)
                module.register_forward_hook(relu_forward_hook_function)

    def generate_gradients(self, input_image, cnn_layer, filter_pos):
        self.model.zero_grad()
        # Forward pass
        x = input_image
        for index, layer in enumerate(self.model.cnn.cnn):
            # Forward pass layer by layer
            # x is not used after this point because it is only needed to trigger
            # the forward hook function
            x = layer(x)
            # Only need to forward until the selected layer is reached
            if index == cnn_layer:
                # (forward hook function triggered)
                break
        conv_output = torch.sum(torch.abs(x[0, filter_pos]))
        # Backward pass
        conv_output.backward()
        # Convert Pytorch variable to numpy array
        # [0] to get rid of the first channel (1,3,224,224)
        gradients_as_arr = self.gradients.data.numpy()[0]
        return gradients_as_arr

In [14]:
original_image = Image.open(fp).convert('RGB')
print(original_image.size)
prep_img = image_transform(original_image)
prep_img.unsqueeze_(0)
prep_img = Variable(prep_img, requires_grad=True)
print(prep_img.size())

(128, 72)
torch.Size([1, 3, 96, 170])


In [None]:
n_filters = [64, 64, 64, 64, 64, 128, 256, 512]

for cnn_layer_idx in range(7):
    grads = []
    for filter_pos in range(n_filters[cnn_layer_idx]):
        # Guided backprop
        GBP = GuidedBackprop(model)
        # Get gradients
        guided_grads = GBP.generate_gradients(prep_img, cnn_layer_idx, filter_pos)
        grads.append(guided_grads)

    grads_tensor = torch.tensor(grads)
    grid = torchvision.utils.make_grid(grads_tensor, nrow=8)
    torchvision.utils.save_image(grid, '{}/layer_activation_{}.png'.format(img_folder, cnn_layer_idx))