In [1]:
import time
from sklearn.metrics import recall_score, f1_score, accuracy_score, precision_score, roc_curve
import torch
import torch.nn.functional as F
import numpy as np
from xgnn_src.shared_networks import OnlineKG, NaiveTeacher
from xgnn_src.graph.gcn import GCN, GCN_MLP
from xgnn_src.graph.utils import load_data, get_mask, draw_simple_graph
import pickle as pkl
from sklearn.model_selection import StratifiedKFold
from xgnn_src.graph.dataloader import GCDataLoader

In [2]:
import collections
import dgl
import networkx as nx

In [3]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [8, 8]

In [4]:
def get_mask(g, base, explainer, undir=True, threshold=0.5):
    base.eval()
    explainer.eval()
    with torch.no_grad():
        base(g, g.ndata['attr'])
        embedding = g.ndata['emb']
        edge_weight = explainer.edge_mask.compute_adj(g, embedding)
        edge_weight = explainer.edge_mask.edge_mlp(edge_weight)
        mask = explainer.edge_mask.concrete(edge_weight, beta=5.)
#         print(mask)
#         mask = F.sigmoid(edge_weight)
#         print(mask)
    with g.local_scope():
        num_nodes = g.num_nodes()
        adj = [[0.] * num_nodes for _ in range(num_nodes)] 
        src, dst = g.edges()
        for i, (s, d) in enumerate(zip(src, dst)):
            s, d = s.item(), d.item()
            m = mask[i].item()
            if m < threshold:
                adj[s][d] = 0.0
                if undir:
                    continue
                adj[d][s] = 0.0
            else:
                adj[s][d] = m
        weight = []
        for s, d in zip(src, dst):
            s, d = s.item(), d.item()
            weight.append(adj[s][d])
        mask = np.array(weight)
    return mask

def obtain_predictions(tpr, fpr, threshold, preds):
    optimal_proba_cutoff = sorted(zip(tpr - fpr, threshold), key=lambda i: i[0], reverse=True)[0][1]
    roc_predictions = [1 if i >= optimal_proba_cutoff else 0 for i in preds]
    return roc_predictions

In [5]:
dataset, dim_nfeats, gclasses = load_data('BA')

Done loading data from cached files.


- Low kl term => focus on class 0 (<1)
- KL >=1 => focus on both
- KL too large => 

In [6]:
with open('./graph/ckpt/gcn/ba_f5.dat', 'rb') as f:
    dataloader, idx = pkl.load(f)
train_idx, valid_idx = idx

In [7]:
base = GCN(10, 64, 2, 5, 0.5, 'max', 'last')
explainer = GCN_MLP(10, 64, 2, 5, 0.5, 64 * 2, 'max', 'last', 'sigmoid', False, 'bn')
teacher = NaiveTeacher(2, 'mean')
online_mode = OnlineKG(base, explainer, teacher).to(torch.device('cpu'))

norm type: bn
norm type: bn


In [9]:
kl = "01"

In [10]:
model = torch.load('./graph/ckpt/gcn/ba_kl%s.pt'%kl, map_location='cuda:0')
online_mode.load_state_dict(model)

<All keys matched successfully>

In [11]:
from scipy.sparse import coo_matrix, csr_matrix
def get_edge_labels(adj, insert=20):
    real = []
    skip_lv = insert + 5
    for r, c in list(zip(adj.row, adj.col)):
        if r >= insert and r < skip_lv and c >= insert and c < skip_lv:
            real.append(1)
        else:
            real.append(0)
    return real

In [12]:
def symmetrize_weight(g, weight):
    adj = g.adj(scipy_fmt="coo")
    adj.data = weight
    adj_t = adj.transpose()
    adj = adj + adj_t
#     w2 = adj_t.data
#     return np.maximum(weight, w2)
    return adj.data

In [13]:
eval_graph_idxs = list(range(0,100)) + list(range(500,600))

In [14]:
s = time.time()
reals, preds = [], []
for i in eval_graph_idxs:
    g = dataset[i][0]
    weight = get_mask(g, base, explainer, True, threshold=0.0)
    weight = symmetrize_weight(g, weight)
    adj = g.adj(scipy_fmt='coo')
    label = get_edge_labels(adj)
    fpr, tpr, threshold = roc_curve(label, weight)
    pred = obtain_predictions(tpr, fpr, threshold, weight)    
    reals.extend(label)
    preds.extend(pred)

pr = precision_score(reals, preds)
re = recall_score(reals, preds)
f1 = 2 * pr * re / (pr + re)
print("Precision %.4f, Recall %.4f F1 %.4f" % (pr, re, f1))
print(time.time() - s)

Precision 0.5707, Recall 0.4660 F1 0.5131
1.551431655883789
