In [1]:
import os.path as osp

import torch
import torch_geometric.transforms as T

from util import IMDB
from models import HAN
from explainer import HeterExplainer


In [5]:
dataset_name = 'IMDB'

path = 'data/IMDB/'
metapaths = [[('movie', 'actor'), ('actor', 'movie')],
             [('movie', 'director'), ('director', 'movie')]]
transform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edges=True,
                           drop_unconnected_nodes=True)
dataset = IMDB(path, transform=transform)
data = dataset[0]
print(data)

HeteroData(
  metapath_dict={
    (movie, metapath_0, movie)=[2],
    (movie, metapath_1, movie)=[2]
  },
  [1mmovie[0m={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  [1m(movie, metapath_0, movie)[0m={ edge_index=[2, 85358] },
  [1m(movie, metapath_1, movie)[0m={ edge_index=[2, 17446] }
)


In [7]:
hidden_channels = 128
out_channels = 3
num_heads = 8
metadata = data.metadata()

ckpt_name = '_'.join((dataset_name, 'hDim', str(hidden_channels), 'nHead', str(num_heads)))
ckpt_path = 'checkpoints/'+ckpt_name+'.pt'


model = HAN(in_channels=-1, out_channels=out_channels, hidden_channels=hidden_channels, heads=num_heads, metadata=metadata)

checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['net'])
print('Trained model loaded.')

Trained model loaded.


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

explainer = HeterExplainer(model, dataset_name, x_dict=data.x_dict, edge_index_dict=data.edge_index_dict, device=device)

num_samples = 1000
p_threshold = .05
p_perturb = 0.5
pred_threshold = .01
k = 10

target = 15
zero_feature_cases = (data['movie'].x.sum(dim=-1)==0).nonzero().cpu().numpy().T[0].tolist()

if target in zero_feature_cases:
    raise ValueError("target in zero feature cases")


Explainer (heter) set up on cpu


In [None]:
S, raw_feature_exp, feature_exp, time_used = explainer.explain(target, num_samples=num_samples,  k=k, p_perturb=p_perturb, p_threshold=p_threshold, pred_threshold=pred_threshold)
print(explainer.calFidelity(target, S, feature_exp))


factual_S, factual_feat_exp = explainer.factual_synMLE(target, S, raw_feature_exp, num_samples=num_samples, k=k, p_perturb=p_perturb)
print(explainer.calFidelity(target, factual_S, factual_feat_exp))
explainer.printMeaningIMDB(factual_S, factual_feat_exp)

# counterfactual_S, counterfactual_feat_exp = explainer.counterfactual_synMLE(target, S, raw_feature_exp, n_cat_value=n_cat_value, num_samples=num_samples, k=k, p_perturb=p_perturb, pred_threshold=0.0000001)
counterfactual_S, counterfactual_feat_exp = explainer.counterfactual_synMLE(target, S, raw_feature_exp, num_samples=num_samples, k=k, p_perturb=p_perturb)
print(explainer.calFidelity(target, counterfactual_S, counterfactual_feat_exp))
explainer.printMeaningIMDB(counterfactual_S, counterfactual_feat_exp)

# explainer.printMeaningIMDB(S, feature_exp)
# explainer.printMeaningIMDB(factual_S, factual_feat_exp)
# explainer.printMeaningIMDB(counterfactual_S, counterfactual_feat_exp)


Generating 4650 samples on target: 15
