# Tutorial for GNN Explainability

In this hand-on code tutorial, we will show how to apply our DIG xgraph APIs to explain a well-trained model. Specifically, we will demonstrate two methods **GNNExplainer** and **SubgraphX** on the node classification dataset BA-shapes.

## GNN explainability
Graph Neural Networks are usually treated as black-boxes and lacking explainability. Without reasoning the prediction procedures of GNNs, we do not understand GNN models and do not know whether the models work in our expected way, thus preventing their use in critical applications pertaining to fairness, privacy, and safety.

### Load the dataset BA-shapes

In [24]:

# import torchvision
# from torchvision.datasets import CIFAR10
# from torchvision import transforms
# PyTorch Lightning
# Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
# !pip install --quiet pytorch-lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint


In [26]:
## Standard libraries
import os
import json
import math
import numpy as np 
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
# import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
# import torchvision
# from torchvision.datasets import CIFAR10
# from torchvision import transforms
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial7"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("mps")
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

  set_matplotlib_formats('svg', 'pdf') # For export
Seed set to 42


cpu


In [7]:
import os
import sys
import os.path as osp
sys.path.insert(0, os.sep.join(os.path.abspath('').split(os.sep)[:-2]))

import torch
from torch_geometric.data import download_url, extract_zip

from dig.xgraph.dataset import SynGraphDataset
# from dig.xgraph.models import *
from dig.xgraph.utils.compatibility import compatible_state_dict
from dig.xgraph.utils.init import fix_random_seed

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

fix_random_seed(123)
# dataset = SynGraphDataset('./datasets', 'BA_shapes')
DATASET_PATH='../data'
dataset = torch_geometric.datasets.TUDataset(root=DATASET_PATH, name="MUTAG")
dataset.data.x = dataset.data.x.to(torch.float32)
dataset.data.x = dataset.data.x[:, :1]
dim_node = dataset.num_node_features
num_classes = dataset.num_classes

data = dataset[0]
# target_node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist()

In [5]:
import torch_geometric

