# import

In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Grayscale
from load_data import CustomMNISTDataset, get_MalariaCellImagesDataset
from utils import plot_map, test, get_ci, infer_data, get_ci
from tqdm import tqdm
from models import SOMNetwork
from matplotlib import ticker
import matplotlib.pyplot as plt
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

root = os.getcwd()
device = 'cpu'

# Load Data

In [None]:
training_data = CustomMNISTDataset(
    root=f'{root}/data/MNIST',
    train=True,
    transform=ToTensor()
)
testing_data = CustomMNISTDataset(
    root=f'{root}/data/MNIST',
    train=False,
    transform=ToTensor()
)
test_aug_data = CustomMNISTDataset(
    root=f'{root}/data/MNIST',
    train=False,
    transform=ToTensor(),
    augmentation=True
)

# training_data = datasets.FashionMNIST(
#     root=f"{root}/data",
#     train=True,
#     download=True,
#     transform=ToTensor()
# )

# testing_data = datasets.FashionMNIST(
#     root=f"{root}/data",
#     train=False,
#     download=True,
#     transform=ToTensor()
# )
batch_size = 60000
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(testing_data, batch_size=batch_size, shuffle=True)
train_dataloader, valid_dataloader, test_dataloader = get_MalariaCellImagesDataset(root=f"{root}/data/cell_images/", resize=[28, 28], valid_size=0.0, test_size = 0.2, batch_size=batch_size, shuffle=True)
# test_aug_dataloader = DataLoader(test_aug_data, batch_size=batch_size, shuffle=True)
X, y = next(iter(train_dataloader))

# Load Model

In [None]:
loss_fn = nn.CrossEntropyLoss()

stride = 4
in_channels = X.shape[1]
out_channels = 2

# filters = [(1, 6), (3, 1), (2, 1), (1, 1)]
# kernels = [(10, 10), (15, 15), (25, 25), (35, 35)]
# model_dict = torch.load(f'{root}/result_pth/our_300_97.19_mnist_4layer_3stride_18412111_0702.pth', map_location=device)

filters = [(2, 2), (1, 3), (3, 1), (1, 1)]
kernels = [(15, 15), (25, 25), (30, 30), (45, 45)]
model_dict = torch.load(f'{root}/result_pth/our_2000_85.44728724369442_malaria_4layer_4stride_22133111_15253045.pth', map_location=device)

model = SOMNetwork(stride=stride, in_channels=in_channels, out_channels=out_channels, kernels=kernels, filters=filters).to(device)
model.load_state_dict(model_dict)
print("Train: \n\tAccuracy: {}, Avg loss: {} \n".format(*test(train_dataloader, model, loss_fn, device=device))) 
print("Test: \n\tAccuracy: {}, Avg loss: {} \n".format(*test(test_dataloader, model, loss_fn, device=device))) 

In [None]:
model

# Load Params

In [None]:
FMs = {}
if model.in_channels == 1:
    FMs[1] = model.layer1[0].weight
elif model.in_channels == 3:
    FMs[1] = {
        'gray': model.layer1[0].gray_weight,
        'rgb': torch.repeat_interleave(torch.repeat_interleave(model.layer1[0].rgb_weight.reshape(model.layer1[0].rgb_weight.shape[0], model.layer1[0].rgb_weight.shape[1], 1, 1), model.layer1[0].kernel_size[0], dim=2), model.layer1[0].kernel_size[1], dim=3)
    }
FMs[2] = model.layer2[0].weight
FMs[3] = model.layer3[0].weight
FMs[4] = model.layer4[0].weight

In [None]:
RMs = {}
RMs[1] = model.layer1[0:-1](X).permute(0, 2, 3, 1).reshape(X.shape[0],-1, model.layer1[0].out_channels)
RMs[2] = torch.nn.Sequential(*(list(model.layer1)+list(model.layer2)))[0:-1](X).squeeze().reshape(X.shape[0], -1, model.layer2[0].out_channels)
RMs[3] = torch.nn.Sequential(*(list(model.layer1)+list(model.layer2)+list(model.layer3)))[0:-1](X).squeeze().reshape(X.shape[0], -1, model.layer3[0].out_channels)
RMs[4] = torch.nn.Sequential(*(list(model.layer1)+list(model.layer2)+list(model.layer3)+list(model.layer4)))(X).squeeze().reshape(X.shape[0], -1, model.layer4[0].out_channels)

