Author: Thrupthi Ann John https://github.com/ThrupthiAnn


# Demo of CIS and CMS maps

This notebook takes you through the steps to calculate Canonical Image Saliency maps and Model Saliency Maps. 
 In our paper, the experiments are conducted for 22085 random images from the CelebA dataset. In this demo, we have provided data for 100 images. Feel free to use your own images. You can use any model in PyTorch, although we have provided models for VGG-16 (recognition and gender) trained on CelebA. 

<p>This is the first demo notebook. After this, run <b>demo2_explanation.ipynb</b> and then <b>demo3_metrics.ipynb</b>


#### Folder names and hyperparameters

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models, transforms
import torch
from pathlib import Path
from os.path import join
from skimage.io import imread
from skimage.transform import resize
from scipy.io import loadmat
from CalculateCIS import Occlusion
import matplotlib.colors
%matplotlib inline
from PIL import Image
from os import makedirs

samplefolder = '../data/SampleData'
imagefolder = join(samplefolder, 'SampleImages')
meshfolder = join(samplefolder, 'Sample3D')

resultsfolder = '../results'
cisfolder = join(resultsfolder, 'CIS')

makedirs(cisfolder, exist_ok=True)
modelfolder = '../data/Models'
frontal = imread(join(modelfolder, 'frontal.jpg'))
frontal_mesh = loadmat(join(modelfolder, 'frontal_mesh.mat'))['vertices'][::20,:]
size = 15
device = torch.device('cuda')

#### Some utility functions

In [None]:
def PlotColormap(colormap, alpha=0.7):
	alphas = np.ones(colormap.shape)
	alphas[colormap==0] = 0
	vmin = np.min(colormap)
	vmax = np.max(colormap)
	cmap = plt.cm.jet
	colors = matplotlib.colors.Normalize(vmin, vmax, clip=True)(colormap)
	colors = cmap(colors)
	colors[..., -1] = alphas*alpha
	fig = plt.figure()
	plt.imshow(frontal)
	plt.imshow(colors)
	plt.axis('off')
	plt.colorbar()
	
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
Norm = transforms.Compose([transforms.ToTensor(),
						   transforms.Normalize(mean, std, inplace=True)])

def ViewImage(filename):
	img = imread(filename)
	fig = plt.figure()
	plt.imshow(img)
	plt.show()
	

def preprocess_image(pil_im, resize_im=True):
    """
        Processes image for CNNs

    Args:
        PIL_img (PIL_img): Image to process
        resize_im (bool): Resize to 224 or not
    returns:
        im_as_var (torch variable): Variable that contains processed float tensor
    """
    # mean and std list for channels (Imagenet)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    # Resize image
    if resize_im:
        pil_im.thumbnail((512, 512))
    im_as_arr = np.float32(pil_im)
    im_as_arr = im_as_arr.transpose(2, 0, 1)  # Convert array to D,W,H
    # Normalize the channels
    for channel, _ in enumerate(im_as_arr):
        im_as_arr[channel] /= 255
        im_as_arr[channel] -= mean[channel]
        im_as_arr[channel] /= std[channel]
    # Convert to float tensor
    im_as_ten = torch.from_numpy(im_as_arr).float()
    # Add one more channel to the beginning. Tensor shape = 1,3,224,224
    im_as_ten.unsqueeze_(0)
    # Convert to Pytorch variable
   # im_as_var = Variable(im_as_ten, requires_grad=True)
    im_as_ten = im_as_ten.to(device);
    im_as_ten.requires_grad=True;
    return im_as_ten
	
def get_image(filename):
    img = Image.open(filename)
    orig_size = img.size
    img = img.resize((224,224))
    pimg= preprocess_image(img)
    return pimg, orig_size

## Step 1: Get test images
Put your test images in ./SampleImages. 100 sample images are provided. 

## Step 2: Obtain the 3D dense alignment of the sample images.
The 3D alignment of the images in ./SampleImages are provided for you in ./Sample3D. <br>
If you want to run for your own images, clone the repository https://github.com/YadiraF/PRNet You need Tensorflow for this. 

Run the following command:

    python PRNet/demo.py -i SampleImages -o Sample3D --isMat True

<p/>
There is a pytorch version at https://github.com/tomguluson92/PRNet_PyTorch , although I have not tested it. 
    
    
## Step 3: Calculate the CIS map of all sample images

Here, we provide code for VGG-Face trained on CelebA. If you want to use another model, please write your own version of the function <b>Confidence(image, classid)</b>.

In [None]:
def Confidence(image, class_index=None):
	with torch.no_grad():
		output = model(image)
	if class_index is None:
		class_index = torch.argmax(output)
	return output[:,class_index].detach(), class_index

def loadVGGModel( filename):
	dat2 = torch.load(filename)
	# copy dictionary
	if str.split(list(dat2.keys())[0],'.')[0] == 'module':
		dat = {}
		for key in dat2.keys():
			k = '.'.join(str.split(key,'.')[1:])
			dat[k] = dat2[key]
	else:
		dat = dat2
		
	n_classes = dat['classifier.6.bias'].shape[0]
	model = models.vgg16(pretrained = False)
	lastlayer = torch.nn.Linear(in_features = model.classifier[-1].in_features, \
							   out_features = n_classes, \
							   bias = True)
	model.classifier[-1] = lastlayer
	model.load_state_dict(dat)
	return model

# model = loadVGGModel(join(modelfolder, 'VGG16_CelebA_Gender.pth'))
model = loadVGGModel(join(modelfolder, 'VGG16_CelebA_Recognition.pth')) # here is the recognition model
model.to(device)
model.eval()
model = torch.nn.DataParallel(model)

In [None]:
# get list of files
p = Path(imagefolder)
filenames = [i.stem for i in p.glob('**/*.jpg')]

# get CIS map for each image. This takes some time. 
for ii in range(len(filenames)):
	# print(ii)
	img, sz = get_image(join(imagefolder,filenames[ii]+'.jpg'))
	mesh = loadmat(join(meshfolder, filenames[ii])+'_mesh.mat')['vertices']
	# subsample the mesh to make the calculation faster
	if len(mesh)>2194:
		mesh = mesh[::20,:]
	heatmap = Occlusion(Confidence, img.to(device, dtype = torch.float), mesh, sz, frontal, frontal_mesh, size = size,class_index=None, device = device);
	outfilename = join(cisfolder, filenames[ii])
	np.save(outfilename, heatmap)


### View the CIS maps





In [None]:
fig = plt.figure()
plt.imshow(np.load(join(cisfolder, '000686.npy')), cmap = 'jet')

### Step 4: Calculate the CMS maps
    


In [None]:
p = Path(cisfolder)
filenames = [i.stem for i in p.glob('**/*.npy')]
heatmap = []
for ii in range(len(filenames)):
	hmap = np.load(join(cisfolder, filenames[ii]+'.npy'))
	if len(heatmap)==0:
		heatmap = hmap
	else:
		heatmap = heatmap+hmap
		
PlotColormap(heatmap)