In [48]:
class GNNModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, layer_name="GraphConv", dp_rate=0.1, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of "hidden" graph layers
            layer_name - String of the graph layer to use
            dp_rate - Dropout rate to apply throughout the network
            kwargs - Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]
        
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                gnn_layer(in_channels=in_channels, 
                          out_channels=out_channels,
                          **kwargs),
                nn.ReLU(inplace=True)
                # ,nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels, 
                             out_channels=c_out,
                             **kwargs)]
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x, edge_index):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
        """
        for l in self.layers:
            # For graph layers, we need to add the "edge_index" tensor as additional input
            # All PyTorch Geometric graph layer inherit the class "MessagePassing", hence
            # we can simply check the class type.
            if isinstance(l, geom_nn.MessagePassing):
                x = l(x, edge_index)
            else:
                x = l(x)
        return x

In [27]:
class GraphGNNModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, dp_rate_linear=0.5, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of output features (usually number of classes)
            dp_rate_linear - Dropout rate before the linear layer (usually much higher than inside the GNN)
            kwargs - Additional arguments for the GNNModel object
        """
        super().__init__()
        self.GNN = GNNModel(c_in=c_in, 
                            c_hidden=c_hidden, 
                            c_out=c_hidden, # Not our prediction output yet!
                            **kwargs)
        self.head = nn.Sequential(
            # nn.Dropout(dp_rate_linear),
            nn.Linear(c_hidden, c_out)
        )

    def forward(self, x, edge_index, batch_idx):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
            batch_idx - Index of batch element for each node
        """
        x = self.GNN(x, edge_index)
        # x = geom_nn.global_mean_pool(x, batch_idx) # Average pooling
        x = geom_nn.global_max_pool(x, batch_idx) # Average pooling
        # x = geom_nn.global_add_pool(x, batch_idx) # sum pooling
        x = self.head(x)
        return x

In [28]:
class GraphLevelGNN(pl.LightningModule):

    def __init__(self, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        
        self.model = GraphGNNModel(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()

    def forward(self, data, mode="train"):
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)
        
        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            try: 
                data.y = data.y.float()
            except: pass
        else:
            preds = x.argmax(dim=-1)
        try: 
            loss = self.loss_module(x, data.y)
            acc = (preds == data.y).sum().float() / preds.shape[0]
        except:
            loss = 0
            acc = 0

        return loss,acc,preds


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.0) # High lr because of small dataset and small model
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc,_ = self.forward(batch, mode="train")
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    # def validation_step(self, batch, batch_idx):
    #     _, acc,preds = self.forward(batch, mode="val")
    #     self.log('val_acc', acc)
    #     # self.log('val_pred', preds)

    def test_step(self, batch, batch_idx):
        _, acc,_ = self.forward(batch, mode="test")
        self.log('test_acc', acc)
        # self.log('test_pred', pred_y)

In [30]:
tu_dataset = torch_geometric.datasets.TUDataset(root=DATASET_PATH, name="ENZYMES")

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Processing...
Done!


In [31]:
tu2=tu_dataset.copy()
type(tu2.data)



torch_geometric.data.data.Data

In [33]:
for i in range(len(tu2.y)): 
    tu2.y[i]=max(torch.bincount(tu2[i].edge_index[0,:]))
    # tu2.y[i]=sum(torch.bincount(tu2[i].edge_index[0,:]))

In [35]:
thres=6

In [36]:
tu2.data.y=(tu2.y>thres).long()



In [37]:
tu2.data.y

tensor([1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0,
        0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
        1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
        1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0,
        0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,

In [38]:
tu2.data.x=tu2.x[:,:1]
# tu2.data.x=torch.ones(tu2.x.shape)
tu2.data.x=torch.randn(tu2.x.shape)

In [39]:
# torch.manual_seed(42)
# tu_dataset.shuffle()
# train_dataset = tu_dataset[:500]
# test_dataset = tu_dataset[500:]
torch.manual_seed(42)
tu2_shuffle=tu2.shuffle()
train_dataset = tu2_shuffle[:500]
test_dataset = tu2_shuffle[500:]

When using a data loader, we encounter a problem with batching $N$ graphs. Each graph in the batch can have a different number of nodes and edges, and hence we would require a lot of padding to obtain a single tensor. Torch geometric uses a different, more efficient approach: we can view the $N$ graphs in a batch as a single large graph with concatenated node and edge list. As there is no edge between the $N$ graphs, running GNN layers on the large graph gives us the same output as running the GNN on each graph separately. Visually, this batching strategy is visualized below (figure credit - PyTorch Geometric team, [tutorial here](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=2owRWKcuoALo)).

<center width="100%"><img src="torch_geometric_stacking_graphs.png" width="600px"></center>

The adjacency matrix is zero for any nodes that come from two different graphs, and otherwise according to the adjacency matrix of the individual graph. Luckily, this strategy is already implemented in torch geometric, and hence we can use the corresponding data loader:

In [41]:
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data
gnn_layer_by_name = {
    "GCN": geom_nn.GCNConv,
    "GAT": geom_nn.GATConv,
    "GIN": geom_nn.GINConv,
    "GraphConv": geom_nn.GraphConv
}

In [42]:
graph_train_loader = geom_data.DataLoader(train_dataset, batch_size=15, shuffle=True)
graph_val_loader = geom_data.DataLoader(test_dataset, batch_size=15) # Additional loader if you want to change to a larger dataset
graph_test_loader = geom_data.DataLoader(test_dataset, batch_size=15)



Let's load a batch below to see the batching in action:

In [43]:
batch = next(iter(graph_test_loader))
print("Batch:", batch)
print("Labels:", batch.y[:10])
print("Batch indices:", batch.batch[:40])

Batch: DataBatch(edge_index=[2, 2192], x=[569, 1], y=[15], batch=[569], ptr=[16])
Labels: tensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0])
Batch indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


In [44]:
def train_graph_classifier(model_name,train_loader=graph_train_loader,test_loader=graph_test_loader, **model_kwargs):
    pl.seed_everything(42)
    
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "GraphLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         # callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
                         accelerator="cpu",# if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=500,
                         enable_progress_bar=False)
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, f"GraphLevel{model_name}.ckpt")
    if os.path.isfile(pretrained_filename):
        # print("Found pretrained model, loading...")
        # model = GraphLevelGNN.load_from_checkpoint(pretrained_filename)
        pl.seed_everything(42)
        model = GraphLevelGNN(c_in=tu2.num_node_features, 
                              c_out=1 if tu2.num_classes==2 else tu2.num_classes, 
                              **model_kwargs)
        trainer.fit(model, graph_train_loader, graph_val_loader)
        # model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    else:
        pl.seed_everything(42)
        model = GraphLevelGNN(c_in=tu2.num_node_features, 
                              c_out=1 if tu2.num_classes==2 else tu2.num_classes, 
                              **model_kwargs)
        trainer.fit(model, graph_train_loader, graph_val_loader)
        # model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    # Test best model on validation and test set
    train_result = trainer.test(model, train_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    # test_pred = trainer.predict(model, graph_test_loader, return_predictions=True)
    result = {"test": test_result[0]['test_acc'], "train": train_result[0]['test_acc']
              # ,"pred_y": test_pred
            } 
    return model, result

### Load the well-trained GCN model

In [50]:
def check_checkpoints(root='./'):
    if osp.exists(osp.join(root, 'checkpoints')):
        return
    url = ('https://github.com/divelab/DIG_storage/raw/main/xgraph/checkpoints.zip')
    path = download_url(url, root)
    extract_zip(path, root)
    os.unlink(path)


# model = GCN_2l(model_level='node', dim_node=dim_node, dim_hidden=300, num_classes=num_classes)
# model.to(device)
# check_checkpoints()
# ckpt_path = osp.join('checkpoints', 'ba_shapes', 'GCN_2l', '0', 'GCN_2l_best.ckpt')
model = GraphLevelGNN(c_in=tu2.num_node_features, 
                              c_out=1 if tu2.num_classes==2 else tu2.num_classes,
                      c_hidden=16)
                              # **model_kwargs)
ckpt_path='/Users/beatrixwen/Downloads/checkpoint.ckpt'
state_dict = compatible_state_dict(torch.load(ckpt_path, map_location='cpu')['model_state_dict'])
model.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for GraphLevelGNN:
	Unexpected key(s) in state_dict: "model.GNN.layers.4.lin_rel.weight", "model.GNN.layers.4.lin_rel.bias", "model.GNN.layers.4.lin_root.weight". 

In [None]:
import os 
os.getcwd()

### GNNExplainer

GNNExplainer learns soft masks for edges and node features to identify important input information. The masks are randomly initialized and updated to maximize the mutual information between original predictions and perturbed graphs.
Here we setup **GNNExplainer** and take it to explain the predictions for 20 target nodes.

In [None]:
from dig.xgraph.method import GNNExplainer
from dig.xgraph.evaluation import XCollector

gnnexplainer = GNNExplainer(model, epochs=100, lr=0.01, explain_graph=False)

### SubgraphX
SubgraphX explores subgraph-level explanations for graph neural networks. It employs the Monte Carlo Tree Search algorithm to efficiently explore different subgraphs via node pruning and select the most important subgraph as the explanation. The importance of subgraphs is measured by Shapley values.

In [None]:
from dig.xgraph.method import SubgraphX

subgraphx_explainer = SubgraphX(model,
                                num_classes=num_classes,
                                device=device,
                                explain_graph=False,
                                reward_method='nc_mc_l_shapley')

## Visualization
We provide the visualization APIs for Synthetic, Text2Graph and Molecules datasets to provide human-intelligible explanation visualizations.

In [None]:
from dig.xgraph.method.subgraphx import PlotUtils
from dig.xgraph.method.subgraphx import MCTS
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops

data = dataset[0].to(device)
node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist()
node_idx = node_indices[20]

subgraph_x, subgraph_edge_index, subset, edge_mask, kwargs = \
    MCTS.__subgraph__(node_idx, data.x, data.edge_index, num_hops=2)
new_node_idx = torch.argwhere(subset == node_idx)[0]

subgraph_y = data.y[subset].to('cpu')
vis_graph = to_networkx(Data(x=subgraph_x, edge_index=subgraph_edge_index))
plotutils = PlotUtils(dataset_name='ba_shapes')
plotutils.plot(vis_graph, nodelist=[], figname=None, y=subgraph_y, node_idx=new_node_idx)

### Visualization results for the GNNExplainer

In [None]:
sparsity = 0.7
import matplotlib.pyplot as plt

# Visualization
gnnexplainer_related_preds = \
    gnnexplainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes, node_idx=node_idx)
ax, G = gnnexplainer.visualize_graph(node_idx=node_idx, 
                                     edge_index=data.edge_index, 
                                     edge_mask=gnnexplainer_related_preds[1][prediction], 
                                     y=data.y)
plt.show()

### Visualization results for the SubgraphX

In [None]:
from dig.xgraph.method.subgraphx import PlotUtils
from dig.xgraph.method.subgraphx import find_closest_node_result

# Visualization
max_nodes = 5
node_idx = node_indices[20]
print(f'explain graph node {node_idx}')
data.to(device)
logits = model(data.x, data.edge_index)
prediction = logits[node_idx].argmax(-1).item()

_, explanation_results, related_preds = \
    subgraphx_explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)

explanation_results = explanation_results[prediction]
explanation_results = subgraphx_explainer.read_from_MCTSInfo_list(explanation_results)

plotutils = PlotUtils(dataset_name='ba_shapes', is_show=True)
subgraphx_explainer.visualization(explanation_results,
                        max_nodes=max_nodes,
                        plot_utils=plotutils,
                        y=data.y)

### Metric
Here we apply two metric Fidelity and Sparsity to compare these two explainability methods.
Fidelity is the probability difference between the original prediction and new prediction when masking the important explanation results. In addition, Sparsity is the ratio of the number of edges not in the final explanation of edges.

 Next, we take GNNExplainer and SubgraphX to explain the same model predictions. Here we control similar sparsity scores and compare the fidelity score.

In [None]:
sparsity = 0.15
gnnexplainer_collector = XCollector()

max_nodes = 5
subgraphx_collector = XCollector()

for data_idx, node_idx in enumerate(target_node_indices[:20]):
    data.to(device)

    logits = model(data.x, data.edge_index)
    prediction = logits[node_idx].argmax(-1).item()

    gnnexplainer_edge_masks, \
    gnnexplainer_hard_edge_masks, \
    gnnexplainer_related_preds = \
        gnnexplainer(data.x, data.edge_index,
                     sparsity=sparsity,
                     num_classes=num_classes,
                     node_idx=node_idx)

    _, subgraphx_explanation_results, subgraphx_related_preds = \
        subgraphx_explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
    subgraphx_explanation_results = subgraphx_explanation_results[prediction]

    gnnexplainer_collector.collect_data(
        gnnexplainer_hard_edge_masks, gnnexplainer_related_preds, prediction)
    subgraphx_collector.collect_data(
        subgraphx_explanation_results[0]['coalition'], subgraphx_related_preds, label=prediction)

In [None]:
print(f'GNNExplainer Fidelity: {gnnexplainer_collector.fidelity:.4f}\n',
      f'GNNExplainer Sparsity: {gnnexplainer_collector.sparsity:.4f}')
print(f'SubgraphX Fidelity: {subgraphx_collector.fidelity:.4f}\n',
      f'SubgraphX Sparsity: {subgraphx_collector.sparsity:.4f}')