# Explaining a Self-Explaining Neural Network
---

<img src="img/SENN.png" alt="SENN Architecture Diagram (Alvarez-Melis \& Jaakkola)]" style="width: 640px;"/>

## Import Libraries

In [1]:
import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

## Computing the Jacobian with Pytorch Autograd

**How Autograd builds a computation graph?**
* Pytorch autograd engine is based on two major classes: ```Tensor``` and ```Function```
* Using these together, the autograd engine builds a computation graph
* Each ```Tensor``` has a ```.grad_fn``` attribute that records how it was generated
* A user-defined ```Tensor``` (a leaf node on the computation graph) has its ```.grad_fn``` set to ```None```
* A ```Tensor``` generated from an operation like ```+``` or ```*``` has its ```.grad_fn``` set to that ```Function``` operation

**How to compute gradients in the Autograd computation graph?**
* If the ```Tensor``` we want to compute the gradient against is a scalar, we simply call ```.backward()``` on it
* If the ```Tensor``` we want to compute the gradient against is multi-dimensional, we pass a gradient value to the ```.backward()``` method

In [14]:
def jacobian_demo(in_dim, out_dim):
    '''
    A Jacobian Demo:
    Set out_dim = 1 for default behaviour
    Set out_dim > 1 for magic
    '''
    x = torch.ones(in_dim, requires_grad=True)
    w = torch.randn((in_dim, out_dim))
    y = x@w
    print(f"x= {x}")
    print(f"W= {w}")
    print(f"y= {y}")

    if out_dim == 1:
        y.backward() # equivalent to y.backward(torch.tensor(1.))
    else:
        y.backward(torch.ones(out_dim))

    print(f"dy/dx = {x.grad}")

In [16]:
jacobian_demo(in_dim=3, out_dim=1)

x= tensor([1., 1., 1.], requires_grad=True)
W= tensor([[ 0.2518],
        [ 1.2410],
        [-0.0681]])
y= tensor([1.4247], grad_fn=<SqueezeBackward3>)
dy/dx = tensor([ 0.2518,  1.2410, -0.0681])


In [17]:
jacobian_demo(in_dim=3, out_dim=2)

x= tensor([1., 1., 1.], requires_grad=True)
W= tensor([[-0.0038,  0.9529],
        [ 1.4555,  0.3204],
        [ 0.0197,  1.1247]])
y= tensor([1.4714, 2.3980], grad_fn=<SqueezeBackward3>)
dy/dx = tensor([0.9491, 1.7759, 1.1445])


**PyTorch Autograd behind the scenes**
* Every node in the computation graph recieves the gradient from the node above it
* Autograd computes the product of the Jacobian of the current node with the incoming gradient vector
* So the grad of a node is computed as: $\mathbf{v}^{T} \cdot \mathbf{J}$
* For a scalar node, the incoming gradient vector is just a ```1 x 1``` vector
* The nodes after a scalar node then recieve the Jacobian of the scalar node w.r.t. to a vector as a ```out_dim x 1``` vector

In [42]:
batch = 1
num_features = 3
num_outputs = 2
x = torch.randn(batch, num_features)
print(x.shape)
w = nn.Parameter(torch.randn(num_features, num_outputs))
net = lambda x: x@w

torch.Size([1, 3])


In [43]:
x = x.squeeze()
n = x.size()[0]
x = x.repeat(num_outputs, 1)
x.requires_grad_(True)
print(x.shape)
y = net(x)
print(y.shape)
y.backward(torch.eye(num_outputs))

torch.Size([2, 3])
torch.Size([2, 2])


In [44]:
print(x.grad.shape)

torch.Size([2, 3])


## Jacobian Function

In [190]:
def jacobian(f, x, out_dim):
    input = torch.tensor(x.clone().detach(), requires_grad=True)
    bsize = input.size()[0]
    # (bs, in_dim) --repeated--> (bs, out_dim, in_dim)
    input = input.unsqueeze(1).repeat(1, out_dim, 1)
    out = f(input)
    # for autograd of non-scalar outputs
    grad_matrix = torch.eye(out_dim).reshape(1,out_dim, out_dim).repeat(bsize, 1, 1)
    out.backward(grad_matrix)
    return input.grad.data

## Test the Jacobian

