<center>
<img src="banner-small.png">
</center>

<h1>Tutorial: Deep Neural Networks and Explanations in PyTorch</h1>

<p>The goal of this tutorial is to train a neural network to predict an image dataset with few labels. For this, we consider a subset of the "Labeled Faces in the Wild" dataset, readily available in Scikit-Learn:</p>

In [None]:
import sklearn
import sklearn.datasets

D = sklearn.datasets.fetch_lfw_people(
    slice_=(slice(68, 197, None), slice(68, 197, None)),
    resize=0.5, min_faces_per_person=40,color=True)

In [None]:
nc = len(D.target_names)

print('Number of examples: %d'%len(D['images']))
print('Number of classes: %d'%nc)

import torch

Itrain = D.data[::2].reshape(-1,64,64,3)/255.0
Itest = D.data[1::2].reshape(-1,64,64,3)/255.0
Ttrain = torch.LongTensor(D.target[::2])
Ttest = torch.LongTensor(D.target[1::2])
Xtrain = torch.FloatTensor(Itrain.transpose(0,3,1,2)*3-1.5)
Xtest = torch.FloatTensor(Itest.transpose(0,3,1,2)*3-1.5)
print(Itrain.shape)
print(Itest.shape)
print(Ttrain.shape)
print(Ttest.shape)
print(Xtrain.shape)
print(Xtest.shape)
#print(Ytrain.shape)
#print(Ytest.shape)
import matplotlib
%matplotlib inline
from matplotlib import pyplot as plt

def images(start):
    f = plt.figure(figsize=(16,2))
    for j in range(8):
        p = f.add_subplot(1,8,j+1)
        p.imshow(Itest[start:start+8][j])
        p.set_xlabel(D.target_names[Ttest[start:start+8][j]].split(" ")[-1])
        p.set_xticks([])
        p.set_yticks([])

images(2)

<h2>Part 1: Training a convolutional neural network</h2>

<p>We now consider a simple convolutional neural network composed of 4 convolution layers, rectified linear units, and three pooling stages.</p>

In [None]:
import torch.nn as nn
import torch.optim as optim

