### SHAPLEY BASED MODEL INTERPRETATION

In [124]:
import shap
import matplotlib.pyplot as plt
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
import scipy
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold as skf

In [125]:
## Model
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F


class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(ResidualBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.activation1 = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride= 1, padding= 2)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.activation2 = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride= 1, padding= 2)
    
    def forward(self,x):
        residual  = torch.clone(x)
        x = self.bn1(x)
        x = self.activation1(x)
        x = self.conv1(x)
        x = self.conv2(self.activation2(self.bn2(x)))
        residual = residual.unsqueeze(0)
        residual = nn.functional.interpolate(residual, size = [x.shape[1], x.shape[2], x.shape[3]])
        residual = residual.squeeze(0)
        x = x.clone()  # Ensure that `x` is not a view or a shared tensor
        x += residual
        return x
    

class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.T = 384
        
        # Layer 1
        self.conv1 = nn.Conv2d(1, 16, (1, 32), padding = 0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        
        # Layer 2
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 32))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        # Layer 3
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))

    def forward(self, x):
        # Layer 1
        x = F.elu(self.conv1(x))
        x = self.batchnorm1(x)
        x = F.dropout(x, 0.25)
        x = x.permute(0, 3, 1, 2)

        # Layer 2
        x = self.padding1(x)
        x = F.elu(self.conv2(x))
        x = self.batchnorm2(x)
        x = F.dropout(x, 0.25)
        x = self.pooling2(x)
        
        # Layer 3
        x = self.padding2(x)
        x = F.elu(self.conv3(x))
        x = self.batchnorm3(x)
        x = F.dropout(x, 0.25)
        x = self.pooling3(x)
        # FC Layer
        x = x.reshape(-1, 4*2*24)
        return x
    
    
class CNNModel(nn.Module):

    def __init__(self):
        super(CNNModel, self).__init__()
        self.feature_image = nn.Sequential()
        self.feature_image.add_module('f_conv1', nn.Conv2d(in_channels = 5, out_channels = 32, kernel_size= 3, stride= 1, padding= 2))
        self.feature_image.add_module('f_resblock1', ResidualBlock(32,32,5))
        self.feature_image.add_module('f_resblock2', ResidualBlock(32,64,5))
        self.feature_image.add_module('f_resblock3', ResidualBlock(64,128,5))
        self.feature_image.add_module('f_adaptiveavgpool', nn.AdaptiveAvgPool2d((15,15)))

        self.feature_wave = nn.Sequential()
        self.feature_wave.add_module('f_EEGNet',EEGNet())

        self.class_classifier = nn.Sequential()
        self.class_classifier.add_module('c_fc1', nn.Linear(15 * 15 * 128 + 4*2*24, 1024))
        self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(1024))
        self.class_classifier.add_module('c_relu1', nn.ReLU(True))
        self.class_classifier.add_module('c_drop1', nn.Dropout(0.2))
        self.class_classifier.add_module('c_fc2', nn.Linear(1024, 512))
        self.class_classifier.add_module('c_bn2', nn.BatchNorm1d(512))
        self.class_classifier.add_module('c_relu2', nn.ReLU(True))
        self.class_classifier.add_module('c_fc3', nn.Linear(512, 2))
        self.class_classifier.add_module('c_softmax', nn.Softmax(dim = 1))

        self.domain_classifier = nn.Sequential()
        self.domain_classifier.add_module('d_fc1', nn.Linear(15 * 15 * 128 + 4*2*24, 1024))
        self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(1024))
        self.domain_classifier.add_module('d_relu1', nn.ReLU(True))
        self.domain_classifier.add_module('d_fc2', nn.Linear(1024, 2))
        self.domain_classifier.add_module('d_softmax', nn.Softmax(dim = 1))
    

    def forward(self, input_data, input_wave):
        alpha = 0
        feature1 = self.feature_image(input_data)
        feature2 = self.feature_wave(input_wave)
        feature1 = feature1.view(-1, 15 * 15 * 128)
        feature = torch.cat((feature1,feature2),1)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.class_classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)
        return class_output


In [126]:
## Load data
class Dataset(Dataset):
    def __init__(self, data_img,data_wave,info):
        #data loading
        self.x1 = data_img
        self.x2 = data_wave
        self.y = info
        self.n_samples = data_img.shape[0]


    def __getitem__(self,index):
        t1 = self.x1[index]
        t3 = self.x2[index]
        t2 = self.y[index]
        t1 = torch.tensor(t1)
        t1 = t1.permute((2,0,1))
        t2 = torch.tensor(t2)
        t3 = torch.tensor(t3)
        return (t1,t3,t2)
    
    def __len__(self):
        return self.n_samples
    
data = scipy.io.loadmat('/home/desktop/Desktop/22104412_Docs/EEG-COGMusic/DA-AFNet/datasets/Zscore_clipped/s15_datasets_Zscore_clipped.mat')

x_c=data['coh']
x_p = data['pli']
x_d = data['psd']
labels_skf = data['labels_kfold']
labels = data['valence']
de = data['EEGNet']

