In [1]:
from GNNModels.Models import *

import torch
from tqdm import tqdm
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


# Importing the Model

In [2]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset_name = "Cora"

dataset = Planetoid(root='/tmp/Planetoid', name=dataset_name, transform=NormalizeFeatures())
data = dataset[0]  # Get the first graph object.

In [3]:
# # This is temporary model training, will be replaced with improting pretrained model, having problems with it currently

# from torch_geometric.nn import GCNConv

# class GCN(torch.nn.Module):
#     def __init__(self, hidden_channels):
#         super().__init__()
#         torch.manual_seed(1)
#         self.conv1 = GCNConv(dataset.num_features, hidden_channels)
#         self.conv2 = GCNConv(hidden_channels, dataset.num_classes)

#     def forward(self, x, edge_index):
#         x = self.conv1(x, edge_index)
#         x = x.relu()
#         #x = F.dropout(x, p=0.5, training=self.training)
#         x = self.conv2(x, edge_index)
#         return x

# model = GCN(hidden_channels=16)

# model = GCN(hidden_channels=32)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# criterion = torch.nn.CrossEntropyLoss()

# def train():
#       model.train()
#       optimizer.zero_grad()  # Clear gradients.
#       out = model(data.x, data.edge_index)  # Perform a single forward pass.
#       loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
#       val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
#       loss.backward()  # Derive gradients.
#       optimizer.step()  # Update parameters based on gradients.
#       return loss, val_loss

# def test():
#       model.eval()
#       out = model(data.x, data.edge_index)
#       pred = out.argmax(dim=1)  # Use the class with highest probability.
#       test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.
#       test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.
#       return test_acc, out


# for epoch in range(200):
#     loss, val_loss = train()
#     if epoch%200 == 0:
#           print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Validation loss: {val_loss:.4f}')

# test_acc, out = test()
# print(f'Test Accuracy: {test_acc:.4f}')

In [4]:
# Using pretrained model from GNNModels

model=get_model_pretrained('GCN','Cora',path='GNNModels/checkpoints/')
criterion=torch.nn.CrossEntropyLoss()


# Explanation Methods

In [5]:
# %cd GrapthXAI-main
# !pip insall -e .

from graphxai.explainers import GNNExplainer, PGExplainer, IntegratedGradExplainer, PGMExplainer

# the ones below we want to use from different libraries
from graphxai.explainers import GNN_LRP, CAM

# need to also use subgraph x from DIG


## GNN Explainer and PGE Explainer

In [6]:
# GNN Explainer - discrete mask of node imp, soft mask of edge imp

gnnexp = GNNExplainer(model)

def gnn_imp_nodes(node_idx):

    node_exp = gnnexp.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index)

    imp_nodes = []

    for k in node_exp.node_reference.keys():

        if node_exp.node_imp[node_exp.node_reference[k]].item() == 1:

            imp_nodes.append(k)

    return imp_nodes

# PGE Explainer - discrete maks of node imp, discrete mask of edge imp

# needs name of emb layer of the model
pgex = PGExplainer(model, emb_layer_name = 'conv2',  max_epochs = 500, lr = 0.01)
pgex.train_explanation_model(data)

def pge_imp_nodes(node_idx):

    node_exp = pgex.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index)

    imp_nodes = []

    for k in node_exp.node_reference.keys():

        if node_exp.node_imp[node_exp.node_reference[k]].item() == 1:

            imp_nodes.append(k)

    return imp_nodes

140it [00:00, 1252.20it/s]
140it [00:00, 974.16it/s]
140it [00:00, 721.94it/s]
140it [00:00, 616.34it/s]
140it [00:00, 827.37it/s]
140it [00:00, 865.31it/s]
140it [00:00, 664.02it/s]
140it [00:00, 633.78it/s]
140it [00:00, 834.34it/s]
140it [00:00, 951.10it/s]
140it [00:00, 968.33it/s]
140it [00:00, 819.82it/s]
140it [00:00, 886.23it/s]
140it [00:00, 984.67it/s]
140it [00:00, 949.69it/s]
140it [00:00, 889.28it/s]
140it [00:00, 904.95it/s]
140it [00:00, 959.74it/s]
140it [00:00, 959.50it/s]
140it [00:00, 784.38it/s]
140it [00:00, 635.76it/s]
140it [00:00, 875.31it/s]
140it [00:00, 984.07it/s] 
140it [00:00, 910.10it/s]
129it [00:00, 920.77it/s]


KeyboardInterrupt: 

## Integrated Gradients and PGM Explainer

In [13]:
# Integrated gradients - soft mask of edge imp

igex = IntegratedGradExplainer(model, criterion=criterion)

def ig_imp_nodes(node_idx):

    node_exp = igex.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index, y = data.y)

    imp_nodes = []

    mask = torch.sigmoid(node_exp.node_imp) >= 0.5

    for k in node_exp.node_reference.keys():

        if mask[node_exp.node_reference[k]].item() == 1:
        
            imp_nodes.append(k)

    return imp_nodes

# PGME Explainer - discrete mask of node imp, randomised, can get ranking as well by asking for top 1 then 2 and so on

pgm = PGMExplainer(model, explain_graph=False)

def pgm_imp_nodes(node_idx, top = None):

    np.random.seed(1998)

    if top == None:

        node_exp = pgm.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index)

    else: 

        node_exp = pgm.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index, top_k_nodes=top)

    imp_nodes = []

    for k in node_exp.node_reference.keys():

        if node_exp.node_imp[node_exp.node_reference[k]].item() == 1:
        
            imp_nodes.append(k)

    return imp_nodes

## CAM

In [8]:
# CAM - soft mask of node importanct

camex = CAM(model)

def cam_imp_nodes(node_idx):

    node_exp = camex.get_explanation_node(node_idx = node_idx, x = data.x, edge_index = data.edge_index, y = data.y)

    imp_nodes = []

    mask = torch.sigmoid(node_exp.node_imp) >= 0.5

    for k in node_exp.node_reference.keys():

        if mask[node_exp.node_reference[k]].item() == 1:
        
            imp_nodes.append(k)

    return imp_nodes

