# Explaining Classifiers using Adversarial Perturbations on the Perceptual Ball
Andrew Elliott, Stephen Law and Chris Russell 

This notebook gives a simple example of running our explainability method on a single image.

Please note that without a GPU this notebook may take a little time to generate the resultant images.


In [None]:
# Loading libraries

# Image libraries and image processing functions
import torchvision
from scipy.ndimage import gaussian_filter


# Our method code
from common_code.mul_perceptualFunc_Final import *
from common_code import utils

# Plotting libraries
import matplotlib.pyplot as plt


First we load the model, this can be changed to other VGG based variants, however the layers selected later would have to be changed to match the ReLUs in that network. To use with other networks a small change would be required in the feature extractor so it output the correct layers.

In [None]:
# Simple function wrapper
def perceptual(image_name,k=100,model=None,layerLists=['16', '19', '22', '25', '29', '32'],sa=1,ga=10000):
    '''
    perceptual function that returns a saliency map. 
    k = iterations
    layerLists = the layers to regularise with the perceptual loss
    ga = weight of perceptual loss 
    '''
    
    # If the model is not specified assume vgg19
    if model==None:
        model = torchvision.models.vgg19_bn(pretrained=True)
        model.requires_grad=False
        model.eval()
    
    # load the image
    img_tensor = utils.open_and_preprocess(image_name)
    img_variable = img_tensor.unsqueeze_(0)

    # create the perceptual loss with the required parameters
    loss=create_perceptual_loss2(-2,img_variable,model,gamma=ga,scalar=sa,layers=layerLists)

    # optimise the loss to find the adv. perturbation
    c=find_direction(loss,img_variable,iterations=k)
 
    # Take pixelwise euclidean distance to get the saliency map
    res=torch.sqrt(((c - img_variable)**2).mean(1))
    res=res.squeeze().cpu().detach().numpy()
    return res


We will demonstrate the method on the image in Fig. 2 in the paper. First lets specify our model, we will use a standard VGG19bn pretrained model from torchvision

In [None]:
# Load the relevant model
premodel = torchvision.models.vgg19_bn(pretrained=True)
premodel.requires_grad=False
premodel.eval()


We specify the layers of the network that correspond to ReLus so we can regularise the correct layers. Note we include 

In [None]:
# layers list for VGG19_BN
layerAll=['2','5','9','12','16','19','22','25','29','32','35','38','42','45','48','51']


Next, lets specify our image. As a demonstration we will use the image from figure in our paper which we display below:

In [None]:
# image name
img='ILSVRC2012_val_00000051.JPEG'
im = utils.open_and_resize(img)

fig = plt.figure(frameon=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(im)
plt.show()

Next lets explore the saliency map regularizing on all layers, on CPU this may take a little while to run. To run on GPU, both the model and the image need to be on GPU above.

In [None]:
# layers to regularise
layerslist=["0-1-2-3-4-5-6-7-8-9-10-11-12"]

# get layers
layers=[layerAll[int(k)] for k in layerslist[0].split('-')]

# run adversarial perturbation on the perceptual ball
res = perceptual(img,k=100,model=premodel,layerLists=layers,sa=1,ga=10000)

# gaussian blur on image
mat1 = gaussian_filter(res, sigma=2)

# visualise the resultant saliency map
fig = plt.figure(frameon=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(mat1.squeeze())
plt.show()


Finally, we run this on all of the layer sets in the paper.

We first make all of the images for each collection of layers. Note that this might take a while, expecially on CPU.

To make this runable by a wide audience we have reduced the number of LBFGS iterations and pushed the computation to CPU, which might result in slightly different results. 

In [None]:
layerslists = []
layerslists.append("")
layerslists.append("0-1-2")
layerslists.append("0-1-2-3-4")
layerslists.append("0-1-2-3-4-5-6")
layerslists.append("0-1-2-3-4-5-6-7-8-9-10-11-12")

images = []
for layerslist in layerslists:
    if len(layerslist)>0:
        layers = [layerAll[int(k)] for k in layerslist.split('-')]
    else:
        layers = []
    res = perceptual(img,k=100,model=premodel,layerLists=layers,sa=1,ga=10000)

    # gaussian blur on image
    mat1 = gaussian_filter(res, sigma=2)
    images.append(mat1)

Plot the resultant perturbations to obtain the similar plot to the version in the paper.

In [None]:
fig,axs = plt.subplots(1,6,figsize=(18,3))
titles = ['Orig']
titles.append("NoPerceptual")
titles.append("0-2")
titles.append("0-4")
titles.append("0-6")
titles.append("0-12")

for ax,curIm,title in zip(axs,[im,]+images,titles):
    ax.set_axis_off()
    ax.imshow(curIm)
    ax.set_title(title,fontsize=20)
plt.tight_layout(pad=0)
plt.show()