# GNN Explainer

The purpose of this notebook is to present GNN Explainer, a general, model-agnostic approach for providing inter-
pretable explanations for predictions of any GNN-based model on any graph-based
machine learning task.

The notebook is organized as follows:

* Brief theoretical recap on GNN-EXPLAINER
* Train your GNN-EXPLAINER to explain graph classification predictions
* Visualize and understand the propose explanation

## Model Description

In [13]:
## TODO

## Train your GNN Explainer!

In this section we will build together the main module used by the explainer, which is implemented in the **ExplainModule class**. This class contains several methods, in particular:

**construct_feat_mask**: initializes the feature mask that will be learned by the explainer

**construct edge_mask**: initializes the edge mask that will be learne by the explainer

**mask_adj**: computes the masked adjacency matrices of the graph whose prediction we want to explain

**mask_density**: computes mask density as (sum masked entried)/(original sum of entries)

**forward**: returns the model prediction (and edge attention, if available), based on the current edge and feature masks.

**loss**: computes the loss function as explained in the previous section

In [1]:
import math
import time
import os

import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure

import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns
import tensorboardX.utils

import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
class ExplainModule(nn.Module):
    def __init__(
        self, adj, x, model, label, args, graph_idx=0, writer=None, use_sigmoid=True
    ):
        super(ExplainModule, self).__init__()
        self.adj = adj
        self.x = x
        self.model = model
        self.label = label
        self.graph_idx = graph_idx
        self.args = args
        self.writer = writer
        self.mask_act = args.mask_act
        self.use_sigmoid = use_sigmoid
        self.graph_mode = True
        # Relative weights for the terms in the loss function.
        self.coeffs = {
            "size": 0.005,
            "feat_size": 1.0,
            "ent": 1.0,
            "feat_ent": 0.1,
            "grad": 0,
            "lap": 1.0,
        }
        num_nodes = adj.size()[1]
        init_strategy = "normal"
        # Initialize the edge mask to be optimized.
        self.mask, self.mask_bias = self.construct_edge_mask(
            num_nodes, init_strategy=init_strategy
        )
        # Initialize the feature mask to be optimized.
        self.feat_mask = self.construct_feat_mask(x.size(-1), init_strategy="constant")
        params = [self.mask, self.feat_mask]
        if self.mask_bias is not None:
            params.append(self.mask_bias)
        # For masking diagonal entries.
        self.diag_mask = torch.ones(num_nodes, num_nodes) - torch.eye(num_nodes)
        if args.gpu:
            self.diag_mask = self.diag_mask.cuda()

        self.scheduler, self.optimizer = train_utils.build_optimizer(args, params)

    def construct_feat_mask(self, feat_dim, init_strategy="normal"):
        """Initialize the feature mask. init_strategy is a string specifying
        the chosen initialization strategy (can be 'costant' or 'normal'.)
        """
        mask = nn.Parameter(torch.FloatTensor(feat_dim))
        if init_strategy == "normal":
            std = 0.1
            with torch.no_grad():
                mask.normal_(1.0, std)
        elif init_strategy == "constant":
            with torch.no_grad():
                nn.init.constant_(mask, 0.0)
        return mask

    def construct_edge_mask(self, num_nodes, init_strategy="normal", const_val=1.0):
        """Initialize the edge mask. init_strategy is a string specifying
        the chosen initialization strategy (eg 'costant' or 'normal'.)
        """
        mask = nn.Parameter(torch.FloatTensor(num_nodes, num_nodes))
        if init_strategy == "normal":
            std = nn.init.calculate_gain("relu") * math.sqrt(
                2.0 / (num_nodes + num_nodes)
            )
            with torch.no_grad():
                mask.normal_(1.0, std)
        elif init_strategy == "const":
            nn.init.constant_(mask, const_val)
        if self.args.mask_bias:
            mask_bias = nn.Parameter(torch.FloatTensor(num_nodes, num_nodes))
            nn.init.constant_(mask_bias, 0.0)
        else:
            mask_bias = None

        return mask, mask_bias

    def _masked_adj(self):
        """Computes the masked adjacency matrix of the graph. Since
        we work with undirected graphs, we make the mask symmetric.
        Self-loops are also excluded using a diagonal mask.
        """
        sym_mask = self.mask
        if self.mask_act == "sigmoid":
            sym_mask = torch.sigmoid(self.mask)
        elif self.mask_act == "ReLU":
            sym_mask = nn.ReLU()(self.mask)
        sym_mask = (sym_mask + sym_mask.t()) / 2
        adj = self.adj.cuda() if self.args.gpu else self.adj
        masked_adj = adj * sym_mask
        if self.args.mask_bias:
            bias = (self.mask_bias + self.mask_bias.t()) / 2
            bias = nn.ReLU6()(bias * 6) / 6
            masked_adj += (bias + bias.t()) / 2
        return masked_adj * self.diag_mask

    def mask_density(self):
        mask_sum = torch.sum(self._masked_adj()).cpu()
        adj_sum = torch.sum(self.adj)
        return mask_sum / adj_sum

    def forward(self, mask_features=True, marginalize=False):
        """Computes the model prediction on the masked graph with masked features.
        Returns the model predictions and adjacency attention (if available).
        """
        x = self.x.cuda() if self.args.gpu else self.x
        self.masked_adj = self._masked_adj()
        if mask_features:
            feat_mask = (
                torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask
            )
            if marginalize:
                std_tensor = torch.ones_like(x, dtype=torch.float) / 2
                mean_tensor = torch.zeros_like(x, dtype=torch.float) - x
                z = torch.normal(mean=mean_tensor, std=std_tensor)
                x = x + z * (1 - feat_mask)
            else:
                x = x * feat_mask

        ypred, adj_att = self.model(x, self.masked_adj)
        res = nn.Softmax(dim=0)(ypred[0])

        return res, adj_att

    def loss(self, pred, epoch):
        """
        Args:
            pred: prediction made by current model (with current mask).
            epoch: training epoch.
        """
        # Prediction loss.
        gt_label = self.label
        logit = pred[gt_label]
        pred_loss = -torch.log(logit)
        # Adjacency mask size loss.
        mask = self.mask
        if self.mask_act == "sigmoid":
            mask = torch.sigmoid(self.mask)
        elif self.mask_act == "ReLU":
            mask = nn.ReLU()(self.mask)
        size_loss = self.coeffs["size"] * torch.sum(mask)
        # Feature mask size loss.
        feat_mask = (
            torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask
        )
        feat_size_loss = self.coeffs["feat_size"] * torch.mean(feat_mask)
        # Adjacency mask entropy loss.
        mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask)
        mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent)
        # Feature mask entropy loss.
        feat_mask_ent = -feat_mask * torch.log(feat_mask) - (1 - feat_mask) * torch.log(
            1 - feat_mask
        )
        feat_mask_ent_loss = self.coeffs["feat_ent"] * torch.mean(feat_mask_ent)
        # Total loss.
        loss = pred_loss + size_loss + mask_ent_loss + feat_size_loss

        # Log data to tensorboard.
        if self.writer is not None:
            self.writer.add_scalar("optimization/size_loss", size_loss, epoch)
            self.writer.add_scalar("optimization/feat_size_loss", feat_size_loss, epoch)
            self.writer.add_scalar("optimization/mask_ent_loss", mask_ent_loss, epoch)
            self.writer.add_scalar(
                "optimization/feat_mask_ent_loss", mask_ent_loss, epoch
            )
            self.writer.add_scalar("optimization/pred_loss", pred_loss, epoch)
            # self.writer.add_scalar("optimization/lap_loss", lap_loss, epoch)
            self.writer.add_scalar("optimization/overall_loss", loss, epoch)

        return loss

