In [37]:
from GNNModels.Models import *

import torch
from tqdm import tqdm
import numpy as np

# Importing the Model

In [2]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset_name = "Cora"

dataset = Planetoid(root='/tmp/Planetoid', name=dataset_name, transform=NormalizeFeatures())
data = dataset[0]  # Get the first graph object.

In [23]:
# This is temporary model training, will be replaced with improting pretrained model, having problems with it currently

from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(1)
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        #x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN(hidden_channels=16)

model = GCN(hidden_channels=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()  # Clear gradients.
      out = model(data.x, data.edge_index)  # Perform a single forward pass.
      loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
      val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
      loss.backward()  # Derive gradients.
      optimizer.step()  # Update parameters based on gradients.
      return loss, val_loss

def test():
      model.eval()
      out = model(data.x, data.edge_index)
      pred = out.argmax(dim=1)  # Use the class with highest probability.
      test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.
      test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.
      return test_acc, out


for epoch in range(200):
    loss, val_loss = train()
    if epoch%200 == 0:
          print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Validation loss: {val_loss:.4f}')

test_acc, out = test()
print(f'Test Accuracy: {test_acc:.4f}')

Epoch: 000, Loss: 1.9467, Validation loss: 1.9455
Test Accuracy: 0.8080


# Explanation Methods

In [16]:
# %cd GraphXAI-main
# !pip install -e .

from graphxai.explainers import GNNExplainer, PGExplainer, IntegratedGradExplainer, PGMExplainer

# the ones below we want to use from different libraries
from graphxai.explainers import GNN_LRP, CAM

# need to also use subgraph x from DIG


In [29]:
# GNN Explainer - discrete mask of node imp, soft mask of edge imp

gnnexp = GNNExplainer(model)

def gnn_imp_nodes(node_idx):

    node_exp = gnnexp.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index)

    imp_nodes = []

    for k in node_exp.node_reference.keys():

        if node_exp.node_imp[node_exp.node_reference[k]].item() == 1:

            imp_nodes.append(k)

    return imp_nodes

# PGE Explainer - discrete maks of node imp, discrete mask of edge imp

# needs name of emb layer of the model
pgex = PGExplainer(model, emb_layer_name = 'conv2',  max_epochs = 100, lr = 0.1)
pgex.train_explanation_model(data)

def pge_imp_nodes(node_idx):

    node_exp = pgex.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index)

    imp_nodes = []

    for k in node_exp.node_reference.keys():

        if node_exp.node_imp[node_exp.node_reference[k]].item() == 1:

            imp_nodes.append(k)

    return imp_nodes

140it [00:00, 653.05it/s]
140it [00:00, 330.41it/s]
140it [00:00, 495.24it/s]
140it [00:00, 517.41it/s]
140it [00:00, 425.18it/s]
140it [00:00, 490.66it/s]
140it [00:00, 379.95it/s]
140it [00:00, 464.86it/s]
140it [00:00, 454.31it/s]
140it [00:00, 470.79it/s]
140it [00:00, 476.03it/s]
140it [00:00, 303.48it/s]
140it [00:00, 460.86it/s]
140it [00:00, 508.38it/s]
140it [00:00, 370.67it/s]
140it [00:00, 412.32it/s]
140it [00:00, 504.03it/s]
140it [00:00, 391.87it/s]
140it [00:00, 438.05it/s]
140it [00:00, 497.66it/s]
140it [00:00, 549.34it/s]
140it [00:00, 457.54it/s]
140it [00:00, 481.60it/s]
140it [00:00, 483.72it/s]
140it [00:00, 271.37it/s]
140it [00:00, 412.29it/s]
140it [00:00, 539.95it/s]
140it [00:00, 563.08it/s]
140it [00:00, 531.17it/s]
140it [00:00, 555.29it/s]
140it [00:00, 582.56it/s]
140it [00:00, 603.93it/s]
140it [00:00, 557.84it/s]
140it [00:00, 353.83it/s]
140it [00:00, 563.04it/s]
140it [00:00, 529.90it/s]
140it [00:00, 588.11it/s]
140it [00:00, 506.21it/s]
140it [00:00

training time is 33.788s





In [40]:
# Integrated gradients - soft mask of edge imp

igex = IntegratedGradExplainer(model, criterion=criterion)

def ig_imp_nodes(node_idx):

    node_exp = igex.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index, y = data.y)

    imp_nodes = []

    mask = torch.sigmoid(node_exp.node_imp) >= 0.5

    for k in node_exp.node_reference.keys():

        if mask[node_exp.node_reference[k]].item() == 1:
        
            imp_nodes.append(k)

    return imp_nodes

# PGME Explainer - discrete mask of node imp, randomised, can get ranking as well by asking for top 1 then 2 and so on

pgm = PGMExplainer(model, explain_graph=False)

def pgm_imp_nodes(node_idx, top = None):

    np.random.seed(1998)

    if top == None:

        node_exp = pgm.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index)

    else: 

        node_exp = pgm.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index, top_k_nodes=top)


    imp_nodes = []

    mask = torch.sigmoid(node_exp.node_imp) >= 0.5

    for k in node_exp.node_reference.keys():

        if mask[node_exp.node_reference[k]].item() == 1:
        
            imp_nodes.append(k)

    return imp_nodes

# Metrics Calculation

In [4]:
# calculates jacard similarity of 2 lists

def jaccard(list1, list2):
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union

In [24]:
out = model(data.x, data.edge_index)

node_idx = 1

print("Prediction : {}".format(out[node_idx].argmax()))
print("True Class : {}".format(data.y[node_idx]))

Prediction : 4
True Class : 4


In [32]:
ig_imp_nodes(node_idx)

[1986]

In [33]:
gnn_imp_nodes(node_idx)

[1, 2, 332, 654, 1454, 1986]

In [35]:
pge_imp_nodes(node_idx)

[]

In [41]:
pgm_imp_nodes(node_idx)

  0%|          | 0/1000000 [00:00<?, ?it/s]

[1, 2, 332, 470, 652, 654, 1454, 1666, 1986]