# Explainable AI

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('../../fmriDEEP'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import torch
import numpy as np
from _core.networks.ConvNets import Simple2dCnnClassifier
from _utils.tools import compute_accuracy
from torch.utils.data import DataLoader
from zennit import composites
from torchvision.datasets import MNIST, FashionMNIST
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt

# this variable contains information whether a GPU can be used for training. If not, we automatically use the CPU.
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# set the random seed for reproducibility
def set_random_seed(seed):
    import random 
    
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    g = torch.Generator() # can be used in pytorch dataloaders for reproducible sample selection when shuffle=True
    g.manual_seed(seed)
    
    return g

g = set_random_seed(42)

In [4]:
def train_network(model, n_epochs, lr=.01, device=torch.device("cpu"), save_name=None):
    #set some variables here, such that we can create pretty plots
    train_loss = np.zeros(n_epochs)
    test_loss = np.zeros_like(train_loss)
    train_acc = np.zeros_like(train_loss)
    test_acc = np.zeros_like(train_loss)

    # loop for the above set number of epochs
    for epoch in range(0, n_epochs):

        # THIS IS WHERE THE MAGIC HAPPENS
        # calling the model.fit() function will execute the 'standard_train' function as defined above.
        train_loss[epoch], train_stats = model.fit(dl_train, lr=lr, device=device)
        train_acc[epoch] = compute_accuracy(train_stats[:, -1], train_stats[:, -2])

        # for validating or testing set the network into evaluation mode such that layers like dropout are not active
        with torch.no_grad():
            test_loss[epoch], test_stats = model.fit(dl_test, device=device, train=False)
            test_acc[epoch] = compute_accuracy(test_stats[:, -1], test_stats[:, -2])

        print('epoch=%03d, train_loss=%1.3f, train_acc=%1.3f, test_loss=%1.3f, test_acc=%1.3f' % 
             (epoch, train_loss[epoch], train_acc[epoch], test_loss[epoch], test_acc[epoch]))

    model.save(save_name, save_full=True)

In [None]:
training_data = FashionMNIST(root="data",train=True,download=True,transform=ToTensor())
dl_train = DataLoader(training_data, batch_size=256, shuffle=True)
test_data = FashionMNIST(root="data",train=False,download=True,transform=ToTensor())
dl_test = DataLoader(test_data, batch_size=256, shuffle=False)

In [None]:
plt.imshow(test_data.data[0, :, :], cmap='Greys')

In [None]:
model = Simple2dCnnClassifier((28, 28), 10)

In [None]:
train_network(model, 10)

In [None]:
import copy
from zennit.attribution import Gradient

shape = (1, 1, 28, 28)

composite_kwargs = {
    'low': 0 * torch.ones(*shape, device=torch.device("cpu")),  # the lowest and ...
    'high': 1 * torch.ones(*shape, device=torch.device("cpu")),  # the highest pixel value for ZBox
}

test = composites.COMPOSITES['epsilon_gamma_box'](**composite_kwargs)


for param in model.parameters():
    param.requires_grad = False


pick_img = 1
    
fig, axes = plt.subplots(1, 10, figsize=(20, 8), sharex=True, sharey=True)
#with Gradient(model=model, composite=test) as modified_model:
with test.context(model) as modified_model:
    ctr = 0
    for i, (data, target) in enumerate(dl_test):
        data_with_grad = data.clone()
        data_with_grad.requires_grad_()
        
        output_relevance = torch.eye(10, device=torch.device("cpu"))[target]

        #out, attribution = modified_model(data_with_grad)
        out = modified_model(data_with_grad)
        predicted = np.argmax(out.detach().numpy(), axis=1)
        torch.autograd.backward(out, output_relevance)
        axes[ctr].imshow(data[pick_img, :, :].squeeze().squeeze().cpu().numpy(), cmap='Greys')
        #axes[ctr].imshow(attribution[pick_img,:,:].squeeze().squeeze().cpu().numpy(), cmap='coolwarm', alpha=.5)
        axes[ctr].imshow(data_with_grad.grad[pick_img,:,:].squeeze().squeeze().cpu().numpy(), cmap='coolwarm', alpha=.5)
        axes[ctr].set_title(f'{predicted[pick_img]}-{target[pick_img]}')
        
        
        ctr += 1
        if ctr > 9:
            break

## Explain brain data

In [5]:
import nibabel as nib
from _utils.train_fns import standard_train
from _core.networks.ConvNets import BrainStateClassifier3d
from _core.datasets.NiftiDataset import NiftiDataset
import _utils.tools as utils

In [6]:
labels = ['handleft', 'handright', 'footleft', 'footright', 'tongue']
dl_train = DataLoader(NiftiDataset(
    'data/brain_data/train', labels, 150, DEVICE, transform=utils.ToTensor()), 
    batch_size=4, shuffle=True, generator=g
)
dl_test = DataLoader(NiftiDataset(
    'data/brain_data/test', labels, 20, DEVICE, transform=utils.ToTensor()), 
    batch_size=4, shuffle=True, generator=g
)

In [11]:
model = BrainStateClassifier3d((91,109,91), len(labels)).to(DEVICE)

In [12]:
model.training

True

In [None]:
train_network(model, 100, lr=.00001, device=DEVICE, save_name='motor-mapper')

In [19]:
model = torch.load('motor-mapper/model.pth').to(torch.device("cpu"))

In [10]:
shape = (1, 1, 91, 109, 91)
composite_kwargs = {
    'low': 0 * torch.ones(*shape, device=torch.device("cpu")),  # the lowest and ...
    'high': 1 * torch.ones(*shape, device=torch.device("cpu")),  # the highest pixel value for ZBox
}

test = composites.COMPOSITES['epsilon_gamma_box'](**composite_kwargs)


for param in model.parameters():
    param.requires_grad = False

In [None]:
# create the composite context outside the main loop, such that the canonizers and hooks do not
# need to be registered and removed for each step.
for j in range(len(labels)):
    dl = DataLoader(
        NiftiDataset('data/brain_data/test', [labels[j]], 20, torch.device("cpu"), 3),
        batch_size=1, shuffle=False, num_workers=0)

    avg = np.zeros((91, 109, 91))
    indi = np.zeros((91, 109, 91, len(dl)))
    with test.context(model) as modified_model:
        for i, (volume, target) in enumerate(dl):
            # we use data without the normalization applied for visualization, and with the
            # normalization applied as the model input
            data_norm = (volume.float().to(torch.device("cpu")))
            data_norm.requires_grad_()

            # one-hot encoding of the target labels of size (len(target), 1000)
            output_relevance = torch.eye(model.config['n_classes'], device=torch.device("cpu"))[target]

            out = modified_model(data_norm)
            # a simple backward pass will accumulate the relevance in data_norm.grad
            torch.autograd.backward((out,), (output_relevance,))
            indi[:, :, :, i] = data_norm.grad.squeeze().squeeze().cpu().numpy()
            avg += indi[:, :, :, i]

    avg /= len(dl)
    utils.save_in_mni(indi, os.path.join('motor-mapper', 'lrp_%s.nii.gz' % labels[j]))
    utils.save_in_mni(avg, os.path.join('motor-mapper', 'lrp_avg_%s.nii.gz' % labels[j]))

In [15]:
model.config

{'channels': [1, 8, 16, 32, 64],
 'kernel_size': 5,
 'pooling_kernel': 2,
 'lin_neurons': [128, 64],
 'dropout': 0.5,
 'input_dims': (91, 109, 91),
 'n_classes': 5,
 'train_fn': <function _utils.train_fns.standard_train(model, train_data: torch.utils.data.dataloader.DataLoader, loss_fn=CrossEntropyLoss(), optimizer=<class 'torch.optim.adam.Adam'>, lr: float = 1e-05, device: torch.device = device(type='cpu'), train=True, **optimizer_kwargs) -> Tuple[numpy.ndarray, numpy.ndarray]>,
 'last_cnn_dims': [5, 6, 5]}