In [1]:
import os.path as osp

import torch

from torch_geometric.datasets import TUDataset
from torch_geometric.utils import k_hop_subgraph, to_networkx

from models import GCN
from explainer import HeterExplainer

import numpy as np

In [4]:
dataset_name = 'MUTAG'
path = 'data/TU'

dataset = TUDataset(path, name=dataset_name)
ckpt_name = '_'.join((dataset_name, 'GCN'))
ckpt_path = 'checkpoints/'+ckpt_name+'.pt'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(dataset.num_features, dataset.num_classes).to(device)

checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['net'])

<All keys matched successfully>

In [5]:
explainer = HeterExplainer(model, dataset_name, MUTAG_dataset=dataset, device=device)

num_samples = 1000
p_threshold = .05
p_perturb = 0.91
pred_threshold = 0.09
k = 10

Explainer (homo) set up on cpu


In [7]:
test_case = [i for i in range(len(dataset)) if dataset[i].y.item()==1] # explain mutagenic class only

target = test_case[53]

if target not in test_case:
    raise ValueError('target not in test case.')
    
data = dataset[target]

S, raw_feature_exp, feature_exp, time_used = explainer.explain(target, num_samples=num_samples, k=k, p_perturb=p_perturb, p_threshold=p_threshold, pred_threshold=pred_threshold)
print(explainer.homoCalFidelity(target, S, feature_exp, evaluator=False))

factual_S, factual_feat_exp = explainer.factual_synMLE(target, S, raw_feature_exp, num_samples=num_samples, k=k, p_perturb=p_perturb)
print(explainer.homoCalFidelity(target, factual_S, factual_feat_exp))

counterfactual_S, counterfactual_feat_exp = explainer.counterfactual_synMLE(target, S, raw_feature_exp, num_samples=num_samples, k=k, p_perturb=p_perturb)
print(explainer.homoCalFidelity(target, counterfactual_S, counterfactual_feat_exp))

Generating 1000 samples on target: 73
new p perturb: 0.0827391304347826
(0.01086801290512085, 0.9717560289427638, 0.0057013751938939095, 0.0062111801242235995, 0.043478260869565216, 0.14285714285714285, {4: 0.14285714285714285})
(0.01086801290512085, 0.9717560289427638, 0.0057013751938939095, 0.0062111801242235995, 0.043478260869565216, 0.14285714285714285, {4: 0.14285714285714285})
(0.01086801290512085, 0.9717560289427638, 0.0057013751938939095, 0.0062111801242235995, 0.043478260869565216, 0.14285714285714285, {4: 0.14285714285714285})
