In [None]:
import random
import pickle
import numpy as np
import tqdm.auto as tqdm

from torch.nn.modules.loss import _WeightedLoss
from torch.optim.lr_scheduler import *
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import metrics

import matplotlib.pyplot as plt
import torchvision.models as models

from scipy.spatial import distance
from medpy import metric
from collections import Counter
import os


In [None]:
# !pip install captum

 ### Set the device: cpu or gpu

In [None]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")

if is_cuda:
    torch.backends.cudnn.benchmark = True

print(f'Preparing to use device {device}')

### Load the data

In [None]:
with open('/Users/pabharathi/Documents/GOES/Data/Data/train_data_scaled.pkl','rb') as f:
    x = pickle.load(f)

In [None]:
print("Saved shape:", x.shape)
x = x.transpose(0,3,1,2)
print("Reshaped to:", x.shape)
num_channels = len(x [0,:, 0, 0])

In [None]:
with open('/Users/pabharathi/Documents/GOES/Data/Data/train_counts.pkl','rb') as f:
    y = pickle.load(f)

In [None]:
#Find the channel-wise mean for the images that are classified as zero
#Create a baseline zero tensor using the channel-wise averages
#Later used as a potential baseline for Integrated Gradients and Deeplift
tensor = torch.tensor((), dtype=torch.float)
base_zero= torch.zeros([num_channels, 32, 32])
zero_idx = np.argwhere(y.squeeze() == 0).squeeze()
for i in range(0, num_channels):
    currchannel_mean = np.mean(x[zero_idx, i, :, :].flatten())
    base_zero[i] = tensor.new_full((32, 32), currchannel_mean)

In [None]:
print(base_zero)

In [None]:
print("Saved shape:", y.shape)
y = y.reshape(y.shape[0], 1)
print("Reshaped to:", y.shape)

### Subsample the data 

In [None]:
# gt0 = list(np.where(y >= 1)[0])
# lt0 = list(np.where(y < 1)[0])
# take_a_sample = random.sample(lt0, 10)

# sample = gt0 + take_a_sample

# x = x[sample]
# y = y[sample]



### Split the data into train / test partitions (80 / 20 split)

In [None]:
X_train, X_test, Y_train, Y_test = train_test_split(
    x, y, test_size=0.2, random_state = 5000
)

##### Create a histogram of the (log) counts

In [None]:
plt.hist(np.log(Y_train),bins=range(200))
plt.xlim([0,50])

##### Create labels for binned lightning counts 

In [None]:
bins = [10, 100]

y_train = np.where(Y_train[:] > 0.0, 1, 0)
y_test = np.where(Y_test[:] > 0.0, 1, 0)

for p,q in enumerate(bins):
    y_train = np.where(Y_train[:] >= q, p+2, y_train)
    y_test = np.where(Y_test[:] >= q, p+2, y_test)

In [None]:
print(y_train.shape)
print(y_test.shape)
print(X_train.shape)
print(X_test.shape)
X_IG = X_test
Y_IG = y_test

In [None]:
#Sample thirty of each of the classes for use in Integrated Gradients
#First get all the indices for each class then sample twenty images from each
#Comment out classes that do not exist in the current iteration
rng = np.random.default_rng(2021) #Set the seed for rng
test_classsize = 30

zero_idx = np.argwhere(y_test.squeeze() == 0).squeeze()
zero_randomidx = rng.choice(zero_idx, size=test_classsize, replace=False )

one_idx = np.argwhere(y_test.squeeze() == 1).squeeze()
one_randomidx = rng.choice(one_idx, size=test_classsize, replace=False )

two_idx = np.argwhere(y_test.squeeze() == 2).squeeze()
two_randomidx = rng.choice(two_idx, size=test_classsize, replace=False )

three_idx = np.argwhere(y_test.squeeze() == 3).squeeze()
three_randomidx = rng.choice(three_idx, size=test_classsize, replace=False )

test_randomidx = np.concatenate((zero_randomidx, one_randomidx, two_randomidx, three_randomidx))


### Create class weights based on counts 

In [None]:
counts = Counter()
for val in y_train:
    counts[val[0]] += 1
counts = dict(counts)

#weights = [1 - (counts[x] / sum(counts.values())) for x in sorted(counts.keys())]
weights = [np.log1p(max(counts.values()) / counts[x]) for x in sorted(counts.keys())]
weights = [x / max(weights) for x in weights]
weights = torch.FloatTensor(weights).to(device)

print(weights)

## Load a Model 

In [None]:
# The original Resnet model had to be modified to suit Captum's Requirement for Deeplift
# To overcome this error, an implementation of Resnet18 from the torchvision github was
# copied over and the modifying the BasicBlock module. The BasicBlock was modified so that
# the relu activation layers are redefined each time instead of being reused for subsequent uses.
# The performance of this modified Resnet model is comparable to the orignical implementation

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=4, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv1d(num_channels, 64, (7, 7), (2, 2), (3, 3), bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)
    
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def _resnet(arch, block, layers, num_classes, pretrained, progress, **kwargs):
    model = ResNet(block, layers, num_classes =num_classes, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet18-5c106cde.pth',
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=False, num_classes = 4, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], num_classes, pretrained, progress,
                   **kwargs)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
        #Add another relu layer
        
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        # Modified to use relu2
        out = self.relu2(out)

        return out
    
