In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from captum.attr import IntegratedGradients
import random
import numpy as np
import os
import logging

In [2]:
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")


In [3]:
# 4. Define the 2-layer neural network model
class Net(nn.Module):
    def __init__(self, num_in, num_hidden, num_out):
        super().__init__()
        set_seed(0)
        self.num_in = num_in
        self.num_hidden = num_hidden
        self.num_out = num_out
        self.lin1 = nn.Linear(num_in, num_hidden)
        self.lin2 = nn.Linear(num_hidden, num_hidden)
        self.lin3 = nn.Linear(num_hidden, num_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input):
        lin1 = F.relu(self.lin1(input))
        lin2 = F.relu(self.lin2(lin1))
        lin3 = self.lin3(lin2)
        return self.softmax(lin3)



In [4]:
# create input and baseline
num_in = 50
inp = torch.arange(0.0, 1.0, 0.02, requires_grad=True).unsqueeze(0)
baseline = torch.zeros_like(inp, requires_grad=True)

In [5]:
def compeletness_test(model, inp, baseline, attribution, target_class_index, tolerance = 1e-3):
    f_inp = model(inp)
    f_baseline = model(baseline)
    
    diff = (f_inp[0][target_class_index] - f_baseline[0][target_class_index]).sum()
    attr_sum = attribution.sum()
    print(f"sum of attributions {attr_sum.item()}")
    print(f"difference of network output at input as baseline {diff.item()}")
    print(f"approximation error {torch.abs((diff-attr_sum)).item()}")
    
    assert torch.abs(attr_sum - diff) <= tolerance, "failed to pass completness axiom of integrated gradients"
    
    print(f"completness test: passed")

In [6]:
# captum IG method
model = Net(num_in, 20, 2)
target_class_index = 1

# applying integrated gradients on the SoftmaxModel and input data point
ig = IntegratedGradients(model)
attributions, approximation_error = ig.attribute(inp, target=target_class_index,
                                    return_convergence_delta=True)

compeletness_test(model, inp, baseline, attributions, target_class_index)

Random seed set as 0
sum of attributions 0.004937011425495175
difference of network output at input as baseline 0.004918217658996582
approximation error 1.879376649859335e-05
completness test: passed


In [7]:
from integrated_gradients_helmholtz import run_integrated_jacobian_scanvi

# Helmholtz method
model = Net(num_in, 20, 2)
num_in = 50
inp = torch.arange(0.0, 1.0, 0.02).unsqueeze(0)
baseline = torch.zeros_like(inp)

batches = [{"X":inp, "batch":1}]
target_class_index = 1
# applying integrated gradients on the SoftmaxModel and input data point
ig_helmholtz = run_integrated_jacobian_scanvi(model, batches, n_steps=50)
compeletness_test(model, inp, baseline, ig_helmholtz[..., target_class_index], target_class_index)

Random seed set as 0
sum of attributions 0.0018184136133641005
difference of network output at input as baseline 0.004918217658996582
approximation error 0.0030998040456324816


AssertionError: failed to pass completness axiom of integrated gradients

In [8]:
from integrated_gradients_paper import integrated_gradients, get_gradients_func
# paper method
num_in = 50
model = Net(num_in, 20, 2)
inp = torch.arange(0.0, 1.0, 0.02, requires_grad=True).unsqueeze(0)
baseline = torch.zeros_like(inp, requires_grad=True)
target_class_index = 1
# applying integrated gradients on the SoftmaxModel and input data point
ig_paper = integrated_gradients(
    inp, 
    target_class_index,
    get_gradients_func,
    baseline,
    model,
    steps=50)

compeletness_test(model, inp, baseline, ig_paper, target_class_index)

Random seed set as 0
sum of attributions 0.004963119514286518
difference of network output at input as baseline 0.004918217658996582
approximation error 4.4901855289936066e-05
completness test: passed


In [9]:
ig_paper

