In [10]:
import time
from sklearn.metrics import recall_score, f1_score, accuracy_score, precision_score, roc_curve
import torch
import torch.nn.functional as F
import numpy as np
from xgnn_src.shared_networks import OnlineKG, NaiveTeacher
from xgnn_src.graph.gcn import GCN, GCN_MLP, GCN_MLP2
from xgnn_src.graph.utils import load_data
import pickle as pkl
from sklearn.model_selection import StratifiedKFold
from xgnn_src.graph.dataloader import GCDataLoader
from xgnn_src.graph.utils import draw_mutag, get_mask

In [2]:
import collections
import dgl
import networkx as nx
import matplotlib.pyplot as plt

In [3]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [8, 8]

In [15]:
def get_mask(g, base, explainer, undir=True, threshold=0.5):
    base.eval()
    explainer.eval()
    with torch.no_grad():
        base(g, g.ndata['attr'])
        embedding = g.ndata['emb']
        edge_weight = explainer.edge_mask.compute_adj(g, embedding)
        edge_weight = explainer.edge_mask.edge_mlp(edge_weight)
        mask = explainer.edge_mask.concrete(edge_weight, beta=5.)
#         print(mask)
#         mask = F.sigmoid(edge_weight)
#         print(mask)
    with g.local_scope():
        num_nodes = g.num_nodes()
        adj = [[0.] * num_nodes for _ in range(num_nodes)] 
        src, dst = g.edges()
        for i, (s, d) in enumerate(zip(src, dst)):
            s, d = s.item(), d.item()
            m = mask[i].item()
            if m < threshold:
                adj[s][d] = 0.0
                if undir:
                    continue
                adj[d][s] = 0.0
            else:
                adj[s][d] = m
        weight = []
        for s, d in zip(src, dst):
            s, d = s.item(), d.item()
            weight.append(adj[s][d])
        mask = np.array(weight)
    return mask

In [26]:
def obtain_predictions(tpr, fpr, threshold, preds):
    optimal_proba_cutoff = sorted(zip(np.abs(tpr - fpr), threshold), key=lambda i: i[0], reverse=True)[0][1]
    roc_predictions = [1 if i >= optimal_proba_cutoff else 0 for i in preds]
    return roc_predictions

In [6]:
mutag_labels = ['C', 'O', 'Cl', 'H', 'N', 'F', 'Br', 'S', 'P', 'I', 'Na', 'K', 'Li', 'Ca']

In [8]:
kl = "4"

In [11]:
base = GCN(14, 64, 2, 5, 0.0, 'max', 'last')
explainer = GCN_MLP2(14, 64, 2, 5, 0.5, 64 * 2, 'max', 'last', 'sigmoid', False, 'bn', 0.0)
teacher = NaiveTeacher(2, 'mean')
online_mode = OnlineKG(base, explainer, teacher)
model = torch.load('./graph/ckpt/gcn/pgmutag_rand5_kl%s_noemb.pt'%kl)
online_mode.load_state_dict(model)

norm type: bn
norm type: bn
norm type: bn


<All keys matched successfully>

In [12]:
dataset, dim_nfeats, gclasses = load_data('Mutagenicity', './graph/datasets/dgl_mutagenicity.pkl', neg_ratio=1.)

In [13]:
graphs = []
for g, l in dataset:
    if l == 0 and g.edata['edge_labels'].sum() > 0:
        graphs.append(g)

In [16]:
s = time.time()
reals, preds = [], []
for g in graphs:
    weight = get_mask(g, base, explainer, True, threshold=0.0)
    label = g.edata['edge_labels'].numpy()
    fpr, tpr, threshold = roc_curve(label, weight)
    pred = obtain_predictions(tpr, fpr, threshold, weight)    
    reals.extend(label)
    preds.extend(pred)

pr = precision_score(reals, preds)
re = recall_score(reals, preds)
f1 = 2 * pr * re / (pr + re)
print("Precision %.4f, Recall %.4f F1 %.4f" % (pr, re, f1))
print(time.time() - s)

Precision 0.0000, Recall 0.0000 F1 nan
7.055383205413818


In [28]:
reals, preds = [], []
for g in graphs:
    weight = get_mask(g, base, explainer, True, threshold=0.0)
    label = g.edata['edge_labels'].numpy()
    fpr, tpr, threshold = roc_curve(label, weight)
#     pred = obtain_predictions(tpr, fpr, threshold, weight)    
    pred = weight.astype(np.int)
    reals.extend(label)
    preds.extend(pred)

pr = precision_score(reals, preds)
re = recall_score(reals, preds)
f1 = 2 * pr * re / (pr + re)
print("Precision %.4f, Recall %.4f F1 %.4f" % (pr, re, f1))

Precision 0.0980, Recall 1.0000 F1 0.1784