# class ResNet(nn.Module):
#     def __init__(self, fcl_layers = [], dr = 0.0, output_size = 1, resnet_model = 18, pretrained = True):
#         super(ResNet, self).__init__()
#         self.pretrained = pretrained
#         self.resnet_model = resnet_model 
#         if self.resnet_model == 18:
#             resnet = models.resnet18(pretrained=self.pretrained)
#         elif self.resnet_model == 34:
#             resnet = models.resnet34(pretrained=self.pretrained)
#         elif self.resnet_model == 50:
#             resnet = models.resnet50(pretrained=self.pretrained)
#         elif self.resnet_model == 101:
#             resnet = models.resnet101(pretrained=self.pretrained)
#         elif self.resnet_model == 152:
#             resnet = models.resnet152(pretrained=self.pretrained)
#         resnet.conv1 = torch.nn.Conv1d(4, 64, (7, 7), (2, 2), (3, 3), bias=False)
#         modules = list(resnet.children())[:-1]      # delete the last fc layer.
#         self.resnet_output_dim = resnet.fc.in_features
#         self.resnet = nn.Sequential(*modules)
#         self.fcn = self.make_fcn(self.resnet_output_dim, output_size, fcl_layers, dr)
        
#     def make_fcn(self, input_size, output_size, fcl_layers, dr):
#         if len(fcl_layers) > 0:
#             fcn = [
# #                 nn.Dropout(dr),
#                 nn.Linear(input_size, fcl_layers[0]),
# #                 nn.BatchNorm1d(fcl_layers[0]),
#                 torch.nn.ReLU()
#             ]
#             if len(fcl_layers) == 1:
#                 fcn.append(nn.Linear(fcl_layers[0], output_size))
#             else:
#                 for i in range(len(fcl_layers)-1):
#                     fcn += [
#                         nn.Linear(fcl_layers[i], fcl_layers[i+1]),
# #                         nn.BatchNorm1d(fcl_layers[i+1]),
#                         torch.nn.ReLU(),
# #                         nn.Dropout(dr)
#                     ]
#                 fcn.append(nn.Linear(fcl_layers[i+1], output_size))
#         else:
#             fcn = [
# #                 nn.Dropout(dr),
#                 nn.Linear(input_size, output_size)
#             ]
#         if output_size > 1:
#             fcn.append(torch.nn.LogSoftmax(dim=1))
#         return nn.Sequential(*fcn)

#     def _make_layer(self, block, planes, num_blocks, stride):
#         strides = [stride] + [1]*(num_blocks-1)
#         layers = []
#         for stride in strides:
#             layers.append(block(self.in_planes, planes, stride))
#             self.in_planes = planes * block.expansion
#         return nn.Sequential(*layers)

#     def forward(self, x):
#         x = self.resnet(x)
#         x = x.view(x.size(0), -1)  # flatten
#         x = self.fcn(x)
#         return x

In [None]:
output_size = len(weights) 
fcl_layers = []
dropout = 0.5

model = resnet18(num_classes = output_size).to(device)

### Test model to make sure the architecture is consistent on a batch size of 2

In [None]:
X = torch.from_numpy(X_train[:2]).float().to(device)
print(X.shape)
g = model(X).exp()
#print(torch.max(g,1)) # exp to turn the logits into probabilities, since we used LogSoftmax
print(g)
#print(torch.argmax(g,1))

### Load Optimizers

In [None]:
learning_rate = 1e-03
weight_decay = 1e-09
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

## Load Custom Loss Function
The modified loss function comes with label smoothing

In [None]:
class SmoothCrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean', smoothing=0.0):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    def k_one_hot(self, targets:torch.Tensor, n_classes:int, smoothing=0.0):
        with torch.no_grad():
            targets = torch.empty(size=(targets.size(0), n_classes),
                                  device=targets.device) \
                                  .fill_(smoothing /(n_classes-1)) \
                                  .scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
        return targets

    def reduce_loss(self, loss):
        return loss.mean() if self.reduction == 'mean' else loss.sum() \
        if self.reduction == 'sum' else loss

    def forward(self, inputs, targets):
        assert 0 <= self.smoothing < 1

        targets = self.k_one_hot(targets, inputs.size(-1), self.smoothing)
        log_preds = F.log_softmax(inputs, -1)

        if self.weight is not None:
            log_preds = log_preds * self.weight.unsqueeze(0)

        return self.reduce_loss(-(targets * log_preds).sum(dim=-1))

In [None]:
train_criterion = SmoothCrossEntropyLoss(weight = weights) #weight = weights, smoothing = 0.1) 
test_criterion = torch.nn.CrossEntropyLoss()  

## Set up Top-K Accuracy Calculations

