In [None]:
import VGG_FACE
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import os
from PIL import Image
import copy

# add relprop() method to each layer
########################################
class Linear(nn.Linear):
    def __init__(self, linear):
        super(nn.Linear, self).__init__()
        self.in_features = linear.in_features
        self.out_features = linear.out_features
        self.weight = linear.weight
        self.bias = linear.bias
        
    def relprop(self, R):
        V = torch.clamp(self.weight, min=0)
        Z = torch.mm(self.X, torch.transpose(V,0,1)) + 1e-9
        S = R / Z
        C = torch.mm(S, V)
        R = self.X * C
        return R
        
class ReLU(nn.ReLU):   
    def relprop(self, R): 
        return R
    
class Reshape(nn.Module):
    def __init__(self):
        super(Reshape, self).__init__()
        
    def forward(self, x):
        #return x.view(-1, 256*6*6)
        return x.view(-1, 512*7*7)
        
    def relprop(self, R):
        #return R.view(-1, 256, 6, 6)
        return R.view(-1, 512, 7, 7)

class MaxPool2d(nn.MaxPool2d):
    def __init__(self, maxpool2d):
        super(nn.MaxPool2d, self).__init__()
        self.kernel_size = maxpool2d.kernel_size
        self.stride = maxpool2d.stride
        self.padding = maxpool2d.padding
        self.dilation = maxpool2d.dilation
        self.return_indices = maxpool2d.return_indices
        self.ceil_mode = maxpool2d.ceil_mode
        
    def gradprop(self, DY):
        DX = self.X * 0
        temp, indices = F.max_pool2d(self.X, self.kernel_size, self.stride, 
                                     self.padding, self.dilation, self.ceil_mode, True)
        DX = F.max_unpool2d(DY, indices, self.kernel_size, self.stride, self.padding)
        return DX
    
    def relprop(self, R):
        Z = self.Y + 1e-9
        S = R / Z
        C = self.gradprop(S)
        R = self.X * C
        return R

class Conv2d(nn.Conv2d):
    def __init__(self, conv2d):
        super(nn.Conv2d, self).__init__(conv2d.in_channels, 
                                        conv2d.out_channels, 
                                        conv2d.kernel_size, 
                                        conv2d.stride, 
                                        conv2d.padding, 
                                        conv2d.dilation, 
                                        conv2d.transposed, 
                                        conv2d.output_padding, 
                                        conv2d.groups, 
                                        True)
        self.weight = conv2d.weight
        self.bias = conv2d.bias
        
    def gradprop(self, DY):
        output_padding = self.X.size()[2] - ((self.Y.size()[2] - 1) * self.stride[0] \
                                             - 2 * self.padding[0] + self.kernel_size[0])
        return F.conv_transpose2d(DY, self.weight, stride=self.stride, 
                                  padding=self.padding, output_padding=output_padding)
        
    def relprop(self, R):
        Z = self.Y + 1e-9
        S = R / Z
        C = self.gradprop(S)
        R = self.X * C
        return R
########################################

#hyperparameters
num_workers = 4
batch_size = 25

# We use a pre-trained Torch model for VGG-FACE and convert it to a PyTorch model. 
# For details on how to do this, please refer to the README file. 
vgg_face = VGG_FACE.VGG_FACE
vgg_face.load_state_dict(torch.load('VGG_FACE.pth'))

# directory for input images
val_dir = 'VGG_Face'
# directory for output heatmaps
directory = 'LRP'

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
# define data loader
val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(val_dir,
                         transforms.Compose([
                             transforms.Scale(256),
                             transforms.CenterCrop(224),
                             transforms.ToTensor(),
                             normalize,
                         ])),
    batch_size=batch_size, shuffle=False,
    num_workers=num_workers, pin_memory=True)
    
class VGGFaceNet(nn.Module):
    def __init__(self):
        super(VGGFaceNet, self).__init__()
        self.layers = nn.Sequential(
            Conv2d(vgg_face[0]),
            ReLU(),
            Conv2d(vgg_face[2]),
            ReLU(),
            MaxPool2d(vgg_face[4]),
            Conv2d(vgg_face[5]),
            ReLU(),
            Conv2d(vgg_face[7]),
            ReLU(),
            MaxPool2d(vgg_face[9]),
            Conv2d(vgg_face[10]),
            ReLU(),
            Conv2d(vgg_face[12]),
            ReLU(),
            Conv2d(vgg_face[14]),
            ReLU(),
            MaxPool2d(vgg_face[16]),
            Conv2d(vgg_face[17]),
            ReLU(),
            Conv2d(vgg_face[19]),
            ReLU(),
            Conv2d(vgg_face[21]),
            ReLU(),
            MaxPool2d(vgg_face[23]),
            Conv2d(vgg_face[24]),
            ReLU(),
            Conv2d(vgg_face[26]),
            ReLU(),
            Conv2d(vgg_face[28]),
            ReLU(),
            MaxPool2d(vgg_face[30]),
            Reshape(),
            Linear(vgg_face[32][1]),
            ReLU(),
            Linear(vgg_face[35][1]),
            ReLU(),
            Linear(vgg_face[38][1])
        )
        
    def forward(self, x):
        x = self.layers(x)
        return x
        
    def relprop(self, R):
        for l in range(len(self.layers), 0, -1):
            R = self.layers[l-1].relprop(R)
        return R
    
model = VGGFaceNet().cuda()

# forward hook method for retrieving intermediate results
def forward_hook(self, input, output):
    self.X = input[0]
    self.Y = output
    
for i in range(0, len(model.layers)):
    model.layers[i].register_forward_hook(forward_hook)

model.eval()
for idx, (input, label) in enumerate(val_loader):
    input = Variable(input, volatile=True).cuda()
    
    output = model(input)
    pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
    
    T = pred.squeeze().cpu().numpy()
    # Original VGG-Face has 2622 classes
    T = (T[:,np.newaxis] == np.arange(2622))*1.0
    T = torch.from_numpy(T).type(torch.FloatTensor)
    T = Variable(T).cuda()
    LRP = model.relprop(output * T)
    
    # save results
    for i in range(0, batch_size):
        img = input[i].permute(1,2,0).data.cpu().numpy()
        img = 255 * (img-img.min()) / (img.max()-img.min())
        img = img.astype('uint8')
        Image.fromarray(img, 'RGB').save(directory + '/%d_input.JPEG' % ((idx*batch_size+i)))
                
        heatmap = LRP[i].permute(1,2,0).data.cpu().numpy()
        heatmap = np.absolute(heatmap)
        heatmap = 255 * (heatmap-heatmap.min()) / (heatmap.max()-heatmap.min())
        heatmap = heatmap.astype('uint8')
        Image.fromarray(heatmap, 'RGB').save(directory + '/%d_LRP.JPEG' % ((idx*batch_size+i)))

print('Done...')