# Explaining a Self-Explaining Neural Network
---

<img src="img/SENN.png" alt="SENN Architecture Diagram (Alvarez-Melis \& Jaakkola)]" style="width: 640px;"/>

## Import Libraries

In [1]:
import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

## Computing the Jacobian with Pytorch Autograd

**How Autograd builds a computation graph?**
* Pytorch autograd engine is based on two major classes: ```Tensor``` and ```Function```
* Using these together, the autograd engine builds a computation graph
* Each ```Tensor``` has a ```.grad_fn``` attribute that records how it was generated
* A user-defined ```Tensor``` (a leaf node on the computation graph) has its ```.grad_fn``` set to ```None```
* A ```Tensor``` generated from an operation like ```+``` or ```*``` has its ```.grad_fn``` set to that ```Function``` operation

**How to compute gradients in the Autograd computation graph?**
* If the ```Tensor``` we want to compute the gradient against is a scalar, we simply call ```.backward()``` on it
* If the ```Tensor``` we want to compute the gradient against is multi-dimensional, we pass a gradient value to the ```.backward()``` method

In [2]:
def jacobian_demo(in_dim, out_dim):
    '''
    A Jacobian Demo:
    Set out_dim = 1 for default behaviour
    Set out_dim > 1 for magic
    '''
    x = torch.ones(in_dim, requires_grad=True)
    w = torch.randn((in_dim, out_dim))
    y = x@w
    print(f"x= {x}")
    print(f"W= {w}")
    print(f"y= {y}")

    if out_dim == 1:
        y.backward() # equivalent to y.backward(torch.tensor(1.))
    else:
        y.backward(torch.ones(out_dim))

    print(f"dy/dx = {x.grad}")

In [3]:
jacobian_demo(in_dim=3, out_dim=1)

x= tensor([1., 1., 1.], requires_grad=True)
W= tensor([[1.7266],
        [0.2947],
        [0.8249]])
y= tensor([2.8462], grad_fn=<SqueezeBackward3>)
dy/dx = tensor([1.7266, 0.2947, 0.8249])


In [4]:
jacobian_demo(in_dim=3, out_dim=2)

x= tensor([1., 1., 1.], requires_grad=True)
W= tensor([[ 0.4594, -0.0022],
        [ 0.1305,  1.7321],
        [-1.6384,  0.9940]])
y= tensor([-1.0485,  2.7239], grad_fn=<SqueezeBackward3>)
dy/dx = tensor([ 0.4572,  1.8626, -0.6444])


**PyTorch Autograd behind the scenes**
* Every node in the computation graph recieves the gradient from the node above it
* Autograd computes the product of the Jacobian of the current node with the incoming gradient vector
* So the grad of a node is computed as: $\mathbf{v}^{T} \cdot \mathbf{J}$
* For a scalar node, the incoming gradient vector is just a ```1 x 1``` vector
* The nodes after a scalar node then recieve the Jacobian of the scalar node w.r.t. to a vector as a ```out_dim x 1``` vector

In [5]:
batch = 1
num_features = 3
num_outputs = 2
x = torch.randn(batch, num_features)
print(x.shape)
w = nn.Parameter(torch.randn(num_features, num_outputs))
net = lambda x: x@w

torch.Size([1, 3])


In [6]:
x = x.squeeze()
n = x.size()[0]
x = x.repeat(num_outputs, 1)
x.requires_grad_(True)
print(x.shape)
y = net(x)
print(y.shape)
# step inside this call to see
y.backward(torch.eye(num_outputs))

torch.Size([2, 3])
torch.Size([2, 2])


In [7]:
print(x.grad.shape)

torch.Size([2, 3])


## Jacobian Function

In [8]:
def jacobian(f, x, out_dim):
    input = x.clone().detach()
    bsize = input.size()[0]
    # (bs, in_dim) --repeated--> (bs, out_dim, in_dim)
    input = input.unsqueeze(1).repeat(1, out_dim, 1)
    input.requires_grad_(True)
    # can only compute Jacobian of inputs and outputs with 2 dimensions
    out = f(input).reshape(bsize, out_dim, out_dim)
    # for autograd of non-scalar outputs
    grad_matrix = torch.eye(out_dim).reshape(1,out_dim, out_dim).repeat(bsize, 1, 1)
    out.backward(grad_matrix)
    return input.grad.data

