# 01 - explanations

In [None]:
import sys
sys.path.append("../")
from src import utils, artificial_data

In [None]:
import networkx as nx
import torch_geometric
from torch_geometric.explain import Explainer, GNNExplainer, CaptumExplainer
from yfiles_jupyter_graphs import GraphWidget
from random import choice, seed
from typing import Dict
import kuzu
import pathlib
import torch
import pandas as pd
import os
import string
from torch_geometric.nn import (MLP, BatchNorm, GraphConv, MultiAggregation,
                                SAGEConv, to_hetero)
from torch_geometric.nn import (MLP, BatchNorm, LayerNorm, GraphConv, MultiAggregation,
                                    SAGEConv, to_hetero)

db = kuzu.Database(str(os.path.abspath("")) + "/../data/demo")
persons = pd.read_parquet(str(os.path.abspath("")) + "/../data/persons.parquet")
train_inds = torch.tensor(persons[persons["mode"] == "train"].index.tolist(), dtype=torch.long)
test_inds = torch.tensor(persons[persons["mode"] == "test"].index.tolist(), dtype=torch.long)
train_loader, test_loader = utils.get_loaders(db, 1, train_inds, test_inds)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
diagnosis_data = pd.read_parquet(str(os.path.abspath("")) + "/../data/diagnoses.parquet")
drug_data = pd.read_parquet(str(os.path.abspath("")) + "/../data/drugs.parquet")

In [None]:
its = iter(test_loader)
batch = next(its)

aggr = "max"

batch = batch.to(device)
metadata = batch.metadata()

model = utils.GraphLevelGNN()
model = to_hetero(model, metadata, aggr=aggr, debug=True).to(device)

model.load_state_dict(torch.load(str(os.path.abspath("")) + "/../models/checkpoint.pt"))
model.eval()

# networkx-viz

In [None]:
batch["person"].y.item()

In [None]:
its = iter(test_loader)

In [None]:
batch = next(its)

while batch["person"].y.item() != 1.0:
    batch = next(its)

explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='raw',
    ),
    node_mask_type='attributes',
    edge_mask_type='object',
    #threshold_config=dict(
    #    #threshold_type='topk',
    #    threshold_type='topk',
    #    value=200,
    #),
)

btch = batch.to(device)

explanation = explainer(
    btch.x_dict,
    btch.edge_index_dict,
)

In [None]:
explanation

In [None]:
import copy

btch_ = copy.copy(btch)

ents = list(btch_.__dict__["_node_store_dict"].keys())
for entity in ents:
    btch_[entity]["node_mask"] = explanation[entity]["node_mask"]

ents = list(btch_.__dict__["_edge_store_dict"].keys()) 
for entity in ents:
    btch_[entity]["edge_mask"] = explanation[entity]["edge_mask"]
    
btch_

In [None]:
from collections import OrderedDict

# Mapping from id to specific node inside heterogeneous graph
node_ents = sorted(list(btch_.__dict__["_node_store_dict"].keys()))
edge_ents = sorted(list(btch_.__dict__["_edge_store_dict"].keys()))

In [None]:
#btch_["drug"]["node_mask"] = (btch_["drug"]["node_mask"].abs().sum(axis=1)).log()
#btch_["person"]["node_mask"] = (btch_["person"]["node_mask"].abs().sum(axis=1)).log()
#btch_["diagnosis"]["node_mask"] = (btch_["diagnosis"]["node_mask"].abs().sum(axis=1)).log()

btch_["drug"]["node_mask"] = (btch_["drug"]["node_mask"].abs().sum(axis=1)).log()
btch_["person"]["node_mask"] = (btch_["person"]["node_mask"].abs().sum(axis=1)).log()
btch_["diagnosis"]["node_mask"] = (btch_["diagnosis"]["node_mask"].abs().sum(axis=1)).log()
homo_batch = btch_.to_homogeneous(node_attrs=["node_mask"])

In [None]:
g = torch_geometric.utils.to_networkx(homo_batch, node_attrs=["node_mask"])

In [None]:
#nx.set_node_attributes(g, torch.tensor([g.nodes[n]['node_mask'] for n in g.nodes]).softmax(0).tolist(), "node_mask")

for i, e in enumerate(torch.tensor([g.nodes[n]['node_mask'] for n in g.nodes]).softmax(0).tolist()):
    g.nodes[i]['node_mask'] = e

In [None]:
[g.nodes[n]['node_mask'] for n in g.nodes]

In [None]:
import matplotlib.pyplot as plt

plt.Figure()
pos = nx.spring_layout(g)
ec = nx.draw_networkx_edges(g, pos, alpha=0.2)
nc = nx.draw_networkx_nodes(g, pos, node_color=[g.nodes[n]['node_mask'] for n in g.nodes], node_size=800, cmap=plt.cm.Blues, edgecolors="k", vmin=0.0,)
plt.colorbar(nc)
plt.axis('off')
plt.show()

In [None]:
homo_batch = batch.to_homogeneous(node_attrs=["node_mask"])
g = torch_geometric.utils.to_networkx(homo_batch)