To train our explainer, we then use the trainExplainer function, implemented in the explainer.py file. 

In [7]:
import sys  

import src.GNNexplainer.configs_explainer as configs_explainer

In [11]:
prog_args = configs_explainer.explainer_arg_parse()

usage: ipykernel_launcher.py [-h] [--dataset DATASET] [--bmname BMNAME]
                             [--pkl PKL_FNAME] [--opt OPT]
                             [--opt-scheduler OPT_SCHEDULER]
                             [--opt-restart OPT_RESTART]
                             [--opt-decay-step OPT_DECAY_STEP]
                             [--opt-decay-rate OPT_DECAY_RATE] [--lr LR]
                             [--clip CLIP] [--clean-log] [--logdir LOGDIR]
                             [--ckptdir CKPTDIR] [--cuda CUDA] [--gpu]
                             [--epochs NUM_EPOCHS] [--hidden-dim HIDDEN_DIM]
                             [--output-dim OUTPUT_DIM]
                             [--num-gc-layers NUM_GC_LAYERS] [--bn]
                             [--dropout DROPOUT] [--nobias] [--no-writer]
                             [--mask-act MASK_ACT] [--mask-bias]
                             [--explain-node EXPLAIN_NODE]
                             [--graph-idx GRAPH_IDX] [--graph-mode]
   