In [None]:
def torch_accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified values of k
    In top-5 accuracy you give yourself credit for having the right answer
    if the right answer appears in your top five guesses.
    """
    with torch.no_grad():
        # ---- get the topk most likely labels according to your model
        # get the largest k \in [n_classes] (i.e. the number of most likely probabilities we will use)
        maxk = max(topk)  # max number labels we will consider in the right choices for out model
        batch_size = target.size(0)

        # get top maxk indicies that correspond to the most likely probability scores
        # (note _ means we don't care about the actual top maxk scores just their corresponding indicies/labels)
        _, y_pred = output.topk(k=maxk, dim=1)  # _, [B, n_classes] -> [B, maxk]
        y_pred = y_pred.t()  # [B, maxk] -> [maxk, B] Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.

        # - get the credit for each example if the models predictions is in maxk values (main crux of code)
        # for any example, the model will get credit if it's prediction matches the ground truth
        # for each example we compare if the model's best prediction matches the truth. If yes we get an entry of 1.
        # if the k'th top answer of the model matches the truth we get 1.
        # Note: this for any example in batch we can only ever get 1 match (so we never overestimate accuracy <1)
        target_reshaped = target.view(1, -1).expand_as(y_pred)  # [B] -> [B, 1] -> [maxk, B]
        # compare every topk's model prediction with the ground truth & give credit if any matches the ground truth
        correct = (y_pred == target_reshaped)  # [maxk, B] were for each example we know which topk prediction matched truth
        # original: correct = pred.eq(target.view(1, -1).expand_as(pred))

        # -- get topk accuracy
        list_topk_accs = []  # idx is topk1, topk2, ... etc
        for k in topk:
            # get tensor of which topk answer was right
            ind_which_topk_matched_truth = correct[:k]  # [maxk, B] -> [k, B]
            # flatten it to help compute if we got it correct for each example in batch
            flattened_indicator_which_topk_matched_truth = ind_which_topk_matched_truth.reshape(-1).float()  # [k, B] -> [kB]
            # get if we got it right for any of our top k prediction for each example in batch
            tot_correct_topk = flattened_indicator_which_topk_matched_truth.float().sum(dim=0, keepdim=True)  # [kB] -> [1]
            # compute topk accuracy - the accuracy of the mode's ability to get it right within it's top k guesses/preds
            topk_acc = tot_correct_topk / batch_size  # topk accuracy for entire batch
            list_topk_accs.append(topk_acc.item())
        return list_topk_accs 

In [None]:
epochs = 300
train_batch_size = 5
valid_batch_size = 128
batches_per_epoch = 500

topk = (1, 2)
patience = 50

In [None]:
# this will help lower the learning rate when the model stops improving
lr_scheduler = ReduceLROnPlateau(
    optimizer, 
    patience = 10, 
    min_lr = 1.0e-10,
    verbose = True
)

### Data Iterators for Model Loading

In [None]:
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)),
    batch_size=train_batch_size, 
    shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)),
    batch_size=valid_batch_size,
    shuffle=False)

In [None]:
epoch_test_losses = []

for epoch in range(epochs):

    ### Train the model 
    model.train()

    # Shuffle the data first
    batch_loss = []
    accuracy = {k: [] for k in topk}
    indices = list(range(X_train.shape[0]))
    random.shuffle(indices)
    
    # Now split into batches
    train_batches_per_epoch = int(X_train.shape[0] / train_batch_size) 
    train_batches_per_epoch = min(batches_per_epoch, train_batches_per_epoch)
    
    # custom tqdm so we can see the progress
    batch_group_generator = tqdm.tqdm(
        enumerate(train_loader), 
        total=train_batches_per_epoch, 
        leave=True
    )

    for k, (x, y) in batch_group_generator:

        # Converting to torch tensors and moving to GPU
        inputs = x.float().to(device)
        lightning_counts = y.long().to(device)

        # Clear gradient
        optimizer.zero_grad()

        # get output from the model, given the inputs
        pred_lightning_counts = model(inputs)

        # get loss for the predicted output
        loss = train_criterion(pred_lightning_counts, lightning_counts.squeeze(-1))
        
        # compute the top-k accuracy
        acc = torch_accuracy(pred_lightning_counts.cpu(), lightning_counts.cpu(), topk = topk)
        for i,l in enumerate(topk):
            accuracy[l] += [acc[i]]

        # get gradients w.r.t to parameters
        loss.backward()
        batch_loss.append(loss.item())

        # update parameters
        optimizer.step()

        # update tqdm
        to_print = "Epoch {} train_loss: {:.4f}".format(epoch, np.mean(batch_loss))
        for l in sorted(accuracy.keys()):
            to_print += " top-{}_acc: {:.4f}".format(l,np.mean(accuracy[l]))
        #to_print += " top-2_acc: {:.4f}".format(np.mean(accuracy[2])
        #to_print += " top-3_acc: {:.4f}".format(np.mean(accuracy[3]))
        to_print += " lr: {:.12f}".format(optimizer.param_groups[0]['lr'])
        batch_group_generator.set_description(to_print)
        batch_group_generator.update()
                                  
        if k >= train_batches_per_epoch and k > 0:
            break
        
    torch.cuda.empty_cache()

    ### Test the model 
    model.eval()
    with torch.no_grad():

        batch_loss = []
        accuracy = {k: [] for k in topk}
        
        # custom tqdm so we can see the progress
        valid_batches_per_epoch = int(X_test.shape[0] / valid_batch_size) 
        batch_group_generator = tqdm.tqdm(
            test_loader, 
            total=valid_batches_per_epoch, 
            leave=True
        )

        for (x, y) in batch_group_generator:
            # Converting to torch tensors and moving to GPU
            inputs = x.float().to(device)
            lightning_counts = y.long().to(device)
            # get output from the model, given the inputs
            pred_lightning_counts = model(inputs)
            # get loss for the predicted output
            loss = test_criterion(pred_lightning_counts, lightning_counts.squeeze(-1))
            batch_loss.append(loss.item())
            # compute the accuracy
            acc = torch_accuracy(pred_lightning_counts, lightning_counts, topk = topk)
            for i,k in enumerate(topk):
                accuracy[k] += [acc[i]]
            # update tqdm
            to_print = "Epoch {} test_loss: {:.4f}".format(epoch, np.mean(batch_loss))
            for k in sorted(accuracy.keys()):
                to_print += " top-{}_acc: {:.4f}".format(k,np.mean(accuracy[k]))
            batch_group_generator.set_description(to_print)
            batch_group_generator.update()

    test_loss = 1 - np.mean(accuracy[1])
    epoch_test_losses.append(test_loss)
    
    # Lower the learning rate if we are not improving
    lr_scheduler.step(test_loss)

    # Save the model if its the best so far.
    if test_loss == min(epoch_test_losses):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': test_loss
        }
        torch.save(state_dict, "best.pt")
        
    # Stop training if we have not improved after X epochs
    best_epoch = [i for i,j in enumerate(epoch_test_losses) if j == min(epoch_test_losses)][-1]
    offset = epoch - best_epoch
    if offset >= patience:
        break

##### Load the best model

In [None]:
checkpoint = torch.load(
    "best.pt",
    map_location=lambda storage, loc: storage
)
best_epoch = checkpoint["epoch"]
#model = Net(filter_sizes, fcl_layers).to(device)
model = resnet18().to(device)
model.load_state_dict(checkpoint["model_state_dict"])

##### Predict on the test dataset with the best model

In [None]:
topk = (1,2)

model.eval()
with torch.no_grad():
    y_true = []
    y_pred = []
    batch_loss = []
    accuracy = {k: [] for k in topk}
    # split test data into batches

    valid_batches_per_epoch = int(X_test.shape[0] / valid_batch_size) 
    batch_group_generator = tqdm.tqdm(
        test_loader, 
        total=valid_batches_per_epoch, 
        leave=True
    )
    
    for (x, y) in batch_group_generator:
        # Converting to torch tensors and moving to GPU
        inputs = x.float().to(device)
        lightning_counts = y.long().to(device)
        # get output from the model, given the inputs
        pred_lightning_counts = model(inputs)
        # get loss for the predicted output
        loss = test_criterion(pred_lightning_counts, lightning_counts.squeeze(-1))
        batch_loss.append(loss.item())
        # compute the accuracy
        acc = torch_accuracy(pred_lightning_counts, lightning_counts, topk = topk)
        for i,k in enumerate(topk):
            accuracy[k] += [acc[i]]
        
        y_true.append(lightning_counts.squeeze(-1))
        # Taking the top-1 answer here, but here is where we could compute the average predicted rather than take top-1
        y_pred.append(torch.argmax(pred_lightning_counts, 1))

y_true = torch.cat(y_true, axis = 0)
y_pred = torch.cat(y_pred, axis = 0)

In [None]:
y_true = y_true.cpu().numpy()
y_pred = y_pred.cpu().numpy()

In [None]:
print("val_loss", np.mean(batch_loss))
for k in topk:
    print(f"top-{k} {np.mean(accuracy[k])}")

In [None]:
metrics.accuracy_score(y_true,y_pred,normalize=True)

### Compute accuracy on all labels

In [None]:
for label in list(set(y_true)):
    c = (y_true == label)
    print(label, (y_true[c] == y_pred[c]).mean())

In [None]:
list(set(y_pred))

### Plot the confusion matrix 

In [None]:
import seaborn as sn
import pandas as pd

In [None]:
cm = metrics.confusion_matrix( y_true, y_pred, normalize = 'true')
print(cm)

In [None]:
df_cm = pd.DataFrame(cm, index = sorted(list(set(y_true))), columns = sorted(list(set(y_true))))
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)
plt.ylabel("Actual lighting class", fontsize = 12)
plt.xlabel("Predicted lighting class", fontsize = 14)
plt.savefig('allbands_Confusion.png')

In [None]:
print(metrics.classification_report(y_true, y_pred, digits =3))


## Deeplift

In [None]:
import PIL
from PIL import Image


import torchvision
from torchvision import models
from torchvision import transforms

from matplotlib.colors import LinearSegmentedColormap
import matplotlib.cm, matplotlib.colors
from collections import defaultdict



# import captum


# from captum.attr import LayerLRP

# from captum.attr import GradientShap
# from captum.attr import Occlusion
# from captum.attr import NoiseTunnel

from captum.attr import DeepLift
from captum.attr import IntegratedGradients

from captum.attr import visualization as viz

from enum import Enum



In [None]:
#Normalization of attributions in XAI mathods
#Modified version of the one in Captum's visualization package

class VisualizeSign(Enum):
    positive = 1
    absolute_value = 2
    negative = 3
    all = 4

def _prepare_image(attr_visual):
    return np.clip(attr_visual.astype(int), 0, 255)


def _normalize_scale(attr, scale_factor):
    assert scale_factor != 0, "Cannot normalize by scale factor = 0"
    if abs(scale_factor) < 1e-5:
        warnings.warn(
            "Attempting to normalize by value approximately 0, visualized results"
            "may be misleading. This likely means that attribution values are all"
            "close to 0."
        )
    attr_norm = attr / scale_factor
    return np.clip(attr_norm, -1, 1)


def _cumulative_sum_threshold(values, percentile):
    # given values should be non-negative
    assert percentile >= 0 and percentile <= 100, (
        "Percentile for thresholding must be " "between 0 and 100 inclusive."
    )
    sorted_vals = np.sort(values.flatten())
    cum_sums = np.cumsum(sorted_vals)
    threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
    return sorted_vals[threshold_id]


def _normalize_image_attr(
    attr, sign: str, outlier_perc:2
):
    attr_combined = np.sum(attr, axis=2)
    # Choose appropriate signed values and rescale, removing given outlier percentage.
    if VisualizeSign[sign] == VisualizeSign.all:
        threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc)
    elif VisualizeSign[sign] == VisualizeSign.positive:
        attr_combined = (attr_combined > 0) * attr_combined
        threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)
    elif VisualizeSign[sign] == VisualizeSign.negative:
        attr_combined = (attr_combined < 0) * attr_combined
        threshold = -1 * _cumulative_sum_threshold(
            np.abs(attr_combined), 100 - outlier_perc
        )
    elif VisualizeSign[sign] == VisualizeSign.absolute_value:
        attr_combined = np.abs(attr_combined)
        threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)
    else:
        raise AssertionError("Visualize Sign type is not valid.")
    return _normalize_scale(attr_combined, threshold)


In [None]:
#Create an array of images from the randomly chosen indices
all_images = X_IG[test_randomidx, :, :, :]
print(all_images.shape)

In [None]:
#Compute the model output for all the randomly chosen inputs
#Store the values in a dictionary for later use
predicted_dict = defaultdict(list)
total_outputs = []
for idx, input1 in enumerate(all_images):
    input1 = torch.tensor(input1)
    input1= input1.unsqueeze(0)
    output = model(input1)
    output = F.softmax(output, dim =1)
    prediction_score, pred_label_idx = torch.topk(output, 1)
    predicted_dict[pred_label_idx.item()].append(idx)
    total_outputs.append(output)

In [None]:
# Initialize the Deeplift model and run it for each of the inputs
# with respect to each of the picture's predicted class (target is set to label of predicted class)
# Current implementation of deeplift raises warning, but expected to go away in future versions
dl = DeepLift(model)
attributions_mult = []
for idx, input1 in enumerate(all_images):
    input1 = torch.tensor(input1)
    input1 = input1.unsqueeze(0)
    prediction_score, pred_label_idx = torch.topk(total_outputs[idx], 1)
    attributions_dl = dl.attribute(input1, target = pred_label_idx)
    attributions_mult.append(np.transpose(attributions_dl.squeeze().cpu().detach().numpy(), (1,2,0)))
attributions_mult = np.array(attributions_mult)

In [None]:
#Make sure that the shape of the attributions alligns with that of the transposed images
# We needed to tranpose because visualization packages require channels to be the last dim
print(attributions_mult.shape, np.transpose(all_images, (0, 2, 3, 1)).shape)

### Various plots depicting attributions against various parameters

In [None]:
#Create a 2D histogram of the pixel values against the attributions for the attributions
samp_all= np.transpose(all_images, (0, 2, 3, 1))
plt.figure(figsize=(7.5,5))
plt.xlabel("Pixel Value", fontsize = 14)
plt.ylabel("Deeplift Attribution", fontsize = 14)
plt.title("Density Heatmap of Inputs and Attributions", fontsize = 16, pad =15)
hi_ = plt.hist2d(samp_all.flatten(), attributions_mult.flatten(), range= [[0.0, 1.0], [-0.15, 0.15]], density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
plt.colorbar()
plt.savefig("DL_InpVsAttr.png")
plt.show()


In [None]:
print(len(weights))

In [None]:
plt.rcParams['figure.constrained_layout.use'] = False


In [None]:
## 2-d Histograms showing how input values as well as attributions vary acrosss true bins
def attr_classviz(attributions, Range, predictions = None, size = 4, Share = True, Pred = False, y_label = "Integrated Gradient Attributions"):
    if(size == 2):
        m,n = 2,1
        h,w = 9, 12
    if(size == 3):
        m,n = 3,1
        h, w = 8.5,19
    if(size == 4):
        m,n = 2,2
        h,w = 18, 12
        
    fig, axs = plt.subplots(m, n, figsize = (h, w), sharey = Share)
    classifications = ['No Lightning', '1-10 Lightning', '10-100 Lightning', '100+ Lightning']

    it = 0
    
    if(n == 1):
        for j in range (m):
            axs[j].set_title("Inputs and Attributions for " +
                            classifications[it], fontsize = 16, pad=15)
            axs[j].set_xlabel("Pixel Value", fontsize =14)
            axs[j].set_ylabel(y_label, fontsize =14)
            if(Pred==False):
                hi_ = axs[j].hist2d(samp_all[test_classsize*(it):test_classsize+test_classsize*(it), :, :, :].flatten(), 
                              attributions[test_classsize *(it):test_classsize+test_classsize*(it), :, :, :].flatten(),
                                    range= Range,
                              density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
            else:
                hi_ = axs[j].hist2d(samp_all[predictions[it], :, :, :].flatten(), 
                              attributions[predictions[it], :, :, :].flatten(), range= Range,
                              density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
            ax = axs[j]
            fig.colorbar(hi_[3], ax = ax)
            it += 1
    else:
        
        for i in range(m):
            for j in range (n):
                axs[i, j].set_title("Inputs and Attributions for " +
                            classifications[it], fontsize = 16, pad=15)
                axs[i,j].set_xlabel("Pixel Value", fontsize =14)
                axs[i,j].set_ylabel(y_label, fontsize =14)
                if(Pred==False):
                    hi_ = axs[i,j].hist2d(samp_all[test_classsize*(it):test_classsize+test_classsize*(it), :, :, :].flatten(), 
                              attributions[test_classsize *(it):test_classsize+test_classsize*(it), :, :, :].flatten(),range= Range,
                              density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
                else:
                    hi_ = axs[i,j].hist2d(samp_all[predictions[it], :, :, :].flatten(), 
                              attributions[predictions[it], :, :, :].flatten(), range= Range,
                              density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
                ax = axs[i,j]
                fig.colorbar(hi_[3], ax = ax)
                it += 1

    plt.subplots_adjust(left=0.2,
                    bottom=0.2, 
                    right=0.9, 
                    top=0.9, 
                    wspace=0.2, 
                    hspace=0.3)
    return plt

In [None]:
_ = attr_classviz(attributions = attributions_mult, Range= [[0.0, 1.0], [-0.15, 0.15]], Share= False, y_label = "Deeplift Attributions" )
_.savefig("DLact_class")

In [None]:
#Get Input vs attributions for each predicted class. Note this is usally less than the classifications in the model
_ = attr_classviz(attributions = attributions_mult, predictions = predicted_dict,  Range= [[0.0, 1.0], [-0.15, 0.15]], Share= False, size = len(predicted_dict), Pred=True, y_label = "Deeplift Attributions")
_.savefig("DLpred_class")

In [None]:
len(predicted_dict)

#### An example of standard deeplif from a bin 3 image 

In [None]:
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

_ = viz.visualize_image_attr_multiple(attributions_mult[34,:,:, 0:1],
                                          samp_all[34,:, :, 0:1],
                                          ["original_image", "heat_map"],
                                          ["all", "all"],
                                          show_colorbar=True,
                                            outlier_perc=1)


In [None]:
#Check that dimensions are same for samples and attirbutions (w.r.t to bands)
print(samp_all[k, :, :, i:i+1].shape, attributions_mult[k,:,:, i:i+1].shape)

In [None]:
def attributions_visualization (attributions, image, cmap=None, method=None):
    fig, axs = plt.subplots(nrows=4, ncols=2, squeeze=False, figsize=(16, 16))
    axs[0, 0].set_title('Band 8 Image', fontsize =16)
    _ = axs[0, 0].imshow(image[:, :, 0:1], vmin = 0, vmax = 1)
    axs[0, 0].axis('off')
    fig.colorbar(_, ax = axs[0, 0])

    
    axs[0, 1].set_title( method+' Band 8 Attributions', fontsize =16)
    _ =axs[0, 1].imshow(_normalize_image_attr(attr=attributions[:,:, 0:1], sign= "all", outlier_perc= 2.0),
                        cmap=plt.cm.PuOr, vmin = -1, vmax = 1)
    axs[0, 1].axis('off')
    fig.colorbar(_, ax = axs[0, 1])

    
    axs[1, 0].set_title('Band 9 Image', fontsize =16)
    _ = axs[1, 0].imshow(image[:, :, 1:2], vmin = 0, vmax = 1)
    axs[1, 0].axis('off')
    fig.colorbar(_, ax = axs[1, 0])

    axs[1, 1].set_title(method+' Band 9 Attributions', fontsize =16)
    _ = axs[1, 1].imshow(_normalize_image_attr(attr=attributions[:,:, 1:2] , sign= "all", outlier_perc= 2.0),
                         cmap=plt.cm.PuOr, vmin = -1, vmax = 1)
    axs[1, 1].axis('off')
    fig.colorbar(_, ax = axs[1, 1])

    axs[2, 0].set_title('Band 10 Image', fontsize =16)
    _ = axs[2, 0].imshow(image[:, :, 2:3], vmin = 0, vmax = 1)
    axs[2, 0].axis('off')
    fig.colorbar(_, ax = axs[2, 0])

    axs[2, 1].set_title(method+' Band 10 Attributions', fontsize =16)
    _ = axs[2, 1].imshow(_normalize_image_attr(attr=attributions[:,:, 2:3] , sign= "all", outlier_perc= 2.0),
                         cmap=plt.cm.PuOr, vmin = -1, vmax = 1)
    axs[2, 1].axis('off')
    fig.colorbar(_, ax = axs[2, 1])
    
    axs[3, 0].set_title('Band 14 Image', fontsize =16)
    _ = axs[3, 0].imshow(image[:, :, 3:4], vmin = 0, vmax = 1)
    axs[3, 0].axis('off')
    fig.colorbar(_, ax = axs[3, 0])
 
    axs[3, 1].set_title(method+' Band 14 Attributions', fontsize =16)
    _ = axs[3, 1].imshow(_normalize_image_attr(attr=attributions[:,:, 3:4] , sign= "all", outlier_perc= 2.0),
                         cmap=plt.cm.PuOr, vmin = -1, vmax = 1)
    axs[3, 1].axis('off')
    fig.colorbar(_, ax = axs[3, 1])
    

    plt.tight_layout()
  
    return fig
    

In [None]:
for i in [27, 34, 81, 110]:
    title = "DL" +str(i)
    _ = attributions_visualization(attributions_mult[i], samp_all[i], method = "Deeplift")
    _.savefig(title)

In [None]:
#Older visualization method
for k in ([27, 44, 81, 110]):
    for i in range(4):
        _ = viz.visualize_image_attr_multiple(attributions_mult[k,:,:, i:i+1],
                                          samp_all[k,:, :, i:i+1],
                                          ["original_image", "heat_map"],
                                          ["all", "all"],
                                          show_colorbar=True,
                                            outlier_perc=1)


In [None]:
#Attributions Plotting for the four bands
attributions_channels = attributions_mult.transpose(3, 1, 2, 0)
fig, axs = plt.subplots(2, 2, figsize = (18,12), sharey = False)
attributions_channels.shape
l =["Upper-level Tropospheric Water Vapor (Band 8)", "Mid-level Tropospheric Water Vapor (Band 9)",
    "Lower-level Tropospheric Water Vapor (Band 10)", "Infrared Longwave Window (Band 14)"]
it = 0
histogram_store = []
for i in range(2):
    for j in range (2):
        attributions_channels[it].flatten()
        axs[i, j].set_title("Deeplift Attribution for " + l[it], fontsize=16, pad=15 )
        axs[i,j].set_xlabel("Attribution Values", fontsize = 14)
        axs[i,j].set_ylabel("Log Counts", fontsize=14)
        _ =axs[i, j].hist(attributions_channels[it].flatten(), bins = 200, density = True, range = (-0.2, 0.2), log=True)
        histogram_store.append(_[0])
        it += 1

plt.tight_layout()
plt.savefig("DL_attrbChanels")
plt.show()
    

# Do similar analysis using Integrated Gradients

In [None]:
integrated_gradients = IntegratedGradients(model)
attributions_mult = []
for idx, input1 in enumerate(all_images):
    input1 = torch.tensor(input1)
    input1 = input1.unsqueeze(0)
    prediction_score, pred_label_idx = torch.topk(total_outputs[idx], 1)
    attributions_ig = integrated_gradients.attribute(input1, target = pred_label_idx, n_steps= 50)
    attributions_mult.append(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)))
attributions_mult = np.array(attributions_mult)

In [None]:
# samp_all= np.transpose(all_images, (0, 2, 3, 1))
plt.figure(figsize=(7.5,5))
plt.xlabel("Pixel Values", fontsize = 14)
plt.ylabel("Integrated Gradient Attribution", fontsize = 14)
plt.title("Density Heatmap of Inputs and Attributions", fontsize = 16, pad =15)
hi_ = plt.hist2d(samp_all.flatten(), attributions_mult.flatten(), range =[[0, 1.0],[-0.6, 0.6]], density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
plt.colorbar()
plt.savefig("IG_InpvsAttrb")
plt.show()

In [None]:
_ = attr_classviz(attributions = attributions_mult, Range= [[0.0, 1.0], [-0.6, 0.6]], Share= False)
_.savefig('IGact_class')


In [None]:
_ = attr_classviz(attributions = attributions_mult, Range= [[0.0, 1.0], [-0.6, 0.6]], Share= False, predictions = predicted_dict, size = len(predicted_dict), Pred=True)
_.savefig("IGpred_class")

In [None]:
# for k in ([27, 44, 81, 110]):
#     for i in range(4):
#         _ = viz.visualize_image_attr_multiple(attributions_mult[k,:,:, i:i+1],
#                                           samp_all[k,:, :, i:i+1],
#                                           ["original_image", "heat_map"],
#                                           ["all", "all"],
#                                           show_colorbar=True,
#                                             outlier_perc=1)


In [None]:
for i in [27, 34, 81, 110]:
    title = "IG" +str(i)
    _ = attributions_visualization(attributions_mult[i], samp_all[i], method = "Integrated Gradients")
    _.savefig(title)

In [None]:
attributions_channels = attributions_mult.transpose(3, 1, 2, 0)
fig, axs = plt.subplots(2, 2, figsize = (18,12), sharey = False)
attributions_channels.shape
l =["Upper-level Tropospheric Water Vapor (Band 8)", "Mid-level Tropospheric Water Vapor (Band 9)",
    "Lower-level Tropospheric Water Vapor (Band 10)", "Infrared Longwave Window (Band 14)"]
it = 0
histogram_store = []
for i in range(2):
    for j in range (2):
        attributions_channels[it].flatten()
        axs[i, j].set_title("Integrated Gradients Attribution for " + l[it], fontsize=16, pad=10 )
        axs[i,j].set_xlabel("Attribution Values", fontsize = 14)
        axs[i,j].set_ylabel("Log Counts", fontsize=14)
        _ =axs[i, j].hist(attributions_channels[it].flatten(), bins = 200, density = True, range = (-0.6, 0.6), log=True)
        histogram_store.append(_[0])
        it += 1

plt.tight_layout()
plt.savefig("IG_attrbChanels")
plt.show()

### Repeat of previous analysis with the Baseline set to the average bin 0 image 
#### Implemented for Integrated Gradients

In [None]:
integrated_gradients = IntegratedGradients(model)
attributionsBase_mult = []

for input1 in all_images:
    input1 = torch.tensor(input1)
    input1 = input1.unsqueeze(0)
    prediction_score, pred_label_idx = torch.topk(total_outputs[idx], 1)
    attributionsBase_ig = integrated_gradients.attribute(input1, target = pred_label_idx, n_steps = 250, baselines= base_zero.unsqueeze(0) )
    attributionsBase_mult.append(np.transpose(attributionsBase_ig.squeeze().cpu().detach().numpy(), (1,2,0)))
attributionsBase_mult = np.array(attributionsBase_mult)

In [None]:
plt.xlabel("Pixel Value")
plt.ylabel("Integrated Gradient Attribution")
plt.title("2D Histogram of Input Values vs. Attribution Values ")
hi_ = plt.hist2d(samp_all.flatten(), attributionsBase_mult.flatten(), density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
plt.colorbar()
plt.show()


In [None]:
attributionsBase_mult.shape

In [None]:
fig, axs = plt.subplots(2, 2, figsize = (18,12), sharey = True)
l =["Band 8 Attributions", "Band 9 Attributions",
    "Band 10 Attributions", "Band 14 Attributions"]

it = 0
for i in range(2):
    for j in range (2):
        axs[i, j].set_title("2D Histogram of Input Values vs. " + l[it] )
        axs[i,j].set_xlabel("Pixel Value")
        axs[i,j].set_ylabel("Integrated Gradient Attribution")
        hi_ = axs[i,j].hist2d(samp_all[:, :, :, it].flatten(), attributionsBase_mult[:, :, :, it].flatten(), density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
        ax = axs[i,j]
        fig.colorbar(hi_[3], ax = ax)
        it += 1

plt.subplots_adjust(left=0.2,
                    bottom=0.2, 
                    right=0.9, 
                    top=0.9, 
                    wspace=0.2, 
                    hspace=0.3)
# plt.savefig("bandwise2d.png")
plt.show()

In [None]:
fig, axs = plt.subplots(2, 2, figsize = (18,12), sharey = True)
classifications = ['0 Bin', '1 Bin', '2 Bin', '3 Bin']

it = 0
for i in range(2):
    for j in range (2):
        axs[i, j].set_title("Input Values vs. Attributions for " +
                            classifications[it])
        axs[i,j].set_xlabel("Pixel Value")
        axs[i,j].set_ylabel("Integrated Gradient Attributions")
        hi_ = axs[i,j].hist2d(samp_all[random_size*(it):random_size+random_size*(it), :, :, :].flatten(), 
                              attributionsBase_mult[random_size *(it):random_size+random_size*(it), :, :, :].flatten(),
                              density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
        ax = axs[i,j]
        fig.colorbar(hi_[3], ax = ax)
        it += 1

plt.subplots_adjust(left=0.2,
                    bottom=0.2, 
                    right=0.9, 
                    top=0.9, 
                    wspace=0.2, 
                    hspace=0.3)

plt.show()

In [None]:
fig, axs = plt.subplots(2, 2, figsize = (18,12), sharey = True)
classifications = ['0 Bin', '1 Bin', '2 Bin', '3 Bin']

it = 0
for i in range(2):
    for j in range (2):
        if(it == len(predicted_dict)):
           break
        axs[i, j].set_title("Input Values vs. Attributions for " +
                            classifications[it])
        axs[i,j].set_xlabel("Pixel Value")
        axs[i,j].set_ylabel("Integrated Gradient Attributions")
        hi_ = axs[i,j].hist2d(samp_all[predicted_dict[it], :, :, :].flatten(), 
                              attributionsBase_mult[predicted_dict[it], :, :, :].flatten(),
                              density = True, bins= (50, 50), cmap = plt.cm.jet, norm = matplotlib.colors.LogNorm())
        ax = axs[i,j]
        fig.colorbar(hi_[3], ax = ax)
        it += 1
    else:
        continue
    break
    

plt.subplots_adjust(left=0.2,
                    bottom=0.2, 
                    right=0.9, 
                    top=0.9, 
                    wspace=0.2, 
                    hspace=0.3)

plt.show()

In [None]:
attributions_channels = attributionsBase_mult.transpose(3, 1, 2, 0)
fig, axs = plt.subplots(2, 2, figsize = (18,12), sharey = True)
attributions_channels.shape
l =["Upper-level Tropospheric Water Vapor (Band 8)", "Mid-level Tropospheric Water Vapor (Band 9)",
    "Lower-level Tropospheric Water Vapor (Band 10)", "Infrared Longwave Window (Band 14)"]
it = 0
histogram_store = []
for i in range(2):
    for j in range (2):
        attributions_channels[it].flatten()
        axs[i, j].set_title("Integrated Gradients for " + l[it] )
        axs[i,j].set_xlabel("Attribution Values")
        axs[i,j].set_ylabel("Counts")
        _ =axs[i, j].hist(attributions_channels[it].flatten(), bins = 557, density = True, log=True, range=(-0.6, 0.6))
        histogram_store.append(_[0])
        it += 1

plt.subplots_adjust(left=0.2,
                    bottom=0.2, 
                    right=0.9, 
                    top=0.9, 
                    wspace=0.2, 
                    hspace=0.2)
plt.show()