In [None]:
G = nx.Graph()
G.add_nodes_from(list(range(btch_.num_nodes)))

options = {"edgecolors": "tab:gray", "node_size": 800, "alpha": 0.9}
nx.draw_networkx_nodes(G, nodelist=[0, 1, 2, 3], node_color="tab:red", **options)
nx.draw_networkx_nodes(G, nodelist=[4, 5, 6, 7], node_color="tab:blue", **options)

nx.draw(G)

# yFiles

In [None]:
its = iter(test_loader)
batch = next(its)
batch = next(its)

homo_batch = batch.to_homogeneous()
g = torch_geometric.utils.to_networkx(homo_batch)

w = GraphWidget(graph=g)

def custom_color_mapping(index: int, node: Dict):
    """throw some hex numbers together"""
    return "#"+''.join([choice('0123456789abcdef') for j in range(6)])

def custom_color_mapping2(index: int, node: Dict):
    """throw some hex numbers together"""
    colors = ["#ff0000", "#00ff00", "#0000ff"]
    return colors[homo_batch["node_type"][index]]

w.set_node_color_mapping(custom_color_mapping2)

res = pd.DataFrame({"text": [i for i in string.ascii_letters[:26]]})
restable = utils.encode_strings(res["text"])

explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='raw',
    ),
    node_mask_type='attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type='topk',
        value=200,
    ),
)

btch = batch.to(device)

explanation = explainer(
    btch.x_dict,
    btch.edge_index_dict,
)

# Create Labels
labels = []
for i in set(homo_batch["node_type"].tolist()):
    if (homo_batch["node_type"] == i).sum().item() == 1:
        labels.append("PID")
    else:
        if len(drug_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["PID"].unique()) == 1:
            a = drug_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["drug_short"].values.tolist()
            b = drug_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["drug_age"].values.tolist()
            tmplabels = ["drug: " + j + " " + k[0] for j,k in zip(a, [str(c) for c in b])]
            labels.extend(tmplabels)
        elif len(diagnosis_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["PID"].unique()) == 1:
            a = diagnosis_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["diagnosis_short"].values.tolist()
            b = diagnosis_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["diagnosis_age"].values.tolist()
            tmplabels = ["diagnosis: " + j + " " + k[0] for j,k in zip(a, [str(c) for c in b])]
            labels.extend(tmplabels)

def labels_func(index: int, node: Dict):
    return labels[index]

w.set_node_label_mapping(labels_func)

w

In [None]:
w = GraphWidget(graph=g)

def custom_color_mapping(index: int, node: Dict):
    """throw some hex numbers together"""
    return "#"+''.join([choice('0123456789abcdef') for j in range(6)])

def custom_color_mapping3(index: int, node: Dict):
    """throw some hex numbers together"""
    # Get number of nodes per nodetype
    print(node)
    max_val = 16.0
    if homo_batch["node_type"][index] == 0:
        scalar = explanation["drug"]["node_mask"][index - 0].abs().sum().item() / max_val
    elif homo_batch["node_type"][index] == 1:
        scalar = explanation["diagnosis"]["node_mask"][index - 12].abs().sum().item() / max_val
    elif homo_batch["node_type"][index] == 2:
        scalar = explanation["person"]["node_mask"][index - 22].abs().sum().item() / max_val
    
    scalar = max(0, scalar)
    scalar = min(255, scalar)
    return '#ff' + ('%02x' % (255 - int(scalar * 255))) + ('%02x' % (255 - int(scalar * 255)))

w.set_node_color_mapping(custom_color_mapping3)

# Create Labels
labels = []
for i in set(homo_batch["node_type"].tolist()):
    if (homo_batch["node_type"] == i).sum().item() == 1:
        labels.append("KVNR")
    else:
        if len(drug_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["PID"].unique()) == 1:
            a = drug_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["drug_short"].values.tolist()
            b = drug_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["drug_age"].values.tolist()
            tmplabels = ["drug: " + j + " " + k[0] for j,k in zip(a, [str(c) for c in b])]
            labels.extend(tmplabels)
        elif len(diagnosis_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["PID"].unique()) == 1:
            a = diagnosis_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["diagnosis_short"].values.tolist()
            b = diagnosis_data.iloc[homo_batch["id"][homo_batch["node_type"] == i].tolist()]["diagnosis_age"].values.tolist()
            tmplabels = ["diagnosis: " + j + " " + k[0] for j,k in zip(a, [str(c) for c in b])]
            labels.extend(tmplabels)

def labels_func(index: int, node: Dict):
    return labels[index]

w.set_node_label_mapping(labels_func)

w

In [13]:
from rich import print
import yaml
from yaml.loader import SafeLoader

with open('../default.yml') as f:
    config = yaml.load(f, Loader=SafeLoader)
    print(config)

In [15]:
target_entity = config["task"]["target_entity"]

In [18]:
config["nodes"][target_entity]["mode"]

TypeError: list indices must be integers or slices, not str