In [153]:
class Net(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(in_dim, out_dim)
    
    def forward(self, x):
        return torch.nn.functional.relu(self.fc1(x))

In [154]:
batch = 1
num_features = 5
num_outputs = 2
x = torch.randn(batch, num_features)
print(x.shape)

torch.Size([1, 5])


In [155]:
f = Net(num_features, num_outputs)

In [156]:
f(x).shape

torch.Size([1, 2])

In [None]:
J = jacobian(f, x, num_outputs)
J, J.shape

## Generate Data

In [133]:
x = torch.randn((bs, in_dim))

In [134]:
y = torch.tensor([1,0,0,1,0,1,0,0,1,0])

## Conceptualizer

In [135]:
class Conceptualizer(nn.Module):
    
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x

## Parameterizer

In [136]:
class Parameterizer(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn((in_dim, h_dim)))
    
    def forward(self, x):
        return x@self.W

## Aggregator

In [137]:
class Aggregator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn((h_dim, out_dim)))

    def forward(self, concepts, relevances):
        return (concepts + relevances)@self.W

## SENN

In [138]:
h = Conceptualizer()
t = Parameterizer()
g = Aggregator()

In [139]:
concepts = h(x)
relevances = t(x)
y_pred = g(concepts, relevances) 

## SENN as Linear Regression

In [240]:
bsize = 5
in_dim = 3
h_dim = in_dim # needs discussion
out_dim = 1

In [241]:
x = torch.randn((bsize, in_dim))
x.requires_grad_(True)
x

tensor([[ 0.0170, -1.0168,  0.6083],
        [ 0.2550,  1.6877,  0.5968],
        [ 0.3455,  2.0961,  1.8915],
        [-0.3178,  1.0152,  2.0694],
        [ 1.0242,  0.3378,  1.0517]], requires_grad=True)

In [242]:
y = torch.tensor(bsize*[1], dtype=torch.float32)
y

tensor([1., 1., 1., 1., 1.])

### Concepts

In [243]:
Conceptualizer = lambda x: x
h = Conceptualizer(x)

### Relevances

In [244]:
w = nn.Parameter(torch.randn(in_dim, h_dim))
Parameterizer = lambda x: x@w 
t = Parameterizer(x)
t

tensor([[ 0.3459,  1.0733, -0.4981],
        [-1.7992, -0.1439,  1.4994],
        [-3.1064,  0.8840,  2.2621],
        [-2.3626,  1.1694,  1.2552],
        [-1.0963,  1.4317,  1.0703]], grad_fn=<MmBackward>)

In [246]:
w

Parameter containing:
tensor([[-0.0312,  0.6309,  0.4697],
        [-0.7936, -0.5017,  0.6987],
        [-0.7571,  0.9081,  0.3359]], requires_grad=True)

### Aggregates

In [247]:
Aggregator = lambda x: (x[0]*x[1]).sum()
g = Aggregator((t, x))
g

tensor(8.8851, grad_fn=<SumBackward0>)

### MSE Loss

In [249]:
mse_loss = nn.MSELoss()
L_y = mse_loss(g, y) 

### Local Bound Loss

$$\mathcal{L}_\theta := \| \nabla_{x} f(x) - \theta(x)^{T} J_{x}^{h}(x) \|$$

In [250]:
t.shape

torch.Size([5, 3])

In [252]:
x_dummy = x.clone().detach()

In [253]:
J_hx = jacobian(Conceptualizer, x_dummy, h_dim).squeeze()

In [255]:
J_hx.shape

torch.Size([5, 3, 3])

In [256]:
g.backward()

In [257]:
x.grad

tensor([[-0.0104,  1.9950, -1.2299],
        [-0.4621, -0.7762,  3.0393],
        [-0.9064,  0.8796,  4.5394],
        [-0.7402,  2.3581,  3.1129],
        [-0.4212,  1.1842,  0.9550]])

In [238]:
L_t = (x.grad.data.squeeze() - t@torch.t(J_hx)).sum() # doesn't work for batches
# need to fix the matrix mult
# or change to denominator layout

### Total Loss

$$\mathcal{L}_y (f(x), y) + \lambda \mathcal{L}_\theta (f) + \xi \mathcal{L}_h (x, \hat{x})$$