# DiSENN: Self-Explaining Neural Networks with Disentanglement
---

## Import libraries

In [None]:
import os
import sys
sys.path.append('..')
sys.path.append(os.path.abspath(os.path.join('..', '..')))

In [None]:
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import numpy as np
from pathlib import Path
from types import SimpleNamespace
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
from models.losses import *
from models.parameterizers import ConvParameterizer
from models.conceptizers import ConvConceptizer, VaeConceptizer
from models.aggregators import SumAggregator

## Configuration

In [None]:
plt.style.use('seaborn-paper')

In [None]:
config = {
  "model_class": "DiSENN",
  "conceptizer": "VaeConceptizer",
  "pretrain_epochs": 1,
  "pre_beta": 1.0,
  "beta": 4.0,
  "concept_loss": "BVAE_loss",
  "robustness_loss": "mnist_robustness_loss",
  "train": True,
  "image_size": 28,
  "concept_dim": 1,
  "concept_visualization": None,
  "parameterizer": "ConvParameterizer",
  "cl_sizes": [1, 10, 20],
  "hidden_sizes": [320, 50],
  "num_concepts": 5,
  "num_classes": 10,
  "dropout": 0.5,
  "aggregator": "SumAggregator",
  "device": "cpu",
  "lr": 1e-3,
  "epochs": 5,
  "robust_reg": 1e-3,
  "concept_reg": 1,
  "print_freq": 50,
  "exp_name": "MNIST_DiSENN",
  "dataloader": "mnist",
  "data_path": "datasets/data/mnist_data",
  "batch_size" : 128,
  "eval_freq" : 100
}
config = SimpleNamespace(**config)

config = {
  "model_class": "DiSENN",
  "conceptizer": "VaeConceptizer",
  "pretrain_epochs": 3,
  "beta": 4.0,
  "concept_loss": "BVAE_loss",
  "robustness_loss": "mnist_robustness_loss",
  "train": True,
  "image_size": 28,
  "concept_dim": 1,
  "concept_visualization": None,
  "parameterizer": "ConvParameterizer",
  "cl_sizes": [1, 10, 20],
  "hidden_sizes": [320, 50],
  "num_concepts": 5,
  "num_classes": 10,
  "dropout": 0.5,
  "aggregator": "SumAggregator",
  "device": "cuda:0",
  "lr": 1e-3,
  "epochs": 5,
  "robust_reg": 1e-3,
  "concept_reg": 1,
  "print_freq": 50,
  "exp_name": "mnist_bvae_default",
  "dataloader": "mnist",
  "data_path": "datasets/data/mnist_data",
  "batch_size" : 128,
  "eval_freq" : 100
}
config = SimpleNamespace(**config)

config = {
  "model_class": "DiSENN",
  "conceptizer": "VaeConceptizer",
  "pretrain_epochs": 3,
  "beta": 4.0,
  "concept_loss": "BVAE_loss",
  "robustness_loss": "mnist_robustness_loss",
  "train": True,
  "image_size": 28,
  "concept_dim": 1,
  "concept_visualization": None,
  "parameterizer": "ConvParameterizer",
  "hidden_sizes": [320, 100],
  "num_concepts": 10,
  "num_classes": 10,
  "dropout": 0.5,
  "aggregator": "SumAggregator",
  "device": "cuda:0",
  "lr": 1e-3,
  "epochs": 5,
  "robust_reg": 1e-3,
  "concept_reg": 1,
  "print_freq": 50,
  "exp_name": "mnist_bvae_concepts10",
  "dataloader": "mnist",
  "data_path": "datasets/data/mnist_data",
  "batch_size" : 128,
  "eval_freq" : 100
}
config = SimpleNamespace(**config)

config = {
  "model_class": "DiSENN",
  "conceptizer": "VaeConceptizer",
  "pretrain_epochs": 3,
  "beta": 4.0,
  "concept_loss": "BVAE_loss",
  "robustness_loss": "mnist_robustness_loss",
  "train": True,
  "image_size": 28,
  "concept_dim": 1,
  "concept_visualization": None,
  "parameterizer": "ConvParameterizer",
  "hidden_sizes": [320, 200],
  "num_concepts": 20,
  "num_classes": 10,
  "dropout": 0.5,
  "aggregator": "SumAggregator",
  "device": "cpu",
  "lr": 1e-3,
  "epochs": 5,
  "robust_reg": 1e-3,
  "concept_reg": 1,
  "print_freq": 50,
  "exp_name": "mnist_bvae_concepts20",
  "dataloader": "mnist",
  "data_path": "datasets/data/mnist_data",
  "batch_size" : 128,
  "eval_freq" : 100
}
config = SimpleNamespace(**config)

# DiSENN
DiSENN (Self-Explaining Neural Networks with Disentanglement) is an extension of the Self-Explaining Neural Network proposed by [1]  

DiSENN incorporates a constrained variational inference framework on a 
SENN Concept Encoder to learn disentangled representations of the 
basis concepts as in [2]. The basis concepts are then independently
sensitive to single generative factors leading to better interpretability 
and lesser overlap with other basis concepts. Such a strong constraint 
better fulfills the "diversity" desiderata for basis concepts
in a Self-Explaining Neural Network.