In [None]:
CIs = {}
CIs[1] = get_ci(X, layer=torch.nn.Sequential(*(list(model.layer1)))[0], n_filters=model.layer1[0].weight.shape[0])
CIs[2] = get_ci(X, layer=torch.nn.Sequential(*(list(model.layer1)+list(model.layer2)))[0:4], sfm_filter=model.layer1[2].filter, n_filters=model.layer2[0].weight.shape[0])
CIs[3] = get_ci(X, layer=torch.nn.Sequential(*(list(model.layer1)+list(model.layer2)+list(model.layer3)))[0:7], sfm_filter=tuple(np.multiply(model.layer1[2].filter, model.layer2[2].filter)), n_filters=model.layer3[0].weight.shape[0])
CIs[4] = get_ci(X, layer=torch.nn.Sequential(*(list(model.layer1)+list(model.layer2)+list(model.layer3)+list(model.layer4)))[0:10], sfm_filter=tuple(np.multiply(np.multiply(model.layer1[2].filter, model.layer2[2].filter), model.layer3[2].filter)), n_filters=model.layer4[0].weight.shape[0])

# Plot FM

In [None]:
if in_channels == 1:
    plot_map(FMs[1].permute(0, 2, 3, 1).reshape(int(len(model.layer1[0].weight)**0.5), int(len(model.layer1[0].weight)**0.5), *model.layer1[0].kernel_size, 1).detach().cpu().numpy())
plot_map(FMs[2].permute(0, 2, 3, 1).reshape(int(len(model.layer2[0].weight)**0.5), int(len(model.layer2[0].weight)**0.5), *model.layer2[0].kernel_size, 1).detach().cpu().numpy())
plot_map(FMs[3].permute(0, 2, 3, 1).reshape(int(len(model.layer3[0].weight)**0.5), int(len(model.layer3[0].weight)**0.5), *model.layer3[0].kernel_size, 1).detach().cpu().numpy())
plot_map(FMs[4].permute(0, 2, 3, 1).reshape(int(len(model.layer4[0].weight)**0.5), int(len(model.layer4[0].weight)**0.5), *model.layer4[0].kernel_size, 1).detach().cpu().numpy())

# Plot CI

In [None]:
plot_map(CIs[1].reshape(int(len(model.layer1[0].weight)**0.5), int(len(model.layer1[0].weight)**0.5), in_channels, *model.layer1[0].kernel_size).permute(0, 1, 3, 4, 2))
plot_map(CIs[2].reshape(int(len(model.layer2[0].weight)**0.5), int(len(model.layer2[0].weight)**0.5), in_channels, *(np.array(model.layer1[2].filter) * model.layer1[0].kernel_size)).permute(0, 1, 3, 4, 2))
plot_map(CIs[3].reshape(int(len(model.layer3[0].weight)**0.5), int(len(model.layer3[0].weight)**0.5), in_channels, *(np.multiply(model.layer1[2].filter, model.layer2[2].filter) * model.layer1[0].kernel_size)).permute(0, 1, 3, 4, 2))
plot_map(CIs[4].reshape(int(len(model.layer4[0].weight)**0.5), int(len(model.layer4[0].weight)**0.5), in_channels, *(np.multiply(np.multiply(model.layer1[2].filter, model.layer2[2].filter), model.layer3[2].filter) * model.layer1[0].kernel_size)).permute(0, 1, 3, 4, 2))

# Explainability

### Plot explainable process of a single image for MNIST and Fashion MNIST

In [None]:
for target in range(0, out_channels):
    all_target_idx = np.where(y == target)[0]
    to_show = 1
    save = f"mnist-16-{target}"
    save = None

    for i in range(len(all_target_idx)):
        if to_show == 0: break
        input_x = X[all_target_idx[i]][None, :, :, :]
        pred = model(input_x)
        if(pred.argmax() == target): continue
        to_show -= 1
        np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
        print(f"target: {target}, pred: {pred.argmax()}\nProb.: {pred.detach().numpy()}")
        infer_data(input_x, model, [CIs[1].squeeze(), CIs[2].squeeze(), CIs[3].squeeze(), CIs[4].squeeze()], save=save)

### Plot explainable process of a single image for Malaria

