In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib

import cbig.osama2024.dataloader as dataloader
import cbig.osama2024.model as model
importlib.reload(model)

#from cbig.osama2024.model import osama

import cbig.osama2024.misc as misc
from cbig.osama2024.model import MODEL_DICT


# load data
import pickle
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split  

## ToySoftmax model

In [2]:


class ToySoftmaxModel(nn.Module):
    r"""
    Model architecture from:

    https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/
    """

    def __init__(self, num_in, num_hidden, num_out):
        super().__init__()
        self.num_in = num_in
        self.num_hidden = num_hidden
        self.num_out = num_out
        self.lin1 = nn.Linear(num_in, num_hidden)
        self.lin2 = nn.Linear(num_hidden, num_hidden)
        self.lin3 = nn.Linear(num_hidden, num_out)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input):
        lin1 = F.relu(self.lin1(input))
        lin2 = F.relu(self.lin2(lin1))
        lin3 = self.lin3(lin2)
        return self.softmax(lin3)

In [None]:
from captum.attr import IntegratedGradients
num_in = 40
input = torch.arange(0.0, num_in * 1.0, requires_grad=True).unsqueeze(0)

# 10-class classification model
model = ToySoftmaxModel(num_in, 20, 10)

# attribution score will be computed with respect to target class
target_class_index = 5

# applying integrated gradients on the SoftmaxModel and input data point
ig = IntegratedGradients(model)
attributions, approximation_error = ig.attribute(input, target=target_class_index,
                                    return_convergence_delta=True)

# The input and returned corresponding attribution have the

# same shape and dimensionality.

# output

# ...................

# attributions: (tensor([[ 0.0000,  0.0014,  0.0012,  0.0019,  0.0034,  0.0020, -0.0041,  
#           0.0085, -0.0016,  0.0111, -0.0114, -0.0053, -0.0054, -0.0095,  0.0097, -0.0170,
#           0.0067,  0.0036, -0.0296,  0.0244,  0.0091, -0.0287,  0.0270,  0.0073,
#          -0.0287,  0.0008, -0.0150, -0.0188, -0.0328, -0.0080, -0.0337,  0.0422,
#           0.0450,  0.0423, -0.0238,  0.0216, -0.0601,  0.0114,  0.0418, -0.0522]],
#        grad_fn=<MulBackward0>),)

# approximation_error (aka delta): 0.00013834238052368164

# assert attributions.shape == input.shape

In [None]:
print(attributions)
print(approximation_error)
assert attributions.shape == input.shape