# Explaining CFGNN's Treecycles dataset using PGExplainer

In [10]:
import pickle
from math import ceil

import torch
from torch_geometric.explain import Explainer
from torch_geometric.explain.algorithm import PGExplainer
from torch_geometric.explain.config import ModelConfig
from torch_geometric.explain.metric import fidelity
from torch_geometric.utils import k_hop_subgraph

from pyg_gcn import GNN

torch.manual_seed(42)

<torch._C.Generator at 0x7fd9eb6b06b0>

# Data

In [11]:
with open(f"../data/gnn_explainer/syn4.pickle", "rb") as file:
	data = pickle.load(file)
adj = torch.tensor(data["adj"], dtype=torch.float).squeeze()
features = torch.tensor(data["feat"], dtype=torch.float).squeeze()
labels = torch.tensor(data["labels"], dtype=torch.long).squeeze()
idx_train = torch.tensor(data["train_idx"], dtype=torch.long)
with open(f"../data/eval-sets/syn4.pickle", "rb") as file:
	idx_test = torch.tensor(pickle.load(file), dtype=torch.long)

In [12]:
edge_index = adj.to_sparse().indices()

# PyG Model

In [13]:
model = GNN(
    nfeat=features.size(1),
    nhid=20,
    nout=20,
    nclass=1 + labels.max().item(),
    dropout=0.0,
)
state_dict = torch.load("../models/gcn_3layer_syn4.pt")
for key in list(state_dict):
    if "gc" in key and "weight" in key:
        new_key = key.split(".")
        new_key = new_key[0] + ".lin." + new_key[1]
        state_dict[new_key] = state_dict[key].T
        del state_dict[key]
model.load_state_dict(state_dict)
model.eval()

GNN(
  (gc1): GCNConv(10, 20)
  (gc2): GCNConv(20, 20)
  (gc3): GCNConv(20, 20)
  (lin): Linear(in_features=60, out_features=2, bias=True)
)

In [14]:
#* The pyg version seems replicate the results without normalizing the adjacency.
output = model(features, edge_index)
pred = output.argmax(dim=1)

In [15]:
train_acc = 100 * (pred[idx_train] == labels[idx_train]).sum() / pred[idx_train].size(0)
test_acc = 100 * (pred[idx_test] == labels[idx_test]).sum() / pred[idx_test].size(0)
print(
    f"Training accuracy: {train_acc:.2f}%",
    f"Test accuracy: {test_acc:.2f}%",
    sep="\n"
)

Training accuracy: 91.24%
Test accuracy: 100.00%


# Explanation

In [16]:
EPOCHS = 30
LR = 0.003
TOP_K = 6

In [17]:
explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=EPOCHS, lr=LR),
    explanation_type="phenomenon",
    edge_mask_type="object",
    model_config=ModelConfig(
        mode="multiclass_classification",
        task_level="node",
        return_type="log_probs"
    ),
    threshold_config=dict(threshold_type='topk', value=TOP_K),
)

In [18]:
for epoch in range(EPOCHS):
    for index in idx_train:
        loss = explainer.algorithm.train(
            epoch=epoch,
            model=model,
            x=features,
            edge_index=edge_index,
            target=labels,
            index=index.item()
        )

# Metrics

In [19]:
from time import perf_counter
start = perf_counter()
for index in idx_test:
    explanation = explainer(features, edge_index, target=labels, index=index.item())
end = perf_counter()
print("Time elapsed:", round(end - start, 2), "seconds")

Time elapsed: 0.38 seconds


## Fidelity

In [20]:
def cal_fidelity(indices):
    fidelities = list()
    for index in indices:
        explanation = explainer(features, edge_index, target=labels, index=index.item())
        fidelities.append(fidelity(explainer, explanation)[0])
    fidelities = torch.tensor(fidelities, dtype=float)
    return 1 - fidelities.mean(), fidelities.std()

In [21]:
# fidelity_mean_train, fidelity_std_train = cal_fidelity(idx_train)
fidelity_mean_test, fidelity_std_test = cal_fidelity(idx_test)

In [22]:
# print(f"Average training fidelity: {fidelity_mean_train:.4f}, std={fidelity_std_train:.4f}")
print(f"Average test fidelity: {fidelity_mean_test:.4f}, std={fidelity_std_test:.4f}")

Average test fidelity: 0.3472, std=0.4794


## Size

In [23]:
print(f"Average explanaiton size: {TOP_K}")

Average explanaiton size: 6


## Accuracy

In [24]:
def cal_acc(indices):
    accuracies = list()
    for index in indices:
        explanation = explainer(features, edge_index, target=labels, index=index.item())
        # Find all edges involved in the explanation.
        edges_involved = explanation.edge_index[:, explanation.edge_mask.to(bool)]
        # Count the edges where the src and the dest have non-zero labels.
        sources = edges_involved[0]
        destinations = edges_involved[1]
        acc = 0
        for i in range(edges_involved.size(1)):
            if labels[sources[i]] != 0 and labels[destinations[i]] != 0:
                acc += 1
        # Divide by the total edges involved.
        acc /= edges_involved.size(1)
        accuracies.append(acc)
    # Divide by the #instances
    accuracies = torch.tensor(accuracies, dtype=float)
    return accuracies.mean(), accuracies.std()

In [25]:
# acc_mean_train, acc_std_train = cal_acc(idx_train)
acc_mean_test, acc_std_test = cal_acc(idx_test)

In [26]:
# print(f"Average train accuracy: {acc_mean_train:.4f}, std={acc_std_train:.4f}")
print(f"Average test accuracy: {acc_mean_test:.4f}, std={acc_std_test:.4f}")

Average test accuracy: 0.7685, std=0.2791


## Sparsity

In [27]:
def cal_sparsity(indices, is_undirected:bool = True):
    sparsity_t = list()
    # extract the subgraph
    for node_index in indices:
        __, __, __, edge_mask = k_hop_subgraph(
            node_idx=node_index.item(),
            num_hops=PGExplainer._num_hops(model),
            edge_index=edge_index,
            num_nodes=features.size(0),
            flow=PGExplainer._flow(model),
        )
        # find the number of edges in the subgraph.
        num_edges = edge_mask.nonzero().size(0)
        # account for undirected edges.
        if is_undirected:
            num_edges = ceil(num_edges / 2)
        # Find all edges involved in the explanation.
        explanation = explainer(features, edge_index, target=labels, index=index.item())
        edges_involved = explanation.edge_index[:, explanation.edge_mask.to(bool)]
        explanation_size = edges_involved.size(1)
        if explanation_size != 6:
            print(f"{node_index}: {explanation_size}: {num_edges}")
        sparsity = 1 - (explanation_size / num_edges)
        sparsity_t.append(sparsity)
    sparsity_t = torch.tensor(sparsity_t, dtype=float)
    return sparsity_t.mean(), sparsity_t.std()

In [28]:
# sparsity_mean_train, sparsity_std_train = cal_sparsity(idx_train, True)
sparsity_mean_test, sparsity_std_test = cal_sparsity(idx_test, True)

In [29]:
# print(f"Average train sparsity: {sparsity_mean_train:.4f}, std={sparsity_std_train:.4f}")
print(f"Average test sparsity: {sparsity_mean_test:.4f}, std={sparsity_std_test:.4f}")

Average test sparsity: 0.3388, std=0.1840
