In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import os
from PIL import Image
import copy

# add relprop() method to each layer
########################################
class Linear(nn.Linear):
    def __init__(self, linear):
        super(nn.Linear, self).__init__()
        self.in_features = linear.in_features
        self.out_features = linear.out_features
        self.weight = linear.weight
        self.bias = linear.bias
        
    def relprop(self, R):
        V = torch.clamp(self.weight, min=0)
        Z = torch.mm(self.X, torch.transpose(V,0,1)) + 1e-9
        S = R / Z
        C = torch.mm(S, V)
        R = self.X * C
        return R
        
class ReLU(nn.ReLU):   
    def relprop(self, R): 
        return R

class Reshape_Alex(nn.Module):
    def __init__(self):
        super(Reshape_Alex, self).__init__()
        
    def forward(self, x):
        return x.view(-1, 256*6*6)
        
    def relprop(self, R):
        return R.view(-1, 256, 6, 6)

class Reshape_VGG(nn.Module):
    def __init__(self):
        super(Reshape_VGG, self).__init__()
        
    def forward(self, x):
        return x.view(-1, 512*7*7)
        
    def relprop(self, R):
        return R.view(-1, 512, 7, 7)

class MaxPool2d(nn.MaxPool2d):
    def __init__(self, maxpool2d):
        super(nn.MaxPool2d, self).__init__()
        self.kernel_size = maxpool2d.kernel_size
        self.stride = maxpool2d.stride
        self.padding = maxpool2d.padding
        self.dilation = maxpool2d.dilation
        self.return_indices = maxpool2d.return_indices
        self.ceil_mode = maxpool2d.ceil_mode
        
    def gradprop(self, DY):
        DX = self.X * 0
        temp, indices = F.max_pool2d(self.X, self.kernel_size, self.stride, 
                                     self.padding, self.dilation, self.ceil_mode, True)
        DX = F.max_unpool2d(DY, indices, self.kernel_size, self.stride, self.padding)
        return DX
    
    def relprop(self, R):
        Z = self.Y + 1e-9
        S = R / Z
        C = self.gradprop(S)
        R = self.X * C
        return R

class Conv2d(nn.Conv2d):
    def __init__(self, conv2d):
        super(nn.Conv2d, self).__init__(conv2d.in_channels, 
                                        conv2d.out_channels, 
                                        conv2d.kernel_size, 
                                        conv2d.stride, 
                                        conv2d.padding, 
                                        conv2d.dilation, 
                                        conv2d.transposed, 
                                        conv2d.output_padding, 
                                        conv2d.groups, 
                                        True)
        self.weight = conv2d.weight
        self.bias = conv2d.bias
        
    def gradprop(self, DY):
        output_padding = self.X.size()[2] - ((self.Y.size()[2] - 1) * self.stride[0] \
                                             - 2 * self.padding[0] + self.kernel_size[0])
        return F.conv_transpose2d(DY, self.weight, stride=self.stride, 
                                  padding=self.padding, output_padding=output_padding)
        
    def relprop(self, R):
        Z = self.Y + 1e-9
        S = R / Z
        C = self.gradprop(S)
        R = self.X * C
        return R
########################################

# hyperparameters
num_workers = 4
batch_size = 10
# buffer to store predictions & labels
buffer_label = list()
buffer_Alex = list()
buffer_VGG = list()

# directory for input images
val_dir = 'ILSVRC2012_img_val'
# directory for output heatmaps
out_dir = 'LRP'

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
# define data loader
val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(val_dir,
                         transforms.Compose([
                             transforms.Scale(256),
                             transforms.CenterCrop(224),
                             transforms.ToTensor(),
                             normalize,
                         ])),
    batch_size=batch_size, shuffle=False,
    num_workers=num_workers, pin_memory=True)

# load pre-trained Alexnet
alex = models.alexnet(pretrained=True).cuda()
for param in alex.parameters():
    param.requires_grad = False
# load pre-trained VGG16
vgg16 = models.vgg16(pretrained=True).cuda()
for param in vgg16.parameters():
    param.requires_grad = False

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.layers = nn.Sequential(
            Conv2d(alex.features[0]),
            ReLU(),
            MaxPool2d(alex.features[2]),
            Conv2d(alex.features[3]),
            ReLU(),
            MaxPool2d(alex.features[5]),
            Conv2d(alex.features[6]),
            ReLU(),
            Conv2d(alex.features[8]),
            ReLU(),
            Conv2d(alex.features[10]),
            ReLU(),
            MaxPool2d(alex.features[12]),
            Reshape_Alex(),
            Linear(alex.classifier[1]),
            ReLU(),
            Linear(alex.classifier[4]),
            ReLU(),
            Linear(alex.classifier[6])
        )
        
    def forward(self, x):
        x = self.layers(x)
        return x
        
    def relprop(self, R):
        for l in range(len(self.layers), 0, -1):
            R = self.layers[l-1].relprop(R)
        return R