In [None]:
target = 1 # 0: infected, 1: uninfected
count = 2
for idx in range(100):
    if count == 0:
        break
    pred_p = model(X[y==target][idx:idx+1])
    pred = pred_p.argmax()
    if pred == target: # if equal to target, that is successfully classified, if not, that is misclassified
        continue
    np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
    print(f"target: {target}, pred: {pred.argmax()}\nProb.: {pred_p.detach().numpy()}")
    count -= 1
    plt.imshow(X[y==target][idx].permute(1, 2, 0))
    plt.axis('off')
    plt.show()
    plot_map(RMs[1][y==target][idx].reshape(6, 6, 15, 15, 1).detach().numpy())
    plot_map(FMs[1]['gray'][torch.topk(RMs[1][y==target][idx][:, 0:206], k=1, dim=1)[1]].reshape(6, 6, 5, 5, 1).detach().numpy())
    plot_map(FMs[1]['rgb'][torch.topk(RMs[1][y==target][idx][:, 206:], k=1, dim=1)[1]].reshape(6, 6, 3, 5, 5).permute(0, 1, 3, 4, 2).detach().numpy())
    plot_map(RMs[2][y==target][idx].reshape(3, 3, 25, 25, 1).detach().numpy())
    plot_map(CIs[2][torch.topk(RMs[2][y==target][idx], k=1, dim=1)[1]].reshape(*model.layer2[2].shape, X.shape[1], 10, 10).permute(0, 1, 3, 4, 2))
    plot_map(RMs[3][y==target][idx].reshape(3, 1, 30, 30, 1).detach().numpy())
    plot_map(CIs[3][torch.topk(RMs[3][y==target][idx], k=1, dim=1)[1]].reshape(*model.layer3[2].shape, X.shape[1], 10, 30).permute(0, 1, 3, 4, 2))
    plt.imshow(RMs[4][y==target][idx].reshape(45, 45, 1).detach().numpy())
    plt.axis('off')
    plt.show()
    plot_map(CIs[4][torch.topk(RMs[4][y==target][idx], k=20, dim=1)[1]].mean(1).reshape(1, 1, X.shape[1], 30, 30).permute(0, 1, 3, 4, 2))

# Explanation method for Fully Connected Layer

In [None]:
target_neuron = 1
k = 10
rm_idx = 0
n_FM4 = kernels[3][0]*kernels[3][1]
target_RM = RMs[4][y==target_neuron]

plt.imshow(target_RM[rm_idx].reshape(*kernels[3]).detach().numpy())
plt.show()
plt.imshow(model.fc1.weight[target_neuron].reshape(*kernels[3]).detach().numpy())
plt.show()
plt.imshow((target_RM[rm_idx].reshape(-1, n_FM4) * model.fc1.weight[target_neuron]).reshape(*kernels[3]).detach().numpy())
plt.show()

wx_topk_list = {}
x_topk_list = {}
for i in tqdm(range(len(target_RM))):
    wx = target_RM[i].reshape(-1, n_FM4) * model.fc1.weight[target_neuron]
    wx_topk = torch.topk(wx, k)[1][0]
    x_topk = torch.topk(target_RM[i].reshape(-1, n_FM4), k)[1][0]
    for j in range(k):
        wx_key = int(wx_topk[j]) 
        if wx_key in wx_topk_list:
            wx_topk_list[wx_key] += 1
        else:
            wx_topk_list[wx_key] = 1
            
        x_key = int(x_topk[j]) 
        if x_key in x_topk_list:
            x_topk_list[x_key] += 1
        else:
            x_topk_list[x_key] = 1
            
wx_topk_list = dict(sorted(wx_topk_list.items(), key=lambda item: item[1]))
x_topk_list = dict(sorted(x_topk_list.items(), key=lambda item: item[1]))

In [None]:
wx_topk_list

In [None]:
x_topk_list

In [None]:
print(f"Values and indices from top k of RM4:\n{torch.topk(RMs[4].reshape(-1, 1, 1, *kernels[3])[y==target_neuron][0][0][0].reshape(n_FM4), k)}\n")
print(f"Values and indices from top k of FC layer:\n{torch.topk(model.fc1.weight[target_neuron], k)}\n")
print(f"Values and indices from top k of RM4*weight:\n{torch.topk((RMs[4].reshape(-1, 1, 1, *kernels[3])[y==target_neuron][0][0][0].reshape(-1, n_FM4) * model.fc1.weight[target_neuron]), k)}\n")

# RBF

In [None]:
def traingle(x, m, w):
    dist = torch.dist(x, m)
    return torch.ones_like(dist) - torch.minimum(dist, w) / w

def trapezoidal(x, m, w, b):
    dist = torch.dist(x, m)
    tri = (torch.ones_like(dist)  - torch.minimum(dist, w) / w )
    return tri - torch.where((tri <= b), tri, torch.tensor(0.0))

def gaussian(x, m, w):
    dist = torch.dist(x, m)
    return torch.exp(-(dist ** 2) / (2 * w ** 2))

def cReLU(x, b):
    return x * torch.ge(x, b).float()
def ReLU(x):
    return x * torch.ge(x, 0).float()

    
m = torch.tensor(0.0)
w = torch.tensor(5)
b = torch.tensor(0.4)

