In [68]:
import pickle
import torch.nn as nn
import cv2
import torch
import numpy as np
from torchvision.models import vgg

In [69]:
def spike_function(input):
    out = torch.zeros_like(input).cuda()
    out[input > 0] = 1.0
    return out


class svgg19(vgg.VGG):
    def __init__(self, num_steps=25, leak_mem=0.95, img_size=224, num_cls=50):
        super(svgg19, self).__init__(vgg.make_layers([64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], False))
        self.num_steps = num_steps
        self.num_cls = num_cls
        self.batch_num = self.num_steps
        self.leak_mem = leak_mem
        self.img_size = img_size


        bias_flag = False

        for m in self.modules():
            if (isinstance(m, nn.Conv2d)):
                m.threshold = 1.0
                torch.nn.init.xavier_uniform_(m.weight, gain=2)
            elif (isinstance(m, nn.Linear)):
                m.threshold = 1.0
                torch.nn.init.xavier_uniform_(m.weight, gain=2)


    def forward(self, x):
        with torch.no_grad():

            batch_size = x.size(0)
            mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda()
            mem_conv2 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda()
            mem_conv3 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda()
            mem_conv4 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda()
            mem_conv5 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda()
            mem_conv6 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda()
            mem_conv7 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda()
            mem_conv8 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda()
            mem_conv9 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda()
            mem_conv10 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda()
            mem_conv11 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda()
            mem_conv12 = torch.zeros(batch_size, 512, self.img_size//8, self.img_size//8).cuda()
            mem_conv13 = torch.zeros(batch_size, 512, self.img_size//16, self.img_size//16).cuda()
            mem_conv14 = torch.zeros(batch_size, 512, self.img_size//16, self.img_size//16).cuda()
            mem_conv15 = torch.zeros(batch_size, 512, self.img_size//16, self.img_size//16).cuda()
            mem_conv16 = torch.zeros(batch_size, 512, self.img_size//16, self.img_size//16).cuda()
            mem_conv_list = [mem_conv1, mem_conv2, mem_conv3, mem_conv4, mem_conv5, mem_conv6, mem_conv7, mem_conv8, mem_conv9, mem_conv10, mem_conv11, mem_conv12, mem_conv13, mem_conv14, mem_conv15, mem_conv16]

            mem_fc1 = torch.zeros(batch_size, 4096).cuda()
            mem_fc2 = torch.zeros(batch_size, 4096).cuda()
            mem_fc3 = torch.zeros(batch_size, self.num_cls).cuda()

            mem_fc_list = [mem_fc1, mem_fc2, mem_fc3]

            # To collect all activation maps and return
            all_activation_maps = [list() for conv in mem_conv_list]

            for t in range(self.num_steps):
                out_prev = x

                # Keep track of which actual CONV layer were currently pushing input through
                conv_layer_counter = 0
                for module_idx, module in enumerate(self.features):
                    if isinstance(module, nn.Conv2d):
                      out_prev = module(out_prev)
                    # Replace ReLU by the spiking activation function
                    elif isinstance(module, nn.ReLU):
                      mem_conv_list[conv_layer_counter] = self.leak_mem * mem_conv_list[conv_layer_counter] + out_prev
                      mem_thr = (mem_conv_list[conv_layer_counter] / self.features[module_idx-1].threshold) - 1.0
                      out = spike_function(mem_thr)
                      rst = torch.zeros_like(mem_conv_list[conv_layer_counter]).cuda()
                      rst[mem_thr > 0] = self.features[module_idx-1].threshold
                      mem_conv_list[conv_layer_counter] = mem_conv_list[conv_layer_counter] - rst
                      out_prev = out.clone()

                      # To collect
                      all_activation_maps[conv_layer_counter].append(out_prev.detach())

                      conv_layer_counter += 1
                    elif isinstance(module, nn.MaxPool2d):
                      out = module(out_prev)
                      out_prev = out.clone()

                out = self.avgpool(out_prev)
                out_prev = out.clone()

                out_prev = out_prev.reshape(batch_size, -1)

                fc_counter = 0
                for module_idx, module in enumerate(self.classifier):
                      if isinstance(module, nn.Linear):
                        out_prev = module(out_prev)
                      elif isinstance(module, nn.ReLU):
                          mem_fc_list[fc_counter] = self.leak_mem * mem_fc_list[fc_counter] + out_prev
                          mem_thr = (mem_fc_list[fc_counter] / self.classifier[module_idx-1].threshold) - 1.0
                          out = spike_function(mem_thr)
                          rst = torch.zeros_like(mem_fc_list[fc_counter]).cuda()
                          rst[mem_thr > 0] = self.classifier[module_idx-1].threshold
                          mem_fc_list[fc_counter] = mem_fc_list[fc_counter] - rst
                          out_prev = out.clone()
                          fc_counter += 1
                      elif isinstance(module, nn.Dropout):
                        out_prev = module(out_prev)

                mem_fc_list[-1] = mem_fc_list[-1] + out_prev

            average_activation_spikemaps = [torch.mean(torch.stack(s, dim=0).permute(1, 0, 2, 3, 4), dim=1) for s in all_activation_maps]

            out_voltage = mem_fc_list[-1] / self.num_steps

            return out_voltage, average_activation_spikemaps

In [23]:
# Define model
model = svgg19(num_steps=25, leak_mem=0.95, img_size=224, num_cls=50)
model.classifier._modules['6'] = nn.Linear(4096, 50)
model.load_state_dict(torch.load("./drive/MyDrive/Thesis/SNN_grad/svgg19_AwA2.pth"))

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model.to(device)
model.eval()

svgg19(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), p

In [70]:
# Load w for relevant filters in model
with open('./drive/MyDrive/Thesis/generate_explanation/lambda20_w.npy', 'rb') as file_add:
        w = pickle.load(file_add)

# Decode w
w_decoded = []

list_of_conv_layers = [module for module in model.modules() if isinstance(module, nn.Conv2d)]
for class_index in range(50):
    w_row = w[class_index]

    non_zero_w_indices = torch.nonzero(torch.from_numpy(w[class_index]))
    sub_w_rows = []
    current_w_start_idx = 0
    for conv_idx in range(len(list_of_conv_layers)):
        to_append = []
        current_w_end_idx = current_w_start_idx + list_of_conv_layers[conv_idx].out_channels
        for non_zero_w_idx in non_zero_w_indices:
            if current_w_start_idx <= non_zero_w_idx < current_w_end_idx:
                to_append.append(non_zero_w_idx - current_w_start_idx)
        sub_w_rows.append(to_append)
        current_w_start_idx = current_w_end_idx
    w_decoded.append(sub_w_rows)

In [71]:
# Load image
from PIL import Image

# Replace this with the image from which you want to generate explanatiosn
input_image_path = "./unnamed.jpg"
image = Image.open(input_image_path).convert("RGB")

# Transform image and prepare for pushing through model
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

input_tensor = transform(image)  # shape: [C, H, W]
input_batch = input_tensor.unsqueeze(0)

# Put image on correct device
input_batch = input_batch.to(device)

In [72]:
# Push image through model and collect predictions + the average activation maps
with torch.no_grad():
    predictions, average_activation_maps = model(input_batch)

In [73]:
# Select relevant filters from the predicted class
predicted_class = torch.argmax(predictions, dim=1)
relevant_filters = w_decoded[predicted_class]
print(relevant_filters)

[[tensor([42])], [tensor([1])], [], [], [], [], [], [], [], [], [], [], [tensor([469])], [tensor([327])], [tensor([241])], [tensor([21]), tensor([123]), tensor([153]), tensor([434]), tensor([481]), tensor([500])]]


In [74]:

### Helper functions for generating the heatmap superimposed on the input image ###

def write_image(write_add, A):
    cv2.imwrite(write_add, np.uint8(A * 255))

#superimposing heatmap (i.e. colored featuremap) on the input image
def superimposing(image, heatmap):
    vis = cv2.addWeighted(heatmap, 0.6, image, 0.4, 0)
    return vis / np.max(vis)


# apply color map on resized featuremap
def apply_colormap(A):
    A = cv2.applyColorMap(np.uint8(255 * A), cv2.COLORMAP_JET)
    return np.float32(A) / 255

#normalizing a map
def normalize_numpy(A):
    if np.max(A) == 0.0:
        return A
    return (A - np.min(A)) / (np.max(A) - np.min(A))

#reading input image
def read_image(image_address, shape):
    image_name = image_address.split('/')[-1]
    image = cv2.imread(image_address)

    image = cv2.resize(image, dsize=(shape, shape))
    image = np.float32(image)
    image = normalize_numpy(image)
    image = image[:, :, ::-1]
    return image, image_name

In [85]:
# Generate explanatory heatmaps
for layer_idx, relevant_filters_per_layer in enumerate(relevant_filters):
  if len(relevant_filters_per_layer) != 0:
    summed_activation_map = torch.zeros_like(average_activation_maps[layer_idx][0][0])
    for relevant_filter in relevant_filters_per_layer:
      summed_activation_map += average_activation_maps[layer_idx][0][relevant_filter.item()]

      upsampled_map = cv2.resize(summed_activation_map.cpu().detach().numpy(), dsize=(224, 224))
      upsampled_map = normalize_numpy(upsampled_map)
      heatmap = apply_colormap(upsampled_map)
      image, image_name = read_image(input_image_path, 224)
      image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
      vis = superimposing(image, heatmap)
      write_image('./drive/MyDrive/Thesis/generate_explanation/' + str(layer_idx) + '.jpg', vis)