In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch.nn.functional as F

np.random.seed(1)

In [None]:
NUMBER_OF_CLASSES = 10
EPOCHS = 20
BATCH_SIZE = 128
# VALIDATION_SPLIT = 0.2

In [None]:
# Load MNIST dataset
working_dir = '/content/drive/MyDrive/EECS 545 project/'
transform_list={
    torchvision.transforms.ToTensor(),
    # torchvision.transforms.Normalize(mean=(0), std=(1.0)),
}
transforms = torchvision.transforms.Compose(transform_list)
train_set = torchvision.datasets.FashionMNIST(train=True, root=working_dir + 'Fashion_MNIST', download = True, transform=transforms)
test_set = torchvision.datasets.FashionMNIST(train=False, root=working_dir + 'Fashion_MNIST', download = True, transform=transforms)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4, 4), stride=(2, 2))
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(4, 4), stride=(2, 2))
        self.drop1 = nn.Dropout(p=0.25)
        self.lin1 = nn.Linear(1600, 128)
        self.drop2 = nn.Dropout(p=0.5)
        self.output = nn.Linear(128, 10)

    def forward(self, input_batch):
        x = F.relu(self.conv1(input_batch))
        x = F.relu(self.conv2(x))
        x = self.drop1(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.lin1(x))
        x = self.drop2(x)
        x = self.output(x)
        return x

In [None]:
device = torch.device("cpu")
cnn_mnist= SimpleCNN().to(device)
cnn_mnist.load_state_dict(torch.load(working_dir + "cnn_fashion_mnist_model_cpu.pth"))
cnn_mnist.eval()

# DeepSHAP Implementation

In [None]:
"""
Begin
"""

In [None]:
layers = [cnn_mnist.conv1, cnn_mnist.conv2, cnn_mnist.lin1, cnn_mnist.output]
num_layers = len(layers)