## Test the Jacobian

In [9]:
class Net(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(in_dim, out_dim)
    
    def forward(self, x):
        return torch.nn.functional.relu(self.fc1(x))

In [10]:
batch = 2
num_features = 5
num_outputs = 2
x = torch.randn(batch, num_features)
print(x.shape)

torch.Size([2, 5])


In [11]:
f = Net(num_features, num_outputs)

In [12]:
f(x).shape

torch.Size([2, 2])

In [13]:
J = jacobian(f, x, num_outputs)
J, J.shape

(tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.1231,  0.4033, -0.0072, -0.3862,  0.3884]],
 
         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [-0.1231,  0.4033, -0.0072, -0.3862,  0.3884]]]),
 torch.Size([2, 2, 5]))

## SENN as Linear Regression

In [14]:
bsize = 5
in_dim = 3
h_dim = in_dim # for the identity conceptizer case
out_dim = 1

In [15]:
x = torch.randn((bsize, in_dim))
x

tensor([[ 0.3792,  0.1336,  0.4417],
        [ 0.4731,  0.6780,  0.3843],
        [ 1.5343,  0.2343, -0.1518],
        [ 2.6796, -0.9424,  0.4785],
        [-1.2728,  0.2543, -1.0110]])

In [16]:
y = torch.tensor(bsize*[1], dtype=torch.float32)
y

tensor([1., 1., 1., 1., 1.])

### Concepts

In [17]:
Conceptualizer = lambda x: x.unsqueeze(-1)
h = Conceptualizer(x)
h.shape

torch.Size([5, 3, 1])

### Relevances

In [18]:
w = nn.Parameter(torch.randn(in_dim, h_dim))
Parameterizer = lambda x: (x@w).unsqueeze(-1) 
t = Parameterizer(x)
t.shape

torch.Size([5, 3, 1])

### Aggregates

In [19]:
Aggregator = lambda x: (x[0].squeeze()*x[1].squeeze()).sum(1)
g = Aggregator((t, x))
g.shape

torch.Size([5])

## SENN

In [20]:
def SENN(x):
    h = Conceptualizer(x)
    t = Parameterizer(x)
    g = Aggregator((t,x))
    return g, t, h

### MSE Loss

In [21]:
mse_loss = nn.MSELoss()
L_y = mse_loss(g, y) 

### Local Bound Loss

$$\mathcal{L}_\theta := \| \nabla_{x} f(x) - \theta(x)^{T} J_{x}^{h}(x) \|$$

In [22]:
J_hx = jacobian(Conceptualizer, x, h_dim)
J_hx.shape

torch.Size([5, 3, 3])

In [23]:
t.shape

torch.Size([5, 3, 1])

In [24]:
def g_SENN(x):
    g, _, _ = SENN(x)
    return g

In [25]:
J_yx = jacobian(g_SENN, x, out_dim)
J_yx.shape

torch.Size([5, 1, 3])

In [26]:
torch.bmm(t.permute(0,2,1), J_hx).shape

torch.Size([5, 1, 3])

In [27]:
L_t = (J_yx - torch.bmm(t.permute(0,2,1), J_hx))

In [28]:
L_t.shape

torch.Size([5, 1, 3])

### Total Loss

$$\mathcal{L}_y (f(x), y) + \lambda \mathcal{L}_\theta (f) + \xi \mathcal{L}_h (x, \hat{x})$$

In [29]:
L = L_y + L_t.mean() + 0.

In [30]:
L

tensor(4.8127, grad_fn=<AddBackward0>)

___

## SENN with Deep Neural Networks

In [133]:
x = torch.randn((bs, in_dim))

In [134]:
y = torch.tensor([1,0,0,1,0,1,0,0,1,0])

## Conceptualizer

In [135]:
class Conceptualizer(nn.Module):
    
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x

## Parameterizer

In [136]:
class Parameterizer(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn((in_dim, h_dim)))
    
    def forward(self, x):
        return x@self.W

## Aggregator

In [137]:
class Aggregator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn((h_dim, out_dim)))

    def forward(self, concepts, relevances):
        return (concepts + relevances)@self.W