References  
[1] Alvarez Melis, et al.
"Towards Robust Interpretability with Self-Explaining Neural Networks" NIPS 2018  
[2] Irina Higgins, et al. 
”β-VAE: Learning basic visual concepts with a constrained variational framework.” ICLR 2017. 

In [None]:
from models.senn import DiSENN

## Load Data

In [None]:
from datasets.dataloaders import get_dataloader
train_dl, val_dl, _ = get_dataloader(config)

## Training

In [None]:
from senn.trainer import DiSENN_Trainer

In [None]:
trainer = DiSENN_Trainer(config)
trainer.run()
trainer.finalize()

## Load Trained Model

In [None]:
conceptizer = eval(config.conceptizer)(**config.__dict__)
parameterizer = eval(config.parameterizer)(**config.__dict__)
aggregator = eval(config.aggregator)(**config.__dict__)

In [None]:
model = DiSENN(conceptizer, parameterizer, aggregator)

In [None]:
checkpoint_path = Path('../senn/results')
model_file = checkpoint_path / config.exp_name / "checkpoints" / "best_model.pt" 
checkpoint = torch.load(model_file, config.device)
print(f"Loading trained model: {model_file}")

In [None]:
model.load_state_dict(checkpoint['model_state'])

# Classification

In [None]:
xb, yb = next(iter(val_dl))
y_pred, explanations, x_reconstruct = model(xb)

In [None]:
plt.rcParams['figure.figsize'] = (2, 2)
plt.imshow(xb[-1][0], cmap='gray')

In [None]:
plt.imshow(x_reconstruct[-1][0].detach().numpy(), cmap='gray')

In [None]:
model = model.to(config.device)

In [None]:
accuracy = 0.
model.eval()
with torch.no_grad():
    for i, (x, labels) in enumerate(val_dl):
        x = x.float().to(config.device)
        labels = labels.long().to(config.device)
        y_pred, explanations, x_reconstructed = model(x)
        accuracy += (y_pred.argmax(axis=1) == labels).float().mean().item()
print(f"Test Mean Accuracy: {accuracy//i * 100} %")

In [None]:
model = model.to(torch.device('cpu'))

# Explanation

## Generate Prototypes from Disentangled Concepts

In [None]:
(yb==3).nonzero()

In [None]:
x = xb[32].cpu()
fname = "/digit3a.png"
model.explain(x, traversal_range=0.45,
              gridsize=(1,6), col_span=3, figure_size=(18,3), show=True,
              save_as="results/"+config.exp_name+fname, use_cdf=True)

In [None]:
x = xb[48].cpu()
fname = "/digit3b.png"
model.explain(x, traversal_range=0.45,
              gridsize=(1,6), col_span=3, figure_size=(18,3), show=True,
              save_as="results/"+config.exp_name+fname, use_cdf=True)

# Experiment: Balance of Performance Accuracy and Explanation Interpretability

## Load Trained Model at low KL Divergence
```
Accuracy:0.979 Classification Loss:1.092 Robustness Loss:0.000 Concept Loss:1.022 Recon Loss: 0.719 KL Div: 0.076
```

In [None]:
model_name = "Epoch[4]-Step[1700].pt"

In [None]:
conceptizer = eval(config.conceptizer)(**config.__dict__)
parameterizer = eval(config.parameterizer)(**config.__dict__)
aggregator = eval(config.aggregator)(**config.__dict__)

In [None]:
model = DiSENN(conceptizer, parameterizer, aggregator)

In [None]:
checkpoint_path = Path('results')
model_file = checkpoint_path / config.exp_name / "checkpoints" / model_name 
checkpoint = torch.load(model_file, config.device)
print(f"Loading trained model: {model_file}")

In [None]:
model.load_state_dict(checkpoint['model_state'])

# Classification

In [None]:
xb, yb = next(iter(val_dl))
y_pred, explanations, x_reconstruct = model(xb)

In [None]:
plt.rcParams['figure.figsize'] = (2, 2)
plt.imshow(xb[-1][0], cmap='gray')

In [None]:
plt.imshow(x_reconstruct[-1][0].detach().numpy(), cmap='gray')

In [None]:
model = model.to(config.device)

In [None]:
accuracy = 0.
model.eval()
with torch.no_grad():
    for i, (x, labels) in enumerate(val_dl):
        x = x.float().to(config.device)
        labels = labels.long().to(config.device)
        y_pred, explanations, x_reconstructed = model(x)
        accuracy += (y_pred.argmax(axis=1) == labels).float().mean().item()
print(f"Test Mean Accuracy: {accuracy//i * 100} %")

In [None]:
model = model.to(torch.device('cpu'))

# Explanation

## Generate Prototypes from Disentangled Concepts

In [None]:
(yb==3).nonzero()

In [None]:
x = xb[23].cpu()
fname = "/exp-digit3.png"
model.explain(x, traversal_range=0.45,
              gridsize=(1,6), col_span=3, figure_size=(18,3), show=True,
              save_as="results/"+config.exp_name+fname, use_cdf=True)

In [None]:
x = xb[27].cpu()
fname = "/exp-digit3b.png"
model.explain(x, traversal_range=0.45,
              gridsize=(1,6), col_span=3, figure_size=(18,3), show=True,
              save_as="results/"+config.exp_name+fname, use_cdf=True)