In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients
import random
import numpy as np
import os
import logging

In [3]:
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

In [4]:
# 4. Define the 2-layer neural network model
class Net(nn.Module):
    def __init__(self, num_in, num_hidden, num_out):
        super().__init__()
        set_seed(0)
        self.num_in = num_in
        self.num_hidden = num_hidden
        self.num_out = num_out
        self.lin1 = nn.Linear(num_in, num_hidden)
        self.lin2 = nn.Linear(num_hidden, num_hidden)
        self.lin3 = nn.Linear(num_hidden, num_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input):
        lin1 = F.relu(self.lin1(input))
        lin2 = F.relu(self.lin2(lin1))
        lin3 = self.lin3(lin2)
        return self.softmax(lin3)

def get_model(num_in = 50, num_classes = 3):
    model = Net(num_in, 20, num_classes)
    return model

def abs_diff(a,b):
    return torch.sum(torch.abs(a - b)).item()

In [5]:
dim = (5,20)
step = 1.0/(100.0)
inp = torch.arange(0.0, 1.0, step, requires_grad=True).unsqueeze(0)
inp = inp.view(dim)
baseline = torch.zeros_like(inp, requires_grad=True)

inp.shape, baseline.shape

(torch.Size([5, 20]), torch.Size([5, 20]))

In [6]:
def _compeletness_test(model, inp, baseline, attribution, target_class_index, tolerance = 1e-3):
    f_inp = model(inp)
    f_baseline = model(baseline)
    
    diff = f_inp - f_baseline
    num_predictions = diff.shape[0]
    # loop over all predictions 
    for prediction_id in range(num_predictions):
        diff_sum = diff[prediction_id][target_class_index]
        attribution_sum = attribution[prediction_id].sum()
        
        if (np.isclose(attribution_sum.item(), diff_sum.item(), atol=tolerance)):
            print(f"completness test: passed", abs_diff(diff_sum, attribution_sum))
        else:
            print(f"completness test: xxfailedxx", abs_diff(diff_sum, attribution_sum))


In [7]:

def test_captum(num_classes = 3):
    # captum IG method
    model = get_model(20,3)
    attrs = []
    for class_idx in range(num_classes):
        print(f'test for class id {class_idx} out of {num_classes} classes')
        ig = IntegratedGradients(model)
        attributions, approximation_error = ig.attribute(inp, target=class_idx,
                                            return_convergence_delta=True)
        print('<<<',attributions.shape)
        attrs.append(attributions)
        _compeletness_test(model, inp, baseline, attributions, class_idx)
        print('-'*20)
    return attrs
captum_attrs = test_captum()

Random seed set as 0
test for class id 0 out of 3 classes
<<< torch.Size([5, 20])
completness test: passed 1.2736224113637975e-05
completness test: passed 4.937518793686653e-05
completness test: passed 8.622671248021373e-05
completness test: passed 0.00014114316453869673
completness test: passed 0.0002859088017215534
--------------------
test for class id 1 out of 3 classes
<<< torch.Size([5, 20])
completness test: passed 3.3146861255118395e-05
completness test: passed 3.483854318950003e-05
completness test: passed 5.550384621842014e-05
completness test: passed 0.00014157511015491786
completness test: passed 0.0004490187872618692
--------------------
test for class id 2 out of 3 classes
<<< torch.Size([5, 20])
completness test: passed 2.0440318786720153e-05
completness test: passed 1.4566380163906142e-05
completness test: passed 3.075261075933394e-05
completness test: passed 4.0210930551953794e-07
completness test: passed 0.00016308114762789497
--------------------


In [8]:
from integrated_gradients_paper import integrated_gradients, get_gradients_func


def test_paper(num_classes = 3):
    # captum IG method
    model = get_model(20,3)
    attrs = []
    for class_idx in range(num_classes):
        print(f'test for class id {class_idx} out of {num_classes} classes')
        # applying integrated gradients on the SoftmaxModel and input data point
        ig_paper = integrated_gradients(
            inp, 
            class_idx,
            get_gradients_func,
            baseline,
            model,
            steps=50)
        print('>>>',ig_paper.shape)
        attrs.append(ig_paper)
        _compeletness_test(model, inp, baseline, ig_paper, class_idx)
        print('-'*20)
    return attrs  
paper_attrs = test_paper()

Random seed set as 0
test for class id 0 out of 3 classes
grads torch.Size([51, 5, 20])
>>> torch.Size([5, 20])
completness test: passed 1.9069993868470192e-06
completness test: passed 5.526444874703884e-05
completness test: passed 5.243765190243721e-05
completness test: passed 7.273256778717041e-05
completness test: passed 0.0002926653251051903
--------------------
test for class id 1 out of 3 classes
grads torch.Size([51, 5, 20])
>>> torch.Size([5, 20])
completness test: passed 4.033674485981464e-06
completness test: passed 5.8634206652641296e-05
completness test: passed 7.259286940097809e-05
completness test: passed 0.00028870999813079834
completness test: passed 0.00018340349197387695
--------------------
test for class id 2 out of 3 classes
grads torch.Size([51, 5, 20])
>>> torch.Size([5, 20])
completness test: passed 5.910871550440788e-06
completness test: passed 3.398861736059189e-06
completness test: passed 2.0128674805164337e-05
completness test: passed 0.00021600071340799332


In [14]:
from integrated_gradients_helmholtz import run_integrated_jacobian_scanvi

# Helmholtz method
dim = (5,20)
step = 1.0/(100.0)
inp = torch.arange(0.0, 1.0, step).unsqueeze(0)
inp = inp.view(dim)
baseline = torch.zeros_like(inp)

batches = []
for i in range(5):
    batches.append({"X":inp[i].unsqueeze(0), "batch":i})
print(inp.shape)
def test_helmholtz(num_classes = 3):
    # captum IG method
    model = get_model(20,3)
    for class_idx in range(num_classes):
        print(f'test for class id {class_idx} out of {num_classes} classes')
        ig_helmholtz = run_integrated_jacobian_scanvi(model, batches, n_steps=50)
        print('>',ig_helmholtz.shape)
        _compeletness_test(model, inp, baseline, ig_helmholtz[..., class_idx], class_idx)
        print('-'*20)
test_helmholtz()

torch.Size([5, 20])
Random seed set as 0
test for class id 0 out of 3 classes
> torch.Size([15, 20, 20])
completness test: xxfailedxx 0.005629558581858873
completness test: xxfailedxx 0.011026146821677685
completness test: xxfailedxx 0.002407208550721407
completness test: xxfailedxx 0.003915990702807903
completness test: xxfailedxx 0.03486831486225128
--------------------
test for class id 1 out of 3 classes
> torch.Size([15, 20, 20])
completness test: xxfailedxx 0.005276530981063843
completness test: xxfailedxx 0.010742351412773132
completness test: xxfailedxx 0.006659789942204952
completness test: xxfailedxx 0.027215586975216866
completness test: xxfailedxx 0.0380849689245224
--------------------
test for class id 2 out of 3 classes
> torch.Size([15, 20, 20])
completness test: xxfailedxx 0.010305995121598244
completness test: xxfailedxx 0.005571513902395964
completness test: xxfailedxx 0.01960938610136509
completness test: xxfailedxx 0.008095739409327507
completness test: xxfailedxx 