In [1]:
import numpy as np
import torch
import torchvision.transforms as transforms
#from glob import glob
import torch.backends.cudnn as cudnn
from modules.resnet import resnet50
import matplotlib.cm
from torch.autograd import Variable as V
import torchvision.models as models
from torch.nn import functional as F
import os
from PIL import Image

In [2]:
arch = 'resnet50'
# load the pre-trained weights
model_file = 'ptrained-models/%s_places365.pth.tar' % arch
model = models.__dict__[arch](num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict)
model.eval()
print('Loaded Original')

Loaded Original


In [3]:
Rmodel = resnet50(num_classes=365)
Rmodel.load_state_dict(state_dict)
Rmodel.eval()
print("Loaded RAP Enabled Model")

Loaded RAP Enabled Model


In [4]:
# load the image transformer
centre_crop = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# load the class label
file_name = 'categories_places365.txt'
classes = list()
with open(file_name) as class_file:
    for line in class_file:
        classes.append(line.strip().split(' ')[0][3:])
classes = tuple(classes)

# load the test image
img_name = 'places365_standard/train/canal-urban/00000559.jpg'
img = Image.open(img_name)
input_img = V(centre_crop(img).unsqueeze(0))

In [5]:
# forward pass
logit = model.forward(input_img)
h_x = F.softmax(logit, 1).data.squeeze()
probs, idx = h_x.sort(0, True)

print('{} original model prediction on {}'.format(arch,img_name))
# output the prediction
for i in range(0, 5):
    print('{:.3f} -> {}'.format(probs[i], classes[idx[i]]))

# forward pass
logit = Rmodel.forward(input_img)
h_x = F.softmax(logit, 1).data.squeeze()
probs, idx = h_x.sort(0, True)

print('{} rap model prediction on {}'.format(arch,img_name))
# output the prediction
for i in range(0, 5):
    print('{:.3f} -> {}'.format(probs[i], classes[idx[i]]))

resnet50 original model prediction on places365_standard/train/canal-urban/00000559.jpg
0.858 -> canal/urban
0.045 -> tower
0.027 -> bridge
0.011 -> church/outdoor
0.006 -> river
resnet50 rap model prediction on places365_standard/train/canal-urban/00000559.jpg
0.858 -> canal/urban
0.045 -> tower
0.027 -> bridge
0.011 -> church/outdoor
0.006 -> river


In [6]:
def enlarge_image(img, scaling = 3):
    if scaling < 1 or not isinstance(scaling,int):
        print ('scaling factor needs to be an int >= 1')

    if len(img.shape) == 2:
        H,W = img.shape
        out = np.zeros((scaling*H, scaling*W))
        for h in range(H):
            fh = scaling*h
            for w in range(W):
                fw = scaling*w
                out[fh:fh+scaling, fw:fw+scaling] = img[h,w]
    elif len(img.shape) == 3:
        H,W,D = img.shape
        out = np.zeros((scaling*H, scaling*W,D))
        for h in range(H):
            fh = scaling*h
            for w in range(W):
                fw = scaling*w
                out[fh:fh+scaling, fw:fw+scaling,:] = img[h,w,:]
    return out

def hm_to_rgb(R, scaling = 3, cmap = 'bwr', normalize = True):
    cmap = eval('matplotlib.cm.{}'.format(cmap))
    if normalize:
        R = R / np.max(np.abs(R)) # normalize to [-1,1] wrt to max relevance magnitude
        R = (R + 1.)/2. # shift/normalize to [0,1] for color mapping
    R = R
    R = enlarge_image(R, scaling)
    rgb = cmap(R.flatten())[...,0:3].reshape([R.shape[0],R.shape[1],3])
    return rgb

def visualize(relevances, img_name):
    # visualize the relevance
    n = len(relevances)
    #print(n)
    heatmap = np.sum(relevances.reshape([n, 224, 224, 1]), axis=3)
    #print(heatmap.shape)
    heatmaps = []
    for h, heat in enumerate(heatmap):
        #print(h,heat.shape)
        maps = hm_to_rgb(heat, scaling=1, cmap = 'seismic')
        heatmaps.append(maps)
        im = Image.fromarray((maps*255).astype(np.uint8))
        im.save('output_heatmap_'+img_name+'.png')
        
def visualize_v2(relevances, img_name):
    # visualize the relevance
    n = len(relevances)
    #print(n)
    heatmap = np.sum(relevances.reshape([n, 224, 224, 1]), axis=3)
    #print(heatmap.shape)
    heatmaps = []
    for h, heat in enumerate(heatmap):
        #print(h,heat.shape)
        #maps = hm_to_rgb(heat, scaling=1, cmap = 'seismic')
        #heatmaps.append(maps)
        heat = heat / np.max(np.abs(heat))
        heat = (heat + 1.)/2
        im = Image.fromarray((heat*255).astype(np.uint8))
        im.save('output_heatmap_'+img_name+'.png')
        
        

def compute_pred(output):
    pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
    print('Pred cls : '+str(pred))
    T = pred.squeeze().cpu().numpy()
    T = np.expand_dims(T, 0)
    T = (T[:, np.newaxis] == np.arange(365)) * 1.0
    T = torch.from_numpy(T).type(torch.FloatTensor)
    Tt = V(T)
    return Tt

In [8]:
input = V(input_img, volatile=True)
input.requires_grad = True
output = model(input)
T = compute_pred(output)
Res = Rmodel.relprop(R = output * T, alpha= 1).sum(dim=1, keepdim=True)
heatmap = Res.permute(0, 2, 3, 1).data.cpu().numpy()
print(heatmap.shape)
visualize_v2(heatmap.reshape([1, 224, 224, 1]), 'lrp')


RAP = Rmodel.RAP_relprop(R=T)
Res = (RAP).sum(dim=1, keepdim=True)
heatmap = Res.permute(0, 2, 3, 1).data.cpu().numpy()
print(heatmap.shape)
visualize(heatmap.reshape([1, 224, 224, 1]), 'rap')

print('Done')

  """Entry point for launching an IPython kernel.


Pred cls : tensor([[79]])
(1, 224, 224, 1)
(1, 224, 224, 1)
Done


In [9]:
!cp places365_standard/train/canal-urban/00000559.jpg ./inp.jpg

In [19]:
path = 'output_heatmap_lrp.png'
with open(path, 'rb') as f:
        img = Image.open(f)
        print(img.mode)
        img.convert('RGB')

print(np.asarray(img))

L
[[162 182 185 ... 130 129 129]
 [170 202 197 ... 132 130 129]
 [180 207 208 ... 133 130 130]
 ...
 [127 128 128 ... 129 129 129]
 [127 128 128 ... 129 130 129]
 [127 127 127 ... 129 129 129]]