tensor([[ 0.0000e+00,  9.8100e-05, -2.0390e-04, -1.5306e-04,  3.2287e-04,
         -1.8544e-04, -3.9908e-05, -1.4650e-04,  4.9033e-04, -2.1914e-04,
          9.2016e-04,  1.0996e-03, -5.3383e-04,  2.6344e-04,  8.8657e-04,
         -1.5459e-03, -9.0614e-05, -5.8357e-04, -8.4320e-04, -1.4649e-04,
         -5.0780e-05, -1.9186e-03, -1.7589e-03,  1.1854e-03, -7.0733e-04,
          7.9240e-04,  1.7168e-03,  2.7485e-03,  3.1179e-04, -1.1015e-03,
         -1.7919e-04, -1.8624e-04,  8.9024e-04,  9.2335e-04,  2.5725e-04,
          1.6627e-03,  1.2797e-03, -3.3386e-04, -1.5338e-04,  6.7018e-04,
         -3.7697e-05,  3.2621e-03, -1.9959e-05, -2.8561e-04, -2.0533e-03,
          8.0743e-04,  2.3874e-03, -1.7385e-03, -1.7422e-03, -1.0547e-03]])

In [10]:
attributions

tensor([[ 0.0000e+00,  9.6289e-05, -1.9685e-04, -1.5218e-04,  3.1501e-04,
         -1.8034e-04, -4.1268e-05, -1.3794e-04,  4.8004e-04, -2.1699e-04,
          9.1168e-04,  1.0814e-03, -5.3959e-04,  2.7289e-04,  8.7898e-04,
         -1.5232e-03, -9.4753e-05, -5.5742e-04, -8.6005e-04, -1.4349e-04,
         -5.4616e-05, -1.9041e-03, -1.7386e-03,  1.1762e-03, -6.7799e-04,
          7.8964e-04,  1.6698e-03,  2.6912e-03,  3.6027e-04, -1.0905e-03,
         -1.6895e-04, -1.8129e-04,  8.5915e-04,  9.2207e-04,  2.6397e-04,
          1.6563e-03,  1.2744e-03, -3.6178e-04, -1.2589e-04,  6.5605e-04,
         -6.6314e-06,  3.1829e-03, -2.0143e-05, -2.3274e-04, -2.0229e-03,
          7.2358e-04,  2.3485e-03, -1.7478e-03, -1.6721e-03, -1.0231e-03]],
       dtype=torch.float64, grad_fn=<MulBackward0>)

In [11]:
ig_helmholtz[..., 1]

tensor([[-0.0000e+00,  3.4380e-05, -1.0807e-04, -6.4972e-05,  1.5246e-04,
         -1.1495e-04, -1.8439e-06, -1.3310e-04,  2.4499e-04, -8.6860e-05,
          5.0418e-04,  4.1214e-04, -3.3020e-04,  3.1484e-04,  4.8420e-04,
         -6.8562e-04, -1.4832e-05, -4.4818e-05, -1.1595e-04,  4.5136e-05,
         -3.9720e-05, -1.0080e-03, -9.8241e-04,  2.1138e-04, -6.1308e-04,
          3.4862e-04,  1.3197e-03,  8.1323e-04,  3.7401e-04, -5.5644e-04,
          4.1998e-04,  2.6268e-04,  6.1181e-04,  4.2450e-04,  1.3267e-04,
          5.8322e-04,  3.0566e-04, -7.8697e-04, -3.9634e-04,  5.6278e-04,
         -2.6588e-04,  1.5068e-03, -3.4203e-04,  3.2421e-04, -1.6146e-03,
          8.0128e-04,  1.3107e-03, -5.3677e-06, -1.4235e-03, -9.5160e-04]])

In [12]:
torch.sum(torch.abs(ig_paper - ig_helmholtz[..., 1]))

tensor(0.0226)

In [13]:
torch.sum(torch.abs(attributions - ig_helmholtz[..., 1]))

tensor(0.0220, dtype=torch.float64, grad_fn=<SumBackward0>)

In [14]:
torch.sum(torch.abs(attributions - ig_paper))

tensor(0.0010, dtype=torch.float64, grad_fn=<SumBackward0>)