# Self Explaining Neural Networks
---

## Import Libraries

In [1]:
import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

## Config

In [2]:
in_dim = 2
out_dim = 1
h_dim = 2
bs = 10

## How to best compute the Jacobian in Pytorch Autograd?

In [160]:
bsize = 1
in_dim = 3
out_dim = 2

In [161]:
x = torch.randn(bsize, in_dim)
x = x.unsqueeze(1).repeat(1, out_dim, 1) # bs, out_dim, in_dim
x.requires_grad_(True)

tensor([[[0.9477, 0.1124, 0.1092],
         [0.9477, 0.1124, 0.1092]]], requires_grad=True)

In [162]:
x.shape

torch.Size([1, 2, 3])

In [163]:
w = nn.Parameter(torch.randn(in_dim, out_dim))
y = x@w
y.shape

torch.Size([1, 2, 2])

In [164]:
incoming_grad = torch.eye(out_dim).reshape(1, out_dim, out_dim).repeat(bsize, 1, 1)
y.backward(incoming_grad)
x.grad.data

tensor([[[ 0.7564, -0.4524,  0.4692],
         [-0.3818, -0.3083, -0.4885]]])

In [105]:
def jacobian(f, x, out_dim):
    bsize = x.size()[0]
    # (bs, in_dim) --repeated--> (bs, out_dim, in_dim)
    x = x.unsqueeze(1).repeat(1, out_dim, 1)
    x.requires_grad_(True)
    y = f(x)
    # for autograd of non-scalar outputs
    grad_matrix = torch.eye(out_dim).reshape(1,out_dim, out_dim).repeat(bsize, 1, 1)
    y.backward(grad_matrix)
    return x.grad.data

In [106]:
class Net(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(in_dim, out_dim)
    
    def forward(self, x):
        return torch.nn.functional.relu(self.fc1(x))

In [120]:
batch = 1
num_features = 5
num_outputs = 2
x = torch.randn(batch, num_features)
print(x.shape)

torch.Size([1, 5])


In [121]:
f = Net(num_features, num_outputs)

In [122]:
f(x).shape

torch.Size([1, 2])

In [123]:
J = jacobian(f, x, num_outputs)
J, J.shape

torch.Size([1, 2, 2])


(tensor([[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]), torch.Size([1, 2, 5]))

In [124]:
f.fc1.weight

Parameter containing:
tensor([[-0.0993,  0.3321, -0.1209, -0.2297, -0.1050],
        [-0.4040,  0.3239, -0.0428,  0.2237,  0.3944]], requires_grad=True)

## Generate Data

In [3]:
x = torch.randn((bs, in_dim))

In [4]:
y = torch.tensor([1,0,0,1,0,1,0,0,1,0])

## Conceptualizer

In [5]:
class Conceptualizer(nn.Module):
    
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x

## Parameterizer

In [6]:
class Parameterizer(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn((in_dim, h_dim)))
    
    def forward(self, x):
        return x@self.W

## Aggregator

In [7]:
class Aggregator(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn((h_dim, out_dim)))

    def forward(self, concepts, relevances):
        return (concepts + relevances)@self.W

## SENN

In [8]:
h = Conceptualizer()
t = Parameterizer()
g = Aggregator()

In [9]:
concepts = h(x)
relevances = t(x)
y_pred = g(concepts, relevances) 

## SENN as Linear Regression

In [216]:
bsize = 5
in_dim = 3
h_dim = in_dim # needs discussion
out_dim = 1

In [197]:
x = torch.randn((bsize, in_dim))
x

tensor([[-0.9139,  1.4504,  0.2090],
        [-0.0448,  0.3841, -1.2505],
        [ 0.5166,  1.4903, -0.2736],
        [ 0.9951, -0.9236,  1.1216],
        [-0.1194, -0.5551,  0.3706]])

In [198]:
y = torch.tensor(bsize*[1], dtype=torch.float32)
y

tensor([1., 1., 1., 1., 1.])

### Concepts

In [199]:
h = x

### Relevances

In [200]:
w = nn.Parameter(torch.randn(in_dim, h_dim))
t = x@w
t

tensor([[-2.0601,  2.3076, -2.7704],
        [-0.7508,  0.9061, -0.9987],
        [-1.9809,  0.9535, -3.1562],
        [ 1.6018, -2.4192,  1.9271],
        [ 0.7975, -0.5414,  1.2127]], grad_fn=<MmBackward>)

### Aggregates

In [201]:
w2 = nn.Parameter(torch.randn(h_dim, out_dim))
g = (h + t)@w2
g

tensor([[ 2.6248],
        [ 0.7869],
        [ 1.5148],
        [-2.2901],
        [-0.6830]], grad_fn=<MmBackward>)

### MSE Loss

In [205]:
mse_loss = nn.MSELoss()
L_y = mse_loss(g.squeeze(), y) 

### Local Bound Loss

$$\mathcal{L}_\theta := \| \nabla_{x} f(x) - \theta(x)^{T} J_{x}^{h}(x) \|$$

In [None]:
J_yx = jacobian(t, x, h_dim)
# for identity concepts
J_hx = torch.ones((bsize, h_dim))
L_t = (J_yx - t*J_hx.transpose()).sum()

### Total Loss

$$\mathcal{L}_y (f(x), y) + \lambda \mathcal{L}_\theta (f) + \xi \mathcal{L}_h (x, \hat{x})$$