SystemExit: 2

In [12]:
%tb

SystemExit: 2

In [None]:
# Load a configuration



if prog_args.gpu:
    os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
    print("CUDA", prog_args.cuda)
else:
    print("Using CPU")

# Configure the logging directory
if prog_args.writer:
    path = os.path.join(prog_args.logdir, io_utils.gen_explainer_prefix(prog_args))
    if os.path.isdir(path) and prog_args.clean_log:
        print("Removing existing log dir: ", path)
        if (
            not input("Are you sure you want to remove this directory? (y/n): ")
            .lower()
            .strip()[:1]
            == "y"
        ):
            sys.exit(1)
        shutil.rmtree(path)
    writer = SummaryWriter(path)
else:
    writer = None

# Load a model checkpoint
ckpt = io_utils.load_ckpt(prog_args)
cg_dict = ckpt["cg"]  # get computation graph
input_dim = cg_dict["feat"].shape[2]
num_classes = cg_dict["pred"].shape[2]
print("Loaded model from {}".format(prog_args.ckptdir))
print("input dim: ", input_dim, "; num classes: ", num_classes)

# build model
print("Method: ", prog_args.method)

# Explain Graph prediction
model = models.GcnEncoderGraph(
    input_dim=input_dim,
    hidden_dim=prog_args.hidden_dim,
    embedding_dim=prog_args.output_dim,
    label_dim=num_classes,
    num_layers=prog_args.num_gc_layers,
    bn=prog_args.bn,
    args=prog_args,
)
if prog_args.gpu:
    model = model.cuda()

# Load state_dict (obtained by model.state_dict() when saving checkpoint)
model.load_state_dict(ckpt["model_state"])
model = model.eval()

# Extract the data relative to the chosen graph.
adj = cg_dict["adj"]
feat = cg_dict["feat"]
label = cg_dict["label"]
pred = cg_dict["pred"]

graph_idx = prog_args.graph_idx
sub_adj = adj[graph_idx]
sub_feat = feat[graph_idx, :]
sub_label = label[graph_idx]
neighbors = np.asarray(range(adj.shape[0]))

sub_adj = np.expand_dims(sub_adj, axis=0)
sub_feat = np.expand_dims(sub_feat, axis=0)

adj = torch.tensor(sub_adj, dtype=torch.float)
x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float)
label = torch.tensor(sub_label, dtype=torch.long)

# Create explainer
explainer = Explainer(
    model=model,
    adj=adj,
    x=x,
    label=label,
    args=prog_args,
    writer=writer,
    graph_idx=prog_args.graph_idx,
)

# Run explainer.
train_explainer(
    explainer=explainer, pred=pred, args=prog_args, graph_idx=prog_args.graph_idx
)
io_utils.plot_cmap_tb(writer, "tab20", 20, "tab20_cmap")


