In [6]:
import shap
import numpy as np
import torch
from CNN import Net
import pandas as pd
import math, random, os

In [8]:
# Retrieved saved model + re-open data train and test
is_cov = False
batch_size = 50

def load_graphs(input_dir, class_dict, is_cov) :

    data, data_labels = [], [] # data contains the graphs as tensors and data_labels the associated seizure type labels
    i = 0
    for szr_type in class_dict.keys() :
        szr_label = class_dict[szr_type]
        for _, _, files in os.walk(os.path.join(input_dir,szr_type)) :

            for npy_file in files :
                A = np.load(os.path.join(input_dir,szr_type,npy_file))
                # Normalise A (already normalised depending on the input)
                A = A/np.amax(A.flatten())

                if is_cov : 
                    L = torch.tensor(A).view(1,20,20)
                else : 
                    L = torch.tensor(np.diag(A*np.ones((A.shape[0],1)))-A).view(1,20,20)

                data.append(L)
                data_labels.append(szr_label)

    return np.array(data, dtype=object), np.array(data_labels)

def train_test_data(input_dir, class_dict, is_cov) :

    train, train_labels = load_graphs(os.path.join(input_dir,'train'), class_dict, is_cov)
    test, test_labels = load_graphs(os.path.join(input_dir,'dev'), class_dict, is_cov)

    return train, test, train_labels, test_labels

def to_set(train, test, train_labels, test_labels) :

    # Oversampling (train set only) to have balanced classification without dropping information
    PD = pd.DataFrame(train_labels, columns=['label'])
    no_0, no_1 = len(PD[PD['label']==0]), len(PD[PD['label']==1])
    R = math.floor(no_0/no_1) # Multiply the dataset by this ratio, then add (no_0 - R*no_1) randomly selected entries from the smallest dataset

    trainset, testset = [], []
    for i in range(len(train)) :
        if train_labels[i] == 1 : # Under-represented class :
            # The dataloader later shuffles the data
            for r in range(R) :
                trainset.append((train[i],train_labels[i]))
        else :
            trainset.append((train[i],train_labels[i]))
    
    # Compensate the remaining imbalance => draw (no_0 - R*no_1) elements from already present elements
    Add = random.sample(PD[PD['label']==1].index.to_list(),no_0 - R*no_1)
    for idx in Add :
        trainset.append((train[idx],train_labels[idx]))

    for j in range(len(test)) :
        testset.append((test[j],test_labels[j]))

    for j in range(len(test)) :
        testset.append((test[j],test_labels[j]))

    return trainset, testset

In [9]:
classes = ['FNSZ','GNSZ']

class_dict = {}
for i, szr_type in enumerate(classes) :
    class_dict[szr_type] = i

# Retrieve trained CNN
model = torch.load('../classifier/low_lapl_50_CNN.pt')
model.eval()

input_dir = '../data/v1.5.2/graph_lapl_nolow'

train, test, train_labels, test_labels = train_test_data(input_dir, class_dict, is_cov)

trainset, testset = to_set(train, test, train_labels, test_labels)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.


In [28]:
print(test_images.size())

torch.Size([2, 1, 20, 20])


In [27]:
batch = next(iter(testloader))
images, _ = batch

background = images[:25].float()
test_images = images[20:22].float()

e = shap.DeepExplainer(model, background)
shap_values = e.shap_values(test_images)

"""
CNN.float()
X, y = X.float(), y.type(torch.LongTensor)
"""

RuntimeError: The size of tensor a (9) must match the size of tensor b (3) at non-singleton dimension 3

In [24]:
print(background.size())

torch.Size([20, 1, 20, 20])


In [12]:
shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
test_numpy = np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)

In [13]:
# plot the feature attributions
shap.image_plot(shap_numpy, -test_numpy)

ValueError: Number of rows must be a positive integer, not 0

<Figure size 648x180 with 0 Axes>