class VGG16Net(nn.Module):
    def __init__(self):
        super(VGG16Net, self).__init__()
        self.layers = nn.Sequential(
            Conv2d(vgg16.features[0]),
            ReLU(),
            Conv2d(vgg16.features[2]),
            ReLU(),
            MaxPool2d(vgg16.features[4]),
            Conv2d(vgg16.features[5]),
            ReLU(),
            Conv2d(vgg16.features[7]),
            ReLU(),
            MaxPool2d(vgg16.features[9]),
            Conv2d(vgg16.features[10]),
            ReLU(),
            Conv2d(vgg16.features[12]),
            ReLU(),
            Conv2d(vgg16.features[14]),
            ReLU(),
            MaxPool2d(vgg16.features[16]),
            Conv2d(vgg16.features[17]),
            ReLU(),
            Conv2d(vgg16.features[19]),
            ReLU(),
            Conv2d(vgg16.features[21]),
            ReLU(),
            MaxPool2d(vgg16.features[23]),
            Conv2d(vgg16.features[24]),
            ReLU(),
            Conv2d(vgg16.features[26]),
            ReLU(),
            Conv2d(vgg16.features[28]),
            ReLU(),
            MaxPool2d(vgg16.features[30]),
            Reshape_VGG(),
            Linear(vgg16.classifier[0]),
            ReLU(),
            Linear(vgg16.classifier[3]),
            ReLU(),
            Linear(vgg16.classifier[6])
        )
        
    def forward(self, x):
        x = self.layers(x)
        return x
        
    def relprop(self, R):
        for l in range(len(self.layers), 0, -1):
            R = self.layers[l-1].relprop(R)
        return R

model_Alex = AlexNet().cuda()
model_VGG = VGG16Net().cuda()

# forward hook method for retrieving intermediate results
def forward_hook(self, input, output):
    self.X = input[0]
    self.Y = output
    
for i in range(0, len(model_Alex.layers)):
    model_Alex.layers[i].register_forward_hook(forward_hook)
    
for i in range(0, len(model_VGG.layers)):
    model_VGG.layers[i].register_forward_hook(forward_hook)

model_Alex.eval()
model_VGG.eval()

correct_Alex = 0
correct_VGG = 0
buffer_label = []
buffer_Alex = []
buffer_VGG = []

for idx, (input, label) in enumerate(val_loader):
    input, label = Variable(input, volatile=True).cuda(), Variable(label).cuda()
    
    output_Alex = model_Alex(input)
    pred_Alex = output_Alex.data.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct_Alex += pred_Alex.eq(label.data.view_as(pred_Alex)).cpu().sum()
    
    T_Alex = pred_Alex.squeeze().cpu().numpy()
    T_Alex = (T_Alex[:,np.newaxis] == np.arange(1000))*1.0
    T_Alex = torch.from_numpy(T_Alex).type(torch.FloatTensor)
    T_Alex = Variable(T_Alex).cuda()
    LRP_Alex = model_Alex.relprop(output_Alex * T_Alex)
    
    output_VGG = model_VGG(input)
    pred_VGG = output_VGG.data.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct_VGG += pred_VGG.eq(label.data.view_as(pred_VGG)).cpu().sum()
    
    T_VGG = pred_VGG.squeeze().cpu().numpy()
    T_VGG = (T_VGG[:,np.newaxis] == np.arange(1000))*1.0
    T_VGG = torch.from_numpy(T_VGG).type(torch.FloatTensor)
    T_VGG = Variable(T_VGG).cuda()
    LRP_VGG = model_VGG.relprop(output_VGG * T_VGG)
    
    buffer_label.append(label.data.cpu().numpy())
    buffer_Alex.append(pred_Alex.cpu().numpy())
    buffer_VGG.append(pred_VGG.cpu().numpy())
    
    # save results which are classified correctly by VGG16, incorrectly by AlexNet
    for i in range(0, batch_size):
        if (pred_Alex.squeeze().cpu().numpy()[i] != label.data.cpu().numpy()[i]) \
        and (pred_VGG.squeeze().cpu().numpy()[i] == label.data.cpu().numpy()[i]):
            img = input[i].permute(1,2,0).data.cpu().numpy()
            img = 255 * (img-img.min()) / (img.max()-img.min())
            img = img.astype('uint8')
            Image.fromarray(img, 'RGB').save(directory + '/%d_input_%d.JPEG' \
                                             % ((idx*batch_size+i, label.data.cpu().numpy()[i])))
                    
            heatmap_Alex = LRP_Alex[i].permute(1,2,0).data.cpu().numpy()
            heatmap_Alex = np.absolute(heatmap_Alex)
            heatmap_Alex = 255 * (heatmap_Alex-heatmap_Alex.min()) / (heatmap_Alex.max()-heatmap_Alex.min())
            heatmap_Alex = heatmap_Alex.astype('uint8')
            Image.fromarray(heatmap_Alex, 'RGB').save(directory + '/%d_LRP_Alex_%d.JPEG' \
                                                 % ((idx*batch_size+i, pred_Alex.squeeze().cpu().numpy()[i])))
            
            heatmap_VGG = LRP_VGG[i].permute(1,2,0).data.cpu().numpy()
            heatmap_VGG = np.absolute(heatmap_VGG)
            heatmap_VGG = 255 * (heatmap_VGG-heatmap_VGG.min()) / (heatmap_VGG.max()-heatmap_VGG.min())
            heatmap_VGG = heatmap_VGG.astype('uint8')
            Image.fromarray(heatmap_VGG, 'RGB').save(directory + '/%d_LRP_VGG_%d.JPEG' \
                                                 % ((idx*batch_size+i, pred_VGG.squeeze().cpu().numpy()[i])))

print('Done...')