x_n = np.zeros((40,75,32,32,5))
for trial in range(x_n.shape[0]):
    for sample in range(x_n.shape[1]):
        x_n[trial,sample,:,:,0] = x_c[trial,sample,:,:,0]+np.transpose(x_p[trial,sample,:,:,0])
        x_n[trial,sample,:,:,1] = x_c[trial,sample,:,:,1]+np.transpose(x_p[trial,sample,:,:,1])
        x_n[trial,sample,:,:,2] = x_c[trial,sample,:,:,2]+np.transpose(x_p[trial,sample,:,:,2])
        x_n[trial,sample,:,:,3] = x_c[trial,sample,:,:,3]+np.transpose(x_p[trial,sample,:,:,3])
        x_n[trial,sample,:,:,4] = x_c[trial,sample,:,:,4]+np.transpose(x_p[trial,sample,:,:,4])
    print(f'Completed trial {trial}')

for trial in range(x_n.shape[0]):
    for sample in range(x_n.shape[1]):
        for i in range(32):
            x_n[trial,sample,i,i,:] = x_d[trial,sample,i,:]
    print(f'Completed trial {trial}')

Completed trial 0
Completed trial 1
Completed trial 2
Completed trial 3
Completed trial 4
Completed trial 5
Completed trial 6
Completed trial 7
Completed trial 8
Completed trial 9
Completed trial 10
Completed trial 11
Completed trial 12
Completed trial 13
Completed trial 14
Completed trial 15
Completed trial 16
Completed trial 17
Completed trial 18
Completed trial 19
Completed trial 20
Completed trial 21
Completed trial 22
Completed trial 23
Completed trial 24
Completed trial 25
Completed trial 26
Completed trial 27
Completed trial 28
Completed trial 29
Completed trial 30
Completed trial 31
Completed trial 32
Completed trial 33
Completed trial 34
Completed trial 35
Completed trial 36
Completed trial 37
Completed trial 38
Completed trial 39
Completed trial 0
Completed trial 1
Completed trial 2
Completed trial 3
Completed trial 4
Completed trial 5
Completed trial 6
Completed trial 7
Completed trial 8
Completed trial 9
Completed trial 10
Completed trial 11
Completed trial 12
Completed tri

In [127]:
v = labels_skf[:,0]
indices = np.where((v>5.5)|(v<4.5))[0]
indices.shape

(35,)

In [128]:
dt = x_n[indices]
labels = labels[indices]

In [129]:
l_skf = np.zeros([40,1])
l_skf[np.where(labels_skf[:,0]>5)[0]] = 1
l_skf = l_skf[indices]

In [130]:
de = de.transpose((0,1,3,2))
de = de[indices,:,np.newaxis,:,:]
de.shape

(35, 75, 1, 384, 32)

In [131]:
kf = skf(n_splits = 10)
log_pred_dict = {}
best_Acc = 0
for k,(train_index,test_index) in enumerate(kf.split(dt, l_skf)):
    if k == 4:
        print(f'Fold {k+1} running')
        deTr,deV = np.concatenate(de[train_index],0), np.concatenate(de[test_index],0)
        dataTr, dataV = np.concatenate(dt[train_index],0), np.concatenate(dt[test_index],0)
        labelsTr, labelsV = np.concatenate(labels[train_index],0), np.concatenate(labels[test_index],0)
        ## parameters
        bs = 300
        image_size = 32
        n_epoch = 50
    
        testDS = Dataset(dataV,deV,labelsV)
        testDL = DataLoader(dataset = testDS, batch_size = bs, shuffle=True)

Fold 5 running


In [132]:
## Load Model and test data loader
model = CNNModel()
model.load_state_dict(torch.load('/home/desktop/Desktop/22104412_Docs/EEG-COGMusic/DA-AFNet/models/model_best_49_4_92.33333333333333.pth'))

<All keys matched successfully>

In [133]:
# since shuffle=True, this is a random sample of test data
batch = next(iter(testDL))
images,waves, labels = batch

In [134]:
import os
import torch.backends.cudnn as cudnn
import torch.utils.data


def test(dataset_loader,epoch,model,name):
    cuda = True
    cudnn.benchmark = True
    image_size = 32
    alpha = 0

    dataloader = dataset_loader

    """ training """

    my_net = model
    my_net = my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    len_dataloader = len(dataloader)
    data_target_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    while i < len_dataloader:

        # test model using target data
        data_target = next(data_target_iter)
        test_img,test_wave, test_label = data_target

        batch_size = len(test_label)

        if cuda:
            test_img = test_img.cuda()
            test_wave = test_wave.cuda()
            test_label = test_label.cuda()

        class_output = my_net(input_data=test_img.float(), input_wave = test_wave.float())
        _,pred = torch.max(class_output, dim=1)
        _,gt = torch.max(test_label,dim=1)
        n_correct += (pred == gt).sum().item()
        n_total += batch_size

        i += 1

    accu = 100*n_correct/ n_total
    print(f'Epoch {epoch} for {name} dataset is {accu}')
    return accu

test(testDL, 100, model, 'target')

Epoch 100 for target dataset is 51.666666666666664


51.666666666666664

In [135]:
images.shape, waves.shape, labels.shape

