## Motivation for Signal Estimation and Attribution

Once we use UntangleAI's uncertainty estimation, we would like to know why the model's decision is uncertain regarding those test points. Some relevant questions one could have at this point in time are
 - What are the salient fetuares in the input that model is relying upon, to make its decision?
 - What are the relative importance given to these fetures by the model to frame its decision?
 - How robust is the model for adversarial attacks? Is it possible to change input slightly to get drastically different model decisions?
 
 As part of UntangleAI Signal Estimation service, we provide these insights about the model and a way to visulaize the decision process of the model for computer vision applications. This service tries to saperate out signal componenet from the noise in the trianing data and learns for each class what signal/feautres model is learning and what else it discards as noise

## Step 0 Training a CNN for recognizing MNIST dataset

This step is optional. If you would like to train a CNN network to recognize MNIST dataset you can refer to [this tutorial](/tutorials/mnist_model_training) which trains a model for 10 epochs and saves the trained weights into lenet_mnist_model.h5

Or you can download the trained weights from [here](https://untanglemodels.s3.amazonaws.com/lenet_mnist_model.h5)

## Step 1 - Signal Estimation for each class

During this phase, we take the trained model and training data set batched by class and estimate signals that model has learned during its training phase for each class and generate signal estimation statistics for each class and save them 

In [None]:
# Required imports

import os
import torch
import torch.nn as nn
from tqdm import tqdm
torch.set_printoptions(precision=8)
from untangle import UntangleAI

torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic=True

Load the model from the trained or downloaded checkpoint file.

In [None]:
#Use the same model used for training
class LeNet(nn.Module):
    # TODO: This isn't really a LeNet, but we implement this to be
    #  consistent with the Evidential Deep Learning paper
    def __init__(self):
        super(LeNet, self).__init__()
        self.model = None
        lenet_conv = []
        lenet_conv += [torch.nn.Conv2d(1,20, kernel_size=(5,5))]
        lenet_conv += [torch.nn.ReLU(inplace=True)]
        lenet_conv += [torch.nn.MaxPool2d(kernel_size=(2,2), stride=2)]
        lenet_conv += [torch.nn.Conv2d(20, 50, kernel_size=(5,5))]
        lenet_conv += [torch.nn.ReLU(inplace=True)]
        lenet_conv += [torch.nn.MaxPool2d(kernel_size=(2,2), stride=2)]

        lenet_dense = []
        lenet_dense += [torch.nn.Linear(4*4*50, 500)]
        lenet_dense += [torch.nn.ReLU(inplace=True)]
        lenet_dense += [torch.nn.Linear(500, 10)]

        self.features = torch.nn.Sequential(*lenet_conv)
        self.classifier = torch.nn.Sequential(*lenet_dense)

    def forward(self, input):
        output = self.features(input)
        output = output.view(input.shape[0], -1)
        output = self.classifier(output)
        return(output)
    
model_ckpt_path = 'lenet_mnist_model.h5'
model = LeNet()
if (torch.cuda.is_available()):
    ckpt = torch.load(model_ckpt_path)
    model.load_state_dict(ckpt)
    model = model.cuda()
else:
    ckpt = torch.load(model_ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt)

model.eval()

Let us define argruments needed for signal estimation.

In [None]:
class SignalEstimationArgs:
    mname = 'lenet'
    batch_size = 16
    num_classes = 10
    img_size = (1,28,28)
    input_tensor = torch.randn(1,1,28,28) # provide your own input tensor
    input_tensor_true = torch.randn(28,28,1) # provide your own true input tensor / ndarray / PIL Image Obj
    data_class = None # or `None` to estimate all classes
    mode = 'estimate' # one of `estimate`, `attribute`
    topk = 1
    cmap = 'seismic'
    json = False
    hm_diff = 'joint'
    
args = SignalEstimationArgs()

Create required directories to store estimated signal statistics which will be used later to attribute signals for a given test point.

In [None]:
module_path = os.path.dirname(os.path.realpath('.'))
proj_path = os.path.abspath(os.path.join(module_path, os.pardir))
model_signal_data_path = os.path.join(module_path, 'model_signal_data/')
results_path = os.path.join(module_path, 'results')
if(not os.path.exists(model_signal_data_path)):
    os.makedirs(model_signal_data_path)
if(not os.path.exists(results_path)):
    os.makedirs(results_path)
signal_store_path = os.path.join(model_signal_data_path, '{}_signals'.format(args.mname))


# Create untangle object
untangle_ai = UntangleAI()

Call untangle API (estimate_signals) to learn and store signal estimation statistics. Provide a data loader which loads the training dataset class by class. For MNIST we have provided an api for the same, which is load_mnist_per_class

In [None]:
def train_loader_fun(class_i):
    loader, _ = untangle_ai.load_mnist_per_class(batch_size=args.batch_size, data_class=class_i)
    return(loader)

untangle_ai.estimate_signals(model, signal_store_path, train_loader_fun, args = args)

## Step 2 - Attributing signals for a test point

Now we use the signal statistics modelled in Step 1 to attribute and visualize the signals for an input test point.

We use untangle_ai API (attribute_signals) to get signal information and visulaization with respect to top k class of model prediction. We also generate a joint heatmap and differentail heatmap images for the top k classes as part of this API.

This api works on individual data points one at a time. It takes input_tensor and image_tensor as (Height, Width, Channel) as its shape as inputs.

Let us try to visualize signal/feature considered important for a random data point. We visualize heatmap, differential heat map and inverse differential heatmap for top 3 classes predicted by model.

In [None]:
from scipy.misc import imread
import matplotlib.pyplot as plt
import random

keys = [str(item) for item in range(args.num_classes)]
ID2Name_Map = dict(zip(keys, keys))
rand_class = random.randint(0, 9)
print('Visualizing signal for a data point in class {}'.format(rand_class))
rand_class_loader = train_loader_fun(rand_class)
for input_tensor, _ in rand_class_loader:
    idx = random.randint(0, input_tensor.shape[0]-1)
    for topk in range(1, 4): # Try to get diff heatmap for top 3 classes
        args.topk = topk   
        input_tensor_single = input_tensor[idx][None, :, :, :]
        input_tensor_true = input_tensor[idx].permute(1, 2, 0) # expected shape (H,W,C)
        out_prefix = os.path.join(results_path, '{}_signals'.format(idx))
        untangle_ai.attribute_signals(model, input_tensor_single, input_tensor_true, signal_store_path,
            ID2Name_Map, args, out_prefix)
        for i in range(topk):
            img = imread(out_prefix + '_class_{}.JPEG'.format(i))
            print('Visualizing {}'.format(out_prefix + '_class_{}.JPEG'.format(i)))
            plt.imshow(img)
            plt.show()
        for i in range(1, topk):
            img = imread(out_prefix + '_diff_class_{}.JPEG'.format(i))
            print('Visualizing {}'.format(out_prefix + '_diff_class_{}.JPEG'.format(i)))
            plt.imshow(img)
            plt.show()
            img = imread(out_prefix + '_invDiff_class_{}.JPEG'.format(i))
            print('Visualizing {}'.format(out_prefix + '_invDiff_class_{}.JPEG'.format(i)))
            plt.imshow(img)
            plt.show()
    break