net1 = nn.Sequential(
    nn.Conv2d(  3, 10, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
    nn.Conv2d( 10, 25, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
    nn.Conv2d( 25,100, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
    nn.Conv2d(100, nc, 8)
)

The following function takes a neural network and prints the training and test accuracy on the face classification data.

In [None]:
def printacc(net):
    net.eval()
    Ytrain = net.forward(Xtrain).view(-1,nc).data
    Ytest = net(Xtest).view(-1,nc).data
    acctrain = torch.mean((torch.max(Ytrain,dim=1)[1] == Ttrain).type(torch.FloatTensor)).item()
    acctest  = torch.mean((torch.max(Ytest, dim=1)[1] == Ttest).type(torch.FloatTensor)).item()
    print('train: %.3f  test: %.3f' %(acctrain,acctest))
    net.train()
    
printacc(net1)

The network initially predicts at random, hence a low training and test accuracy. The following function trains the neural network on the training data using stochastic gradient descent for a certain number of iterations.

In [None]:
import numpy

def train(net,nbit=2500):
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
    
    for i in range(nbit+1):

        R = numpy.random.permutation(len(Xtrain))[:25]

        xr, tr = Xtrain[R]*1, Ttrain[R]*1

        optimizer.zero_grad()
        criterion(net.forward(xr).view(-1,nc),tr).backward()
        optimizer.step()

        if i % (nbit//5) == 0: printacc(net)

train(net1)

Throughout training, the prediction accuracy becomes maximum on the training set but reaches some saturation on the test set. This suggests that the neural network is sufficiently large and that the main bottleneck is statistical overfitting.

<h2>Part 2: Improving generalization</h2>

We would like to improve the generalization ability of the neural network.

<h3>Dropout</h3>

As a first try, we add a dropout layer in the network. Note that dropout layers are the most effective in the last layers of the network.

In [None]:
net2 = nn.Sequential(
    nn.Conv2d(  3, 10, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
    nn.Conv2d( 10, 25, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
    nn.Conv2d( 25,100, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
    nn.Dropout2d(),
    nn.Conv2d(100, nc, 8)
)

train(net2)

The training accuracy still reaches 1.0, but thanks to dropout, the test accuracy is now sensibly higher.

<h3>Transfer learning</h3>

We now investigate a second technique to improve the test accuracy, which is to use a neural network that has been pretrained on some generic computer vision task with many labels, e.g. ImageNet. We load the 17 first layers of the VGG-16 network pretrained on ImageNet and apply these layers to our face data to generate features.

In [None]:
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=True).features
features = nn.Sequential(*list(vgg16)[:17])
print(features)

Pretrained neural networks are usually preforming highly nonlinear and high-dimensional mappings. Hence, a linear model trained on top of these features may be sufficient. We train here a simple logistic regressor with scikit learn.

In [None]:
import sklearn
import sklearn.linear_model

lr = sklearn.linear_model.LogisticRegression()
Ftrain = features.forward(Xtrain).data
lr.fit(numpy.array(Ftrain.reshape(len(Xtrain),-1)),Ttrain)

A logistic regressor can be seen as a one-layer neural network trained with cross entropy. Hence, we can convert our logistic regression model into a neural network layer and append it to our sequence of features.

In [None]:
weight = torch.FloatTensor(lr.__dict__['coef_'].reshape(nc,256,8,8))
bias = torch.FloatTensor(lr.__dict__['intercept_'])
topconv = nn.Conv2d(256,nc,8)
topconv.weight = nn.Parameter(weight)
topconv.bias = nn.Parameter(bias)
net3 = nn.Sequential(*(list(features)+[topconv]))

The neural network can now be tested for its accuracy on the training and test set.

In [None]:
printacc(net3)

The training accuracy is still maximum, but this time the test accuracy has dramatically increased. This suggests that the generic VGG-16 visual features are very useful for our classification task.

<h3>Part 3: Explaining predictions with LRP</h3>

<p> We now would like to get insight into the predictions of our models by applying the LRP method. The LRP-0, LRP-$\epsilon$, and LRP-$\gamma$ rules described in the <a href="https://doi.org/10.1007/978-3-030-28954-6_10">LRP overview paper</a> (Section 10.2.1) for propagating relevance on the lower layers are special cases of the more general propagation rule</p>

<p>
$$
R_j = \sum_k \frac{a_j \rho(w_{jk})}{\epsilon + \sum_{0,j} a_j \rho(w_{jk})} R_k
$$
</p>

<p>(cf. Section 10.2.2), where $\rho$ is a function that transform the weights, and $\epsilon$ is a small positive increment. We now come to the practical implementation of this general rule. It can be decomposed as a sequence of four computations:</p>

<p>
\begin{align*}
\forall_k:&~z_k = {\textstyle \epsilon + \sum_{0,j}} a_j \rho(w_{jk}) & (\text{step }1)\\
\forall_k:&~s_k = R_k / z_k \qquad & (\text{step }2)\\
\forall_j:&~c_j = {\textstyle \sum_k} \rho(w_{jk}) s_k \qquad & (\text{step }3)\\
\forall_j:&~R_j = a_j \cdot c_j \qquad & (\text{step }4)
\end{align*}
</p>

<p>The layer-wise relevance propagation procedure then consists of iterating over the layers in reverse order, starting from the top layer towards the first layers, and at each layer, applying this sequence of computations.</p>

In [None]:
import copy

def newlayer(layer,g):

    layer = copy.deepcopy(layer)

    try: layer.weight = nn.Parameter(g(layer.weight))
    except AttributeError: pass

    try: layer.bias   = nn.Parameter(g(layer.bias))
    except AttributeError: pass

    return layer

def LRP(layers,X,T):
    
    L = len(layers)
    
    # -------------------------------------------------------------
    # Set up activations and relevance scores
    # -------------------------------------------------------------
    A = [X]+[None]*L
    for l in range(L): A[l+1] = layers[l].forward(A[l])
    R = [None]*L + [A[-1].data*T]

    for l in range(0,L)[::-1]:
        
        A[l] = (A[l].data).requires_grad_(True)

        
        # -------------------------------------------------------------
        # Special case: first layer
        # -------------------------------------------------------------
        if l==0:
            
            A[0] = (A[0].data).requires_grad_(True)

            lb = (A[0].data*0-1.5).requires_grad_(True)
            hb = (A[0].data*0+1.5).requires_grad_(True)

            z = layers[0].forward(A[0]) + 1e-9                                     # step 1 (a)
            z -= newlayer(layers[0],lambda p: p.clamp(min=0)).forward(lb)    # step 1 (b)
            z -= newlayer(layers[0],lambda p: p.clamp(max=0)).forward(hb)    # step 1 (c)
            s = (R[1]/z).data                                                      # step 2
            (z*s).sum().backward(); c,cp,cm = A[0].grad,lb.grad,hb.grad            # step 3
            R[0] = (A[0]*c+lb*cp+hb*cm).data                                       # step 4

            
        # -------------------------------------------------------------
        # General convolution and pooling layers
        # -------------------------------------------------------------
        elif isinstance(layers[l],torch.nn.Conv2d) or isinstance(layers[l],torch.nn.MaxPool2d):

            rho = lambda p: p + (0.25 if l < L-1 else 0.0)*p.clamp(min=0); incr = lambda z: z+1e-9

            z = incr(newlayer(layers[l],rho).forward(A[l])) # step 1
            s = (R[l+1]/z).data                                   # step 2
            (z*s).sum().backward(); c = A[l].grad                 # step 3
            R[l] = (A[l]*c).data                                  # step 4

        # -------------------------------------------------------------
        # ReLU layers (pass through)
        # -------------------------------------------------------------
        else:

            R[l] = R[l+1]
            
    return R

We can now use the LRP function to find evidence found by each network to explain class membership.

In [None]:
from matplotlib.colors import ListedColormap
my_cmap = plt.cm.seismic(numpy.arange(plt.cm.seismic.N))
my_cmap[:,0:3] *= 0.85
my_cmap = ListedColormap(my_cmap)

def heatmaps(net,start):

    R = LRP(list(net),Xtest[start:start+8],torch.eye(nc)[Ttest[start:start+8]].view(-1,nc,1,1))[0]
    R = numpy.array(R)
    
    Rmax = numpy.abs(R).max()
    Rsmax = numpy.abs(R).sum(axis=1).max()
    
    f = plt.figure(figsize=(16,2))
    for j in range(8):
        p = f.add_subplot(1,8,j+1)
        p.set_xticks([])
        p.set_yticks([])
        p.imshow(R[j].sum(axis=0),cmap=my_cmap,vmin=-Rsmax,vmax=Rsmax)

In [None]:
images(0)
heatmaps(net1,0)
heatmaps(net2,0)
heatmaps(net3,0)

The last network (based on pretrained VGG-16) is not only the one with highest accuracy, its decisions are also supported by a broader set of features.