In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install captum
!pip install transformers==3

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
import matplotlib.pyplot as plt
import torch
from torch_geometric.datasets import TUDataset
import numpy as np
from termcolor import colored
from torchsummary import summary
from torch.autograd import Variable
from keras import backend as K
from statistics import mean
from sklearn import metrics
from copy import deepcopy
from captum.attr import Saliency
from scipy.spatial.distance import hamming
from itertools import zip_longest
from time import perf_counter
import csv
import torch.nn as nn
import torch_geometric.nn as gnn
from torch import Tensor
from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size
from typing import Callable, Union, Tuple
from torch_sparse import SparseTensor
from time import perf_counter
import random
import pandas
from torch_geometric.loader import DataLoader
import sklearn
import gcn_2l_model
#import model_train_test_gc
#from model_train_test_gc import Model_Train_Test

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class SA_GC(object):
  def __init__(self, task, method, graph, importance_range, start_epoch, input_dim, hid_dim, output_dim):

    self.GCN_model = self.load_model(task, method, load_index=start_epoch, input_dim=input_dim, hid_dim=hid_dim, output_dim=output_dim)
    self.criterion = torch.nn.CrossEntropyLoss()
    self.importance_dict = {}
    self.new_graph = self.drop_important_nodes(self.GCN_model, graph, importance_range)
    


  def load_model(self, task, method, load_index, input_dim, hid_dim, output_dim):

    if load_index != 0:
      GCN_model, optimizer, load_index = self.loading_config(task, method, load_index, input_dim, hid_dim, output_dim)
      return GCN_model
    else:
      GCN_model = gcn_2l_model.GCN_2Layer_Model(model_level='graph', dim_node=input_dim, dim_hidden=hid_dim, dim_output=output_dim)
      return GCN_model


  def loading_config(self, task, method, load_index, input_dim, hid_dim, output_dim):
    GCN_model = gcn_2l_model.GCN_2Layer_Model(model_level='graph', dim_node=input_dim, dim_hidden=hid_dim, dim_output=output_dim)
    optimizer = torch.optim.Adam(params = GCN_model.parameters(), lr=1e-4)
    checkpoint = torch.load("/content/drive/My Drive/Explainability Methods/"+str(method)+" on " + str(task) + "/Model/" + str(method) + " on " + str(task) + " classifier model_" + str(load_index)+".pt")
    GCN_model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    return GCN_model, optimizer, epoch
  

  def loss_calculations(self, preds, gtruth):
    loss_per_epoch = self.criterion(preds, gtruth)
    return loss_per_epoch

  def remove_nones(self, sample_grads):
    sample_grads2 = []
    for item in sample_grads:
      Each_Graph = []
      for item2 in item:
        if item2 != None:
          Each_Graph.append(torch.tensor(item2.clone().detach().requires_grad_(True),requires_grad=True))
        else:
          Each_Graph.append(torch.tensor(0))
      sample_grads2.append(Each_Graph)

    return sample_grads2

  def compute_grad(self, model, graph, with_respect):
    prediction = model(graph.x, graph.edge_index, graph.batch)
    if with_respect == 1 :
      loss = self.loss_calculations(prediction, graph.y)
      #print(loss)
    elif with_respect == 2:
      loss = self.loss_calculations(prediction, torch.tensor([0]))
      #print(loss)  
    elif with_respect == 3:
      loss = self.loss_calculations(prediction, torch.tensor([1]))
      #print(loss)
    return torch.autograd.grad(loss, list(self.GCN_model.parameters()), allow_unused=True)

  def compute_sample_grads(self, model, test_dataset, with_respect):

    sample_grads = [self.compute_grad(model, graph, with_respect) for graph in test_dataset]
    sample_grads = self.remove_nones(sample_grads)
    sample_grads = zip(*sample_grads)
    sample_grads = [torch.stack(shards) for shards in sample_grads]
    return sample_grads

  

  def compute_square_gradients(self, model, dataset):
    per_sample_grads_wrt_graph_label = self.compute_sample_grads(model, dataset, 1)
    per_sample_grads_wrt_class_zero = self.compute_sample_grads(model, dataset, 2)
    per_sample_grads_wrt_class_one = self.compute_sample_grads(model, dataset, 3)

    grads_wrt_graph_label = torch.square(per_sample_grads_wrt_graph_label[0])
    #square_grads_wrt_graph_label = torch.square(per_sample_grads_wrt_graph_label[1])
    square_grads_wrt_graph_label = grads_wrt_graph_label.detach().tolist()

    grads_wrt_class_zero = torch.square(per_sample_grads_wrt_class_zero[0])
    #square_grads_wrt_class_zero = torch.square(per_sample_grads_wrt_class_zero[1])
    square_grads_wrt_class_zero = grads_wrt_class_zero.detach().tolist()

    grads_wrt_class_one = torch.square(per_sample_grads_wrt_class_one[0])
    #square_grads_wrt_class_one = torch.square(per_sample_grads_wrt_class_one[1])
    square_grads_wrt_class_one = grads_wrt_class_one.detach().tolist()

    return square_grads_wrt_graph_label, square_grads_wrt_class_zero, square_grads_wrt_class_one
  

  def saliency(self, dataset, gradients):
    Final= []
    for i in range(len(dataset)):
      Mid = []
      for node in dataset[i].x.detach().numpy():
        First = []
        for grad_list in gradients[i]:
          First.append(np.multiply(node, grad_list))
        Mid.append(First)
      Final.append(Mid)
    
    Saliency_Nodes = []
    for graph in Final:
      Node = []
      for node in graph:
        Grad = []
        for grad in node:
          Grad.append(sum(grad))
        Node.append(sum(Grad))
      #norm = [(float(i)-min(Node))/(max(Node)-min(Node)) for i in Node]
      norm = [(float(i))/(max(Node) + 1e-16) for i in Node]
      Saliency_Nodes.append(norm)
    return Saliency_Nodes
  


  def is_salient(self, index, score, importance_range):
    start, end = importance_range
    if start <= score <= end:
      return True
    else:

      return False
  

  def drop_important_nodes(self, model, graph, importance_range):
    square_grads_wrt_graph_label, square_grads_wrt_class_zero, square_grads_wrt_class_one = self.compute_square_gradients(model, graph)
    print(len(square_grads_wrt_graph_label))
    SA_attribution_scores = self.saliency(graph, square_grads_wrt_graph_label)
    occluded_GNNgraph_list = []
    
    for i in range(len(SA_attribution_scores)):
      sample_graph = deepcopy(graph[i])
      graph_dict ={}
      for j in range(len(sample_graph.x)):
        
        if self.is_salient(j, (SA_attribution_scores[i][j]), importance_range):
          #print("before: ", sample_graph.x[j])
          sample_graph.x[j][:] = 0
          #print(torch.zeros_like(sample_graph.x[j]))
          #print("manipulated: ",sample_graph.x[j])
          graph_dict[j] = True
        else:
          graph_dict[j] = False
      self.importance_dict[i] = graph_dict
      occluded_GNNgraph_list.append(sample_graph)


    return occluded_GNNgraph_list



dataset = TUDataset(root='data/TUDataset', name='MUTAG')



#new_output = SA_GC(graph=[graph], importance_range=(0.5, 1), start_epoch=200, input_dim = len(graph.x[0]), hid_dim = 7, output_dim = 2)
new_output = SA_GC(task="Graph Classification", method="SA", graph=dataset, importance_range=(0.5, 1), start_epoch=200, input_dim = len(dataset[0].x[0]), hid_dim = 7, output_dim = 2)
print(new_output.new_graph[-1].x, dataset[-1].x)

print(new_output.importance_dict)


In [None]:
dataset = TUDataset(root='data/TUDataset', name='MUTAG')



#new_output = SA_GC(graph=[graph], importance_range=(0.5, 1), start_epoch=200, input_dim = len(graph.x[0]), hid_dim = 7, output_dim = 2)
new_output = SA_GC(task="Graph Classification", method="SA", graph=dataset, importance_range=(0.5, 1), start_epoch=200, input_dim = len(dataset[0].x[0]), hid_dim = 7, output_dim = 2)
print(new_output.new_graph[-1].x, dataset[-1].x)

print(new_output.importance_dict)