In [1]:
import os.path as osp

import torch

from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_networkx

from models import GCN
from explainer import HeterExplainer

import numpy as np

In [2]:
dataset_name = 'MUTAG'
path = 'data/TU'

dataset = TUDataset(path, name=dataset_name)
ckpt_name = '_'.join((dataset_name, 'GCN'))
ckpt_path = 'checkpoints/'+ckpt_name+'.pt'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(dataset.num_features, dataset.num_classes).to(device)

checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['net'])

<All keys matched successfully>

In [3]:
explainer = HeterExplainer(model, dataset_name, MUTAG_dataset=dataset, device=device)

num_samples = 1000
p_threshold = .05
p_perturb = 0.91
pred_threshold = 0.09
k = 10

Explainer (homo) set up on cpu


In [8]:
test_case = [i for i in range(len(dataset)) if dataset[i].y.item()==1] # explain mutagenic class only

target = test_case[1] # any instance you want to explain

if target not in test_case:
    raise ValueError('target not in test case.')
    
data = dataset[target]

S, raw_feature_exp, feature_exp, time_used = explainer.explain(target, num_samples=num_samples, k=k, p_perturb=p_perturb, p_threshold=p_threshold, pred_threshold=pred_threshold)

factual_S, factual_feat_exp = explainer.factual_synMLE(target, S, raw_feature_exp, num_samples=num_samples, k=k, p_perturb=p_perturb)
F_metric = explainer.homoCalFidelity(target, factual_S, factual_feat_exp)
print(f"F^Hence-X\nF-effect: {F_metric[2]} O-Dens: {F_metric[3]}   T-Dens: {F_metric[4]}  F-Dens: {F_metric[5]}")

counterfactual_S, counterfactual_feat_exp = explainer.counterfactual_synMLE(target, S, raw_feature_exp, num_samples=num_samples, k=k, p_perturb=p_perturb)
CF_metric = explainer.homoCalFidelity(target, counterfactual_S, counterfactual_feat_exp)
print(f"CF^Hence-X\nCF-effect: {CF_metric[0]} O-Dens: {CF_metric[3]}   T-Dens: {CF_metric[4]}  F-Dens: {CF_metric[5]}")

Generating 1000 samples on target: 3
CF^Hence-X
CF-effect: 0.2016252875328064 O-Dens: 0.03759398496240602   T-Dens: 0.2631578947368421  F-Dens: 0.14285714285714285
