Setting up the environment

In [None]:
# Clone the repository
!git clone https://github.com/Ggenoni/SENN.git

In [None]:
# Install Miniconda
!wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
!bash Miniconda3-latest-Linux-x86_64.sh -b -p /usr/local/miniconda
import os
os.environ["PATH"] += ":/usr/local/miniconda/bin"

In [None]:
!conda --version

In [None]:
%cd SENN

In [None]:
!conda env create -f environment.yml

In [None]:
!source /usr/local/miniconda/bin/activate senn && which python

In [None]:
# ==> After running this cell in Colab, restart the runtime changing kernell <==

!source /usr/local/miniconda/bin/activate senn
!pip install ipykernel
!python -m ipykernel install --user --name=senn --display-name "Python (senn)"

In [None]:
# Add the style file
!mkdir -p ~/.config/matplotlib/stylelib/
!echo -e "axes.titlesize: 18\naxes.labelsize: 14\nfigure.dpi: 100" > ~/.config/matplotlib/stylelib/seaborn-paper.mplstyle


Import libraries  
The first part of the code is adapte from https://github.com/AmanDaVinci/SENN, in particular report.ipynb.

In [None]:
import json
import torch
import numpy as np
import torch.nn as nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from pathlib import Path
from types import SimpleNamespace
from importlib import import_module

%matplotlib inline

In [None]:
from senn.models.losses import *
from senn.models.parameterizers import *
from senn.models.conceptizers import *
from senn.models.aggregators import *
from senn.models.senn import SENN, DiSENN

In [None]:
from senn.datasets.dataloaders import get_dataloader
from senn.utils.plot_utils import show_explainations, show_prototypes, plot_lambda_accuracy, get_comparison_plot

Utility functions

In [None]:
def get_config(filename):
    config_path = Path('configs')
    config_file = config_path / filename
    with open(config_file, 'r') as f:
        config = json.load(f)
    
    return SimpleNamespace(**config)

In [None]:
def load_checkpoint(config):
    model_file = Path('results') / config.exp_name / "checkpoints" / "best_model.pt" 
    return torch.load(model_file, config.device)

In [None]:
def accuracy(model, dataloader, config):
    accuracies = []
    model.eval()
    with torch.no_grad():
        for i, (x, labels) in enumerate(dataloader):
            x = x.float().to(config.device)
            labels = labels.long().to(config.device)
            y_pred, (concepts, relevances), _ = model(x)
            accuracies.append((y_pred.argmax(axis=1) == labels).float().mean().item())
    accuracy = np.array(accuracies).mean()
    print(f"Test Mean Accuracy: {accuracy * 100: .3f} %")

Load MNIST data and config

In [None]:
mnist_config = get_config("config.json")
#mnist_config.device = "cpu"

In [None]:
_, _, mnist_test_dl = get_dataloader(mnist_config)

Load our trained model

In [None]:
conceptizer = ConvConceptizer(**mnist_config.__dict__)
parameterizer = ConvParameterizer(**mnist_config.__dict__)
aggregator = SumAggregator(**mnist_config.__dict__)

mnist_SENN = SENN(conceptizer, parameterizer, aggregator)

In [None]:
mnist_checkpoint = load_checkpoint(mnist_config)
mnist_SENN.load_state_dict(mnist_checkpoint['model_state'])

Test accuracy

In [None]:
accuracy(mnist_SENN, mnist_test_dl, mnist_config)

Show explanations

In [None]:
show_explainations(mnist_SENN, mnist_test_dl, 'mnist')

In [None]:
show_prototypes(mnist_SENN, mnist_test_dl, 'activation')

==> TO BE DONE  <==  
Integrated gradients  
LIME  

AI explainability