(torch.Size([300, 5, 32, 32]),
 torch.Size([300, 1, 384, 32]),
 torch.Size([300, 2]))

In [136]:
l = np.where(labels == 1)[1]
np.where(l ==  1)[0]

array([  0,   5,   6,   7,   9,  11,  17,  18,  19,  20,  21,  22,  23,
        24,  25,  26,  29,  30,  32,  34,  36,  37,  42,  43,  45,  47,
        49,  53,  56,  57,  58,  59,  62,  63,  66,  68,  74,  82,  86,
        89,  90,  97,  98,  99, 101, 106, 108, 111, 112, 114, 116, 119,
       120, 121, 122, 123, 125, 128, 130, 131, 132, 134, 135, 136, 140,
       141, 142, 144, 147, 151, 152, 155, 156, 159, 161, 164, 166, 168,
       169, 170, 171, 172, 175, 176, 177, 178, 179, 180, 181, 182, 183,
       184, 185, 186, 188, 189, 190, 191, 192, 194, 195, 196, 197, 198,
       199, 200, 201, 203, 206, 207, 209, 210, 217, 219, 223, 227, 229,
       230, 231, 234, 236, 239, 240, 242, 243, 244, 245, 246, 248, 249,
       252, 256, 260, 263, 265, 266, 269, 270, 275, 276, 278, 282, 283,
       287, 290, 291, 293, 294, 296, 298])

In [137]:
background_images = images
test_images = images
background_waves = waves
test_waves = waves

# If the model outputs a tuple and you want to use the first output for SHAP
explainer = shap.GradientExplainer(
    model, 
    [background_images.float().to('cuda'), background_waves.float().to('cuda')] # This will specify the first output if the model returns a tuple
)


In [138]:
shap_values = explainer.shap_values([test_images.float().to('cuda'), test_waves.float().to('cuda')])

In [139]:
shap_values[0].shape, shap_values[1].shape

((300, 5, 32, 32, 2), (300, 1, 384, 32, 2))

In [140]:
background_images.shape, background_waves.shape

(torch.Size([300, 5, 32, 32]), torch.Size([300, 1, 384, 32]))

In [141]:
# import matplotlib.pyplot as plt


# shap.image_plot(shap_values[0][0, 0, :, :, :], test_images[0, 0, :, :].numpy(), cmap='viridis')
# shap.image_plot(shap_values[0][0, 1, :, :, :], test_images[0, 1, :, :].numpy(), cmap='viridis')
# shap.image_plot(shap_values[0][0, 2, :, :, :], test_images[0, 2, :, :].numpy(), cmap='viridis')
# shap.image_plot(shap_values[0][0, 3, :, :, :], test_images[0, 3, :, :].numpy(), cmap='viridis')
# shap.image_plot(shap_values[0][0, 4, :, :, :], test_images[0, 4, :, :].numpy(), cmap='viridis')  


# shap.image_plot(
#     shap_values[1][0, 0, 0:128, :, :].transpose((1, 0, 2)), 
#     test_waves[0, 0, 0:128, :].numpy().transpose((1, 0)), 
#     cmap='viridis'
# )

In [142]:
## 5 channel wise features

shap_values_images = shap_values[0]
shap_values_images = shap_values_images.reshape((300,5,-1,2))
shap_values_images = np.sum(shap_values_images, axis = 2)
shap_values_images.shape


(300, 5, 2)

In [143]:
## 1 spatio-temporal feature

shap_values_waves = shap_values[1]
shap_values_waves = shap_values_waves.reshape((300,1,-1,2))
shap_values_waves = np.sum(shap_values_waves, axis = 2)
shap_values_waves.shape

(300, 1, 2)

In [144]:
shap_values = np.concatenate((shap_values_images, shap_values_waves), axis = 1)
shap_values.shape

(300, 6, 2)

In [145]:
## reshape inputs for the shap plots

test_images = test_images.reshape((300,5,-1))
test_waves = test_waves.reshape((300,1,-1))
test_images.shape, test_waves.shape

(torch.Size([300, 5, 1024]), torch.Size([300, 1, 12288]))

In [146]:
test_images = np.sum(test_images.numpy(), axis = 2)
test_waves = np.sum(test_waves.numpy(), axis = 2)

test_images.shape, test_waves.shape

((300, 5), (300, 1))

In [147]:
X_test = np.concatenate((test_images, test_waves), axis = 1)
X_test.shape

(300, 6)

In [151]:
plt.figure()
shap.summary_plot(shap_values[:,:,0], X_test, feature_names=["Theta band", "Alpha band", "Beta_low band", "Beta_high band", "Gamma band", "Raw_EEG"],show=False)
plt.title("SHAP summary plot for low valence class")
plt.savefig("c0.pdf", format="pdf", bbox_inches="tight")

In [153]:
plt.figure()
shap.summary_plot(shap_values[:,:,1], X_test, feature_names=["Theta band", "Alpha band", "Beta_low band", "Beta_high band", "Gamma band", "Raw_EEG"], show=False)
plt.title("SHAP summary plot for high valence class")
plt.savefig("c1.pdf", format="pdf", bbox_inches="tight")