# 01 - explanations

In [None]:
import os
import utils
import artificial_data

os.chdir("./../")
print(os.getcwd())

In [None]:
import networkx as nx
import torch_geometric
from torch_geometric.explain import Explainer, GNNExplainer, CaptumExplainer
from yfiles_jupyter_graphs import GraphWidget
from random import choice, seed
from typing import Dict
import kuzu
import pathlib
import torch
import pandas as pd
import string
from utils import load_yaml


config = load_yaml("./default.yml")

backend_data = dict()
backend_id = dict()

for nodetype in config["nodes"]:
    backend_data[nodetype["name"]] = pd.read_parquet(nodetype["file"])
    backend_id[nodetype["name"]] = nodetype["key"]

db = kuzu.Database(str(os.path.abspath("")) + "/data/demo")

train_inds = torch.tensor(pd.read_parquet(config["task"]["train_inds"])["ids"].values, dtype=torch.long)
test_inds = torch.tensor(pd.read_parquet(config["task"]["test_inds"])["ids"].values, dtype=torch.long)

train_loader, test_loader = utils.get_loaders(db, 1, train_inds, test_inds, config["task"]["target_entity"])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from torch_geometric.nn import to_hetero
from torch import optim
import matplotlib.pyplot as plt
import copy

its = iter(test_loader)
batch = next(its)

batch = batch.to(device)
metadata = batch.metadata()

module = __import__(config["script"])
mdl = getattr(module, config["task"]["model"])
margs = config["task"]["model_args"]
hmargs = config["task"]["heteromodel_args"]
model = mdl(**margs)
model = to_hetero(model, metadata, **hmargs).to(device)

opt = getattr(optim, config["task"]["optimizer"])
oargs = config["task"]["optimizer_args"]
optimizer = opt(model.parameters(), **oargs)

model.load_state_dict(torch.load(str(os.path.abspath("")) + "/models/checkpoint.pt"))
model.eval()

In [None]:
def create_explanation():
    batch = next(its)

    while batch["person"].y.item() != 1.0:
        batch = next(its)

    explainer = Explainer(
        model=model,
        algorithm=CaptumExplainer('IntegratedGradients'),
        explanation_type='model',
        model_config=dict(
            mode='multiclass_classification',
            task_level='graph',
            return_type='raw',
        ),
        node_mask_type='attributes',
        edge_mask_type='object',
    )

    btch = batch.to(device)

    explanation = explainer(
        btch.x_dict,
        btch.edge_index_dict,
    )

    btch_ = copy.copy(btch)

    ents = list(btch_.__dict__["_node_store_dict"].keys())
    for entity in ents:
        btch_[entity]["node_mask"] = explanation[entity]["node_mask"]

    ents = list(btch_.__dict__["_edge_store_dict"].keys()) 
    for entity in ents:
        btch_[entity]["edge_mask"] = explanation[entity]["edge_mask"]
        
    for nodetype in config["nodes"]:
        btch_[nodetype["name"]]["node_mask"] = (btch_[nodetype["name"]]["node_mask"].abs().sum(axis=1)).log()
    
    homo_batch = btch_.to_homogeneous(node_attrs=["node_mask"])
    g = torch_geometric.utils.to_networkx(homo_batch, node_attrs=["node_mask"])

    for i, e in enumerate(torch.tensor([g.nodes[n]['node_mask'] for n in g.nodes]).softmax(0).tolist()):
        g.nodes[i]['node_mask'] = e

    plt.Figure()
    pos = nx.spring_layout(g)
    ec = nx.draw_networkx_edges(g, pos, alpha=0.2)
    nc = nx.draw_networkx_nodes(g, pos, node_color=[g.nodes[n]['node_mask'] for n in g.nodes], node_size=800, cmap=plt.cm.Blues, edgecolors="k", vmin=0.0,)
    plt.colorbar(nc)
    plt.axis('off')
    plt.show()

In [None]:
create_explanation()

In [None]:
batch = next(its)

while batch["person"].y.item() != 1.0:
    batch = next(its)

explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='raw',
    ),
    node_mask_type='attributes',
    edge_mask_type='object',
)

btch = batch.to(device)

explanation = explainer(
    btch.x_dict,
    btch.edge_index_dict,
)

btch_ = copy.copy(btch)

ents = list(btch_.__dict__["_node_store_dict"].keys())
for entity in ents:
    btch_[entity]["node_mask"] = explanation[entity]["node_mask"]

ents = list(btch_.__dict__["_edge_store_dict"].keys()) 
for entity in ents:
    btch_[entity]["edge_mask"] = explanation[entity]["edge_mask"]
    
for nodetype in config["nodes"]:
    btch_[nodetype["name"]]["node_mask"] = (btch_[nodetype["name"]]["node_mask"].abs().sum(axis=1)).log()

homo_batch = btch_.to_homogeneous(node_attrs=["node_mask", "id"])
g = torch_geometric.utils.to_networkx(homo_batch, node_attrs=["node_mask"])

for i, e in enumerate(torch.tensor([g.nodes[n]['node_mask'] for n in g.nodes]).softmax(0).tolist()):
    g.nodes[i]['node_mask'] = e

#labels = dict(enumerate(homo_batch["node_type"].tolist()))

labels = []
for t in sorted(list(set(homo_batch["node_type"].tolist()))):
    labels.extend(backend_data[homo_batch._node_type_names[t]].iloc[homo_batch["id"][homo_batch["node_type"] == t].tolist()][backend_id[homo_batch._node_type_names[t]]].values.tolist())

labels = dict(enumerate(labels))

plt.Figure()
pos = nx.spring_layout(g)
ec = nx.draw_networkx_edges(g, pos, alpha=0.2)
nc = nx.draw_networkx_nodes(g, pos, node_color=[g.nodes[n]['node_mask'] for n in g.nodes], node_size=800, cmap=plt.cm.Blues, edgecolors="k", vmin=0.0,)
la = nx.draw_networkx_labels(g, pos, labels)
plt.colorbar(nc)
plt.axis('off')
plt.show()

In [None]:
homo_batch._node_type_names