In [None]:
# Pick a sample from the test set
print(test_loader.dataset.class_to_idx)
idx = [19, 2, 1, 13, 6, 8, 4, 9, 18, 0]  # Example indices
fig, axs = plt.subplots(3, 4)
for i in range(10):
    features, label = test_loader.dataset.__getitem__(idx[i])
    # print(label)
    axs[i // 4, i % 4].imshow(features.squeeze(), cmap='gray', vmin=0, vmax=1)
index = idx[0]
features, label = test_loader.dataset.__getitem__(index)
features = features.unsqueeze(0)
# print(features)
pre_label = cnn_mnist(features).max(dim=1)[1]
print("Predict label: ", pre_label.item(), "True label: ", label)

fig, axs = plt.subplots(1,1)
axs.imshow(features.reshape(28,28), cmap='gray', vmin=0, vmax=1)

In [None]:
# Choose background
# print(test_loader.dataset.data.shape)
# print(test_loader.dataset.targets.shape)
reference_data = test_loader.dataset.data[0:1000, :, :]
reference_data = reference_data.unsqueeze(1) / 255.0

# ## method 1
# ## Conv1
# conv1_ref_pre = cnn_mnist.conv1(reference_data)
# conv1_ref = F.relu(conv1_ref_pre)
# ## Conv2
# conv2_ref_pre = cnn_mnist.conv2(conv1_ref)
# conv2_ref = F.relu(conv2_ref_pre)
# conv2_ref_flat = torch.flatten(conv2_ref, start_dim=1)
# ## lin1
# lin1_ref_pre = cnn_mnist.lin1(conv2_ref_flat)
# lin1_ref = F.relu(lin1_ref_pre)
# ## lin2
# output_ref = cnn_mnist.output(lin1_ref)
# # print(output_ref)
# # output_ref_2 = cnn_mnist(reference_data)
# # print(output_ref_2)

# ## Average all the refence values
# ref = reference_data.mean(dim=0)
# ## Conv1
# conv1_ref_pre = conv1_ref_pre.mean(dim=0)
# conv1_ref = conv1_ref.mean(dim=0)
# ## Conv2
# conv2_ref_pre = conv2_ref_pre.mean(dim=0)
# conv2_ref = conv2_ref.mean(dim=0)
# conv2_ref_flat = conv2_ref_flat.mean(dim=0)
# ## lin1
# lin1_ref_pre = lin1_ref_pre.mean(dim=0)
# lin1_ref = lin1_ref.mean(dim=0)
# ## lin2
# output_ref = output_ref.mean(dim=0)

## method 2
ref = reference_data.mean(dim=0).unsqueeze(0)
## Conv1
conv1_ref_pre = cnn_mnist.conv1(ref)
conv1_ref = F.relu(conv1_ref_pre)
## Conv2
conv2_ref_pre = cnn_mnist.conv2(conv1_ref)
conv2_ref = F.relu(conv2_ref_pre)
conv2_ref_flat = torch.flatten(conv2_ref, start_dim=1)
## lin1
lin1_ref_pre = cnn_mnist.lin1(conv2_ref_flat)
lin1_ref = F.relu(lin1_ref_pre)
## lin2
output_ref = cnn_mnist.output(lin1_ref)

fig, axs = plt.subplots(1,1)
axs.imshow(ref.reshape(28,28), cmap='RdBu', vmin = -1, vmax=1)

In [None]:
## Fordward pass
## Conv1
conv1_pre = cnn_mnist.conv1(features)
conv1_data = F.relu(conv1_pre)
## Conv2
conv2_pre = cnn_mnist.conv2(conv1_data)
conv2_data = F.relu(conv2_pre)
conv2_flat = torch.flatten(conv2_data, start_dim=1)
## lin1
lin1_pre = cnn_mnist.lin1(conv2_flat)
lin1_data = F.relu(lin1_pre)
## lin2
output_data = cnn_mnist.output(lin1_data)
# print(output_data)

## Calculating delta
x_delta = features - ref
## Conv1
delta_conv1_pre = conv1_pre - conv1_ref_pre
delta_conv1_data = conv1_data - conv1_ref
## Conv2
delta_conv2_pre = conv2_pre- conv2_ref_pre
delta_conv2_data = conv2_data - conv2_ref
delta_conv2_flat = conv2_flat - conv2_ref_flat
## lin1
delta_lin1_pre = lin1_pre - lin1_ref_pre
delta_lin1_data = lin1_data - lin1_ref
## lin2
delta_output_data = output_data - output_ref


In [None]:
eps = 1e-4
## Calculating SHAP values
shap_values_mul = np.zeros((NUMBER_OF_CLASSES, features.shape[2], features.shape[3]))
# Backward pass
for clas in range(NUMBER_OF_CLASSES):
    # output layer
    temp = (delta_lin1_data.data).requires_grad_(True)
    z = cnn_mnist.output(temp)
    z[0, clas].backward()
    delta_lin1_data_grad = temp.grad
    # print(delta_lin1_data_grad)
    # print(cnn_mnist.output.weight[clas, :])

    # lin1 layer
    idx = (abs(delta_lin1_pre.data) < eps)
    delta_lin1_pre_grad = torch.FloatTensor(delta_lin1_pre.shape)
    if torch.sum(idx) != 0:
        for i in range(delta_lin1_pre.shape[1]):
            if idx[0, i] == True:
                temp1 = (delta_lin1_pre[0, i].data).requires_grad_(True)
                z = F.relu(temp1)
                z.backward()
                delta_lin1_pre_grad[0, i] = temp1.grad * delta_lin1_data_grad[0, i]
    if torch.sum(~idx) != 0:
        temp2 = delta_lin1_data[:, (~idx).squeeze()].data / delta_lin1_pre[:, (~idx).squeeze()].data
        temp2 = temp2 * delta_lin1_data_grad[:, (~idx).squeeze()]
    
    if torch.sum(~idx) != 0:
        delta_lin1_pre_grad[:, (~idx).squeeze()] = temp2
    # print(delta_lin1_pre_grad)

    temp2 = torch.FloatTensor(cnn_mnist.lin1.weight.shape)
    for i in range(temp2.shape[0]):
        temp = (delta_conv2_data.data).requires_grad_(True)
        z = torch.flatten(temp, start_dim=1)
        z2 = cnn_mnist.lin1(z)
        z2[0, i].backward()
        temp2[i, :] = torch.flatten(temp.grad, start_dim=1)
    delta_conv2_data_grad = torch.matmul(delta_lin1_pre_grad, temp2)
    # print(delta_conv2_data_grad[:, 100:200])
    delta_conv2_data_grad = delta_conv2_data_grad.unflatten(1, delta_conv2_data.squeeze().shape)
    # delta_conv2_data_grad = torch.flatten(delta_conv2_data_grad, start_dim=1)
    # print(delta_conv2_data_grad[:, 100:200])

    # conv2 layer
    idx = (abs(delta_conv2_pre.data) < eps).squeeze()
    delta_conv2_pre_grad = torch.FloatTensor(delta_conv2_pre.shape)
    if torch.sum(idx) != 0:
        for i in range(delta_conv2_pre.shape[1]):
            for j in range(delta_conv2_pre.shape[2]):
                for k in range(delta_conv2_pre.shape[3]):
                    if idx[i, j, k] == True:
                        temp1 = (delta_conv2_pre[0, i, j, k].data).requires_grad_(True)
                        z = F.relu(temp1)
                        z.backward()
                        delta_conv2_pre_grad[0, i, j, k] = temp1.grad * delta_conv2_data_grad[0, i, j, k]
    if torch.sum(~idx) != 0:
        temp2 = delta_conv2_data[:, ~idx].data / delta_conv2_pre[:, ~idx].data
        temp2 = temp2 * delta_conv2_data_grad[:, ~idx]
    # print(temp2.shape)

    if torch.sum(~idx) != 0:
        delta_conv2_pre_grad[:, (~idx).squeeze()] = temp2
    # print(delta_conv2_pre_grad.shape)
    
    delta_conv2_pre_grad_flat = torch.flatten(delta_conv2_pre_grad, start_dim=1)
    # print(delta_conv2_pre_grad_flat.shape)
    # print(delta_conv1_data.shape)
    temp2 = torch.FloatTensor(delta_conv2_pre_grad_flat.shape[1], delta_conv1_data.flatten(start_dim=1).shape[1])
    # print(temp2.shape)
    
    for i in range(temp2.shape[0]):
        temp = (delta_conv1_data.data).requires_grad_(True)
        z2 = cnn_mnist.conv2(temp)
        z2.flatten(start_dim=1)[0, i].backward()
        temp2[i, :] = torch.flatten(temp.grad, start_dim=1)
    delta_conv1_data_grad = torch.matmul(delta_conv2_pre_grad_flat, temp2)
    # print(delta_conv1_data_grad.shape)
    # print(delta_conv1_data_grad[:, 100:200])
    delta_conv1_data_grad = delta_conv1_data_grad.unflatten(1, delta_conv1_data.squeeze().shape)
    # delta_conv1_data_grad = torch.flatten(delta_conv1_data_grad, start_dim=1)
    # print(delta_conv1_data_grad[:, 100:200])

    # conv1 layer
    idx = (abs(delta_conv1_pre.data) < eps).squeeze()
    delta_conv1_pre_grad = torch.FloatTensor(delta_conv1_pre.shape)
    if torch.sum(idx) != 0:
        for i in range(delta_conv1_pre.shape[1]):
            for j in range(delta_conv1_pre.shape[2]):
                for k in range(delta_conv1_pre.shape[3]):
                    if idx[i, j, k] == True:
                        temp1 = (delta_conv1_pre[0, i, j, k].data).requires_grad_(True)
                        z = F.relu(temp1)
                        z.backward()
                        delta_conv1_pre_grad[0, i, j, k] = temp1.grad * delta_conv1_data_grad[0, i, j, k]
    if torch.sum(~idx) != 0:
        temp2 = delta_conv1_data[:, ~idx].data / delta_conv1_pre[:, ~idx].data
        temp2 = temp2 * delta_conv1_data_grad[:, ~idx]
    # print(temp2.shape)

    if torch.sum(~idx) != 0:
        delta_conv1_pre_grad[:, (~idx).squeeze()] = temp2
    # print(delta_conv1_pre_grad.shape)
    
    delta_conv1_pre_grad_flat = torch.flatten(delta_conv1_pre_grad, start_dim=1)
    # print(delta_conv1_pre_grad_flat.shape)
    # print(x_delta.shape)
    temp2 = torch.FloatTensor(delta_conv1_pre_grad_flat.shape[1], x_delta.flatten(start_dim=1).shape[1])
    # print(temp2.shape)
    
    for i in range(temp2.shape[0]):
        temp = (x_delta.data).requires_grad_(True)
        z2 = cnn_mnist.conv1(temp)
        z2.flatten(start_dim=1)[0, i].backward()
        temp2[i, :] = torch.flatten(temp.grad, start_dim=1)

    x_delta_grad = torch.matmul(delta_conv1_pre_grad_flat, temp2)
    # print(x_delta_grad.shape)
    # print(x_delta_grad[:, 100:200])
    x_delta_grad = x_delta_grad.unflatten(1, x_delta.squeeze().shape)
    # x_delta_grad = torch.flatten(x_delta_grad, start_dim=1)
    # print(x_delta_grad[:, 100:200])

    shap_values_mul[clas, :, :] = x_delta_grad.squeeze().numpy()


In [None]:
explaination = (shap_values_mul.reshape(10, 784) * x_delta.numpy().reshape(1, 784)).T  # shape (784, 10)
# explaination = (shap_values_mul.reshape(10, 784) * features.numpy().reshape(1, 784)).T  # shape (784, 10)
print(explaination.sum(axis=0))
print(delta_output_data.squeeze().detach().numpy())

In [None]:
#  Normalization
explaination = explaination - explaination.mean(axis=1).reshape(-1, 1)

In [None]:
# # Maksing
# for clas in range(NUMBER_OF_CLASSES):
#     sort_idx = np.argsort(abs(explaination[:, clas]))
#     mask = sort_idx[0:round(explaination.shape[0] * 0.85)]
#     explaination[mask, clas] = 0

In [None]:
number_compare_list = [1, 2, 3, 4, 6]
name_of_class = ['Trouser', 'Pullover', 'Dress', 'Coat', 'Shirt']
fig, axs = plt.subplots(3, len(number_compare_list))
number = label
input_img = features.numpy().reshape(1, 784)
expected_X = ref.numpy().reshape(1, 784)
limit_common = abs(explaination).max()
axs[0, 0].imshow(input_img.reshape(28,28), cmap='gray', vmin=0, vmax=1)
axs[0, 1].imshow(expected_X.reshape(28,28), cmap='gray', vmin=0, vmax=1)
axs[0, 2].imshow(explaination[:, number].reshape(28,28), cmap='RdBu', vmin = -limit_common, vmax=limit_common)
axs[0, 0].axis("off")
axs[0, 1].axis("off")
axs[0, 2].axis("off")
axs[0, 3].axis("off")
axs[0, 4].axis("off")

for i, number_compare in enumerate(number_compare_list):
    # mean = explaination.mean()
    # axs[0].imshow(explaination[:, number].reshape(28,28), cmap='gray', vmin = explaination[:, number].min(), vmax=explaination[:, number].max())
    # axs[i, 0].imshow(input_img.reshape(28,28), cmap='RdBu', vmin = -1, vmax=1)
    # axs[i, 1].imshow(expected_X.reshape(28,28), cmap='RdBu', vmin = -1, vmax=1)
    # limit_common = max(abs(explaination[:, number]).max(), abs(explaination[:, number_compare]).max())
    # limit = abs(explaination[:, number]).max()
    # axs[i, 2].imshow(explaination[:, number].reshape(28,28), cmap='RdBu', vmin = -limit_common, vmax=limit_common)
    # limit = abs(explaination[:, number_compare]).max()
    axs[1, i].imshow(explaination[:, number_compare].reshape(28,28), cmap='RdBu', vmin = -limit_common, vmax=limit_common)
    axs[1, i].axis("off")
    sort_idx = np.argsort(explaination[:, number_compare])
    mask = sort_idx[0:round(explaination.shape[0] * 0.2)]
    # mask = (explaination[:, number_compare] < 0)
    input_img_masked = input_img.copy()
    input_img_masked[0, mask] = 0
    axs[2, i].imshow(input_img_masked.reshape(28,28), cmap='gray', vmin=0, vmax=1)
    axs[2, i].text(6, 35, name_of_class[i])
    axs[2, i].axis("off")

plt.savefig(working_dir + "fashion-mnist-cnn-relu.pdf")


# fig, axs = plt.subplots(1,4)
# input_img = features.numpy().reshape(1, 784)
# number = label
# number_compare = 7
# # mean = explaination.mean()
# # axs[0].imshow(explaination[:, number].reshape(28,28), cmap='gray', vmin = explaination[:, number].min(), vmax=explaination[:, number].max())
# axs[0].imshow(input_img.reshape(28,28), cmap='RdBu', vmin = -1, vmax=1)
# limit_common = max(abs(explaination[:, number]).max(), abs(explaination[:, number_compare]).max())
# limit = abs(explaination[:, number]).max()
# axs[1].imshow(explaination[:, number].reshape(28,28), cmap='RdBu', vmin = -limit_common, vmax=limit_common)
# limit = abs(explaination[:, number_compare]).max()
# axs[2].imshow(explaination[:, number_compare].reshape(28,28), cmap='RdBu', vmin = -limit_common, vmax=limit_common)
# sort_idx = np.argsort(explaination[:, number_compare])
# mask = sort_idx[0:round(explaination.shape[0] * 0.2)]
# # mask = (explaination[:, number_compare] < 0)
# input_img_masked = input_img.copy()
# input_img_masked[0, mask] = 0
# axs[3].imshow(input_img_masked.reshape(28,28), cmap='RdBu', vmin = -1, vmax=1)


In [None]:
print(explaination.max())
print(explaination.min())

In [None]:
%pip install shap
import shap

In [None]:
shap_values = []
for i in range(NUMBER_OF_CLASSES):
    shap_values.append(explaination[:, i].reshape(1, 28, 28, 1))

In [None]:
shap.image_plot(shap_values, input_img.reshape(1, 28, 28, 1))

In [None]:
# print(explaination[:, number])

In [None]:
print(output_data)

In [None]:
"""
End
"""