## Visualization

In [None]:
This notebook is designed to visualize the results of the GNN Explainer.

Use it after one has trained the model using train.py, and has run the explainer optimization (explainer_main.py).
The main purpose is to visualize the trained mask by interactively tuning the threshold. In many scientific applications, the explanation size is unknown a priori. This tool can help user visualize the selected subgraph, with respect to different values of the thresholds, and find the right size for a good explanation.

In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import numpy as np
import os
import networkx as nx
import matplotlib.pyplot as plt
import json

%matplotlib inline

Configuring the experiment you want to visualize. These values should match the configuration:

> TODO: Unify configuration of experiments in yaml

In [5]:
logdir = '../src/GNNexplainer/log/'
expdir = 'REDDIT-BINARY_base_h20_o20_explain'

In [6]:
# Load the produced masks

In [14]:
dirs = os.listdir(os.path.join(logdir, expdir))
dirs = os.listdir(logdir)

In [16]:
masks = []
# This would print all the files and directories
for file in dirs:
    if file.split('.')[-1] == 'npy':
        print(file)
        masks.append(file)

masked_adj_REDDIT-BINARY_base_h20_o20_explainnode_idx_0graph_idx_3.npy


Utility to save masks:

In [17]:
from networkx.readwrite import json_graph

def save_mask(G, fname, fmt='json', suffix=''):
    pth = os.path.join(logdir, expdir, fname+'-filt-'+suffix+'.'+fmt)
    if fmt == 'json':
        dt = json_graph.node_link_data(G)
        with open(pth, 'w') as f:
            json.dump(dt, f)
    elif fmt == 'pdf':
        plt.savefig(pth)
    elif fmt == 'npy':
        np.save(pth, nx.to_numpy_array(G))

Plotting utilities:

In [18]:
def show_adjacency_full(mask, ax=None):
    adj = np.load(os.path.join(logdir, expdir, mask), allow_pickle=True)
    if ax is None:
        plt.figure()
        plt.imshow(adj);
    else:
        ax.imshow(adj)
    return adj

In [19]:
def read_adjacency_full(mask, ax=None):
    adj = np.load(os.path.join(logdir, expdir, mask), allow_pickle=True)
    return adj

In [21]:
filt_adj = read_adjacency_full(masks[0])
@interact
def filter_adj(thresh=0.5):
    filt_adj[filt_adj<thresh] = 0
    return filt_adj

interactive(children=(FloatSlider(value=0.5, description='thresh', max=1.5, min=-0.5), Output()), _dom_classes…

Weight-based threshold:

In [22]:
# EDIT THIS INDEX
MASK_IDX = 0
# EDIT THIS INDEX

m = masks[MASK_IDX]
adj = read_adjacency_full(m)


@interact(thresh=widgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.01))
def plot_interactive(thresh=0.5):
    filt_adj = read_adjacency_full(m)
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,5))
    plt.title(str(m));

    # Full adjacency
    ax1.set_title('Full Adjacency mask')
    adj = show_adjacency_full(m, ax=ax1);
    
    # Filtered adjacency
    filt_adj[filt_adj<thresh] = 0
    ax2.set_title('Filtered Adjacency mask');
    ax2.imshow(filt_adj);
    
    # Plot subgraph
    ax3.set_title("Subgraph")
    G_ = nx.from_numpy_array(adj)
    G  = nx.from_numpy_array(filt_adj)
    G.remove_nodes_from(list(nx.isolates(G)))
    nx.draw(G, ax=ax3)
    save_mask(G, fname=m, fmt='json')
    
    print("Removed {} edges -- K = {} remain.".format(G_.number_of_edges()-G.number_of_edges(), G.number_of_edges()))
    print("Removed {} nodes -- K = {} remain.".format(G_.number_of_nodes()-G.number_of_nodes(), G.number_of_nodes()))


interactive(children=(FloatSlider(value=0.5, description='thresh', max=1.0, step=0.01), Output()), _dom_classe…