x = torch.linspace(m - 2 * w, m + 2 * w, 1000)
y0 = [gaussian(i, m, w) for i in x]
y1 = [traingle(i, m, w) for i in x]
y2 = [trapezoidal(i, m, w, b) for i in x]
y3 = [gaussian(i, torch.tensor(0.0), torch.tensor(0.447)) for i in x]
y4 = [gaussian(i, torch.tensor(0.0), torch.tensor(1.0)) for i in x]
y5 = [gaussian(i, torch.tensor(0.0), torch.tensor(2.23)) for i in x]
y6 = [gaussian(i, torch.tensor(-2.0), torch.tensor(0.707)) for i in x]
x1 = torch.linspace(0.4, m + 2 * w, 1000)
cr1 = [cReLU(i, b) for i in x1]
x2 = torch.linspace(m - 2 * w, 0.4-(1e-7), 1000)
cr2 = [cReLU(i, b) for i in x2]
relu = [ReLU(i) for i in x]

plt.plot(x, y0)
plt.title(f"m = {m.item()}, σ = {w.item()}")
plt.xlabel('x')
plt.ylabel('ϕ(x)')
plt.show()

plt.plot(x, y1)
plt.title(f"m = {m.item()}, w = {w.item()}")
plt.xlabel('x')
plt.ylabel('ϕ(x)')
plt.show()

plt.plot(x, y2)
plt.title(f"m = {m.item()}, w = {w.item()}, b = {0.4}")
plt.xlabel('x')
plt.ylabel('ϕ(x)')
plt.show()

plt.plot(x, y3, label='m =  0, σ = 0.5')
plt.plot(x, y4, label='m =  0, σ = 1.0', color="red")
plt.plot(x, y5, label='m =  0, σ = 2.2')
plt.plot(x, y6, label='m = -2, σ = 0.7')
p1 = (-1, gaussian(torch.tensor(-1.0), torch.tensor(0.0), torch.tensor(1.0)))
plt.plot(*p1, 'go') 
plt.annotate(f"{p1[0], round(p1[1].item(), 1)}", (p1[0]-0.1, round(p1[1].item(), 1)-0.1))
p2 = (2, gaussian(torch.tensor(2.0), torch.tensor(0.0), torch.tensor(1.0)))
plt.plot(*p2, 'go') 
plt.annotate(f"{p2[0], round(p2[1].item(), 1)}", (p2[0]+0.1, round(p2[1].item(), 1)+0.1))

plt.xlim([-5, 5])
plt.ylim([0, 1])
plt.legend()
plt.xlabel('x')
plt.ylabel('ϕ(x)')
plt.show()

plt.plot([0, 0], [-10, 10], color="black")
plt.plot(x1, cr1, color="blue")
plt.plot(x2, cr2, color="blue")
plt.title(f"c = {0.4}")
plt.xlim([-1, 1])
plt.ylim([-0.1, 1])
ax = plt.gca()

ax.xaxis.set_major_locator(ticker.MultipleLocator(0.2))
plt.xlabel('x')
plt.ylabel('c ReLU (x)')
plt.show()

plt.plot([0, 0], [-10, 10], color="black")
plt.plot(x, relu, color="blue")
plt.xlim([-1, 1])
plt.ylim([-0.1, 1])
ax = plt.gca()

ax.xaxis.set_major_locator(ticker.MultipleLocator(0.2))
plt.xlabel('x')
plt.ylabel('ReLU (x)')
plt.show()


w = np.array([1, -2])  
b = 0  

num_points = 100
x = np.random.uniform(-5, 5, (num_points, 2))  
x = np.append(x, np.array([[10, 2]]), axis=0)
x = np.append(x, np.array([[1, -2]]), axis=0)


labels = np.sign(np.dot(x, w) + b)

plt.scatter(x[:, 0], x[:, 1], c=labels)
plt.plot([-5, 10], [(-w[0] * (-5) - b) / w[1], (-w[0] * 10 - b) / w[1]], 'r')
plt.annotate('(1, -2)', xy=(1, -2), xytext=(3, -1), arrowprops=dict(arrowstyle='->'))
plt.annotate('(10, 2)', xy=(10, 2), xytext=(9, 0), arrowprops=dict(arrowstyle='->'))
plt.xlabel('x')
plt.ylabel('y')

hyperplane_x = 0  
hyperplane_y = (-w[0] * hyperplane_x - b) / w[1]  

plt.annotate('Decision Boundary', xy=(hyperplane_x, hyperplane_y), xytext=(1, 5), arrowprops=dict(arrowstyle='->'))

plt.show()