# Lab: Explainability in Neural Networks with activation maximitation.


------------------------------------------------------
*Pablo M. Olmos pamartin@ing.uc3m.es*

Signal theory and communications department.

**Universidad Carlos III de Madrid**

<img src='http://www.tsc.uc3m.es/~emipar/BBVA/INTRO/img/logo_uc3m_foot.jpg' width=400 />

------------------------------------------------------

In this part of the lab, we will implement a simple example of **activation maximization** to find out wich patterns are neccesary at the input of a neural network, that has been previously trained, to get a desired output (for example, a high confidence in a classification task).

We will visualize this technique using MNIST database.

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'  #To get figures with high quality!

import numpy as np
import torch
from torch import nn
from torch import optim
import matplotlib.pyplot as plt

In [None]:
from IPython.display import Image
from IPython.core.display import HTML 

Image(url= "https://i1.wp.com/datasmarts.net/es/wp-content/uploads/2019/09/1_yBdJCRwIJGoM7pwU-LNW6Q.png?w=479&ssl=1", width=400, height=200)

Load the dataset with torchvision ...

In [None]:
### Run this cell

from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.,), (1.0,)), # Media 0, varianza 1
                              ])

# Download and load the training  data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)


In [None]:
dataiter = iter(trainloader)   # To iterate through the dataset

images, labels = dataiter.next()
print(type(images))
print(images.shape)
print(labels.shape)


Let's show some image from that batch ...

In [None]:
plt.imshow(images[1].numpy().reshape([28,28]), cmap='viridis')
plt.colorbar()

We also create a validation set.

In [None]:
import copy

validloader = copy.deepcopy(trainloader)  # Creates a copy of the object 

#We take the first 45k images for training
trainloader.dataset.data = trainloader.dataset.data[:45000,:,:]
trainloader.dataset.targets = trainloader.dataset.targets[:45000]

#And the rest for validation
validloader.dataset.data = validloader.dataset.data[45000:,:,:]
validloader.dataset.targets = validloader.dataset.targets[45000:]

> **Exercise:** Train an MLP network based on four dense layers of 256, 128, 64 and 10 hidden units respectively. Calculate the accuracy in training and test sets after training only 3 epochs. If you notice overfitting problems, introduce whatever mechanisms you consider to reduce it.

In [None]:
#YOUR CODE HERE

## Activation Maximization

Once our network has been trained, the goal is to find which features at the input of the network (pixels in our case) are more important for the classifier to decide one class or another at the output. That is, choose between one digit or another at the output. 

The steps we follow are the following:

- We fix the parameters that we have just trained.

- We define a new set of trainable parameters that will emulate the input of the network and we will opbtimize to have a high confidence in a specific digit.

We implement the first step with the following code:

In [None]:
for param in my_MLP_drop.parameters(): #my_MLP_drop is the name of your neural network!
    param.requires_grad = False

> **Exercise:** Complete the following code, in which we define the optimization problem to find the image at the input of the network that provides a given confidence in a given digit

In [None]:
class Optimize_NN_Input(nn.Module):

    def __init__(self,NN_trained,lr,img0): 
      
        # NN_trained is the trained network, as argument

        # img0 is a radomn initialization for the input image
        
        super().__init__()
        
        # Self.input --> Image to be optimized

        self.input = nn.Parameter(img0,requires_grad = True)

        self.NN_trained = NN_trained

        self.lr = lr
        
        # We use mean squared error to minimize the difference between the desired probability and the 
        # obtained for self.input to input

        self.criterion = nn.MSELoss() 

        self.optim = optim.Adam(self.parameters(), self.lr)

        self.loss_during_training = []

    def forward(self):

        # In the forward method we only evaluate the log-probabilities given self.input!

        logprobs = self.NN_trained.forward(#YOUR CODE HERE)
        
        return logprobs

    def trainloop(self,category,true_prob,sgd_iter):

        # Category is the digit we are going to look at

        # true_prob is the desired probability, which we pass to logarithm
        
        true_log_prob = torch.log(#YOUR CODE HERE)

        # sgd_iter is the number of iterations
        
        for i in range(sgd_iter):

            # Reset the gradients
            #YOUR CODE HERE 
            
            # Compute the network output
            #YOUR CODE HERE
            
            # We calculate the difference between the desired probability (true_log_prob) and the obtained one
            # for the digit
            
            loss = self.criterion(logprobs[0,#YOUR CODE HERE].view(true_log_prob.shape),#YOUR CODE HERE)
  

            self.loss_during_training.append(loss)

            # Compute gradients
            #YOUR CODE HERE
          
            # Optimize
            #YOUR CODE HERE

Let's instanciate an object of the previous class for a random initialization (independent Gaussian distribution with mean zero and standard deviation 0.01)

In [None]:
opt_input = Optimize_NN_Input(NN_trained=my_MLP_drop,lr=5e-3,img0=torch.randn(1,28**2)*0.01)

> **Exercise:** Obtain the probability distribution at the output of the classifier for the chosen initialization. Remember that the classifier uses a `logsoftmax` on output. Discuss the result.

In [None]:
#YOUR CODE HERE

> **Exercise:**: Optimize the input representation for 50 iterations to obtain 90% confidence in the digit 3. Plot the loss function.

In [None]:
#YOUR CODE HERE

> **Exercise:** Compare the probabilities at the output of the classifier before and after optimizing at the input

In [None]:
#YOUR CODE HERE

> **Exercise:** Display the optimized input image. Does it roughly match the expected digit? Discuss the results.

In [None]:
#YOUR CODE HERE

> **Exercise:** Repeat the results for all digits and discuss those results.

In [None]:
#YOUR CODE HERE

# Explainability at input using CNNs

Although with dense networks the performance of the classifier is very good, we are going to check how a CNN classifier manages to build the solution based on patterns in the most interpretable input images, since it takes into account the correlation between pixels.

> **Exercise:** Implement a classifier for MNIST based on the CNN LeNet-5 seen in one of the class examples. Note that when working with 32x32 images, the dimension of the maps at the output of the second CNN is not 5x5, as indicated in the figure, but 4x4.
>
> Train the classifier 3 epochs and get the accuracy in train and test sets.

In [None]:
Image(url= "https://ichi.pro/assets/images/max/724/0*H9_eGAtkQXJXtkoK")

In [None]:
#YOUR CODE HERE

> **Exercise:** Using the function `Optimize_NN_Input`, perform the activation maximization analysis now for the trained CNN. Don't forget to disable gradients with respect to network parameters. Initialize the input image with the same independent Gaussian distribution with zero mean and standard deviation 0.01. Note that this input image now has dimensions [1,1,28,28] and optimize against the input image for at least 500 iterations (with CNNs optimization is slower).
>
> Visualize the input picture for each digit and discuss the results. 

In [None]:
#YOUR CODE HERE