In [1]:
import os, time
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import dgl
import dgl.function as fn
import pickle as pkl
import torch
from sklearn.metrics import recall_score, precision_score, roc_auc_score, roc_curve

In [2]:
from xgnn_src.shared_networks import MLP, MLP_PRED
from xgnn_src.node.online_kg2 import AllOnlineKG
import torch.nn.functional as F

In [3]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [8, 8]

In [4]:
from utils import *
from xgnn_src.node.eval import *

In [5]:
import networkx as nx
import matplotlib.pyplot as plt
import collections

In [6]:
from xgnn_src.node.eval import predict, evaluate_dataset, extract_true_motif, explain_test
from xgnn_src.node.draw import draw_simple_graph

In [7]:
Arg = collections.namedtuple("Arg", ["teacher_name", "hidden_sizes", "n_layers", "dropout", "n_hidden", "all_layer_dp", "skip_norm"])

- Nodes of label 0 belong to the base BA graph
- Nodes of label 1, 2, 3 are separately at the middle, bottom, or top of houses

### Using bidirected graph (default)

```
python online_kg2.py --dataset TRG --temp 2 --n-epochs 1000 --gpu 0 --teacher-name gcn2 --student-type gcn --n-hidden 64 --n-layers 5 --lr 0.01 --all-layer-dp --skip-norm --add-reverse --teacher-pretrain ./ckpt/gcn/tree_circles_feat_bidir_kl4.pt --sl-factor 1

python online_kg2.py --dataset TRG --temp 2 --n-epochs 1000 --gpu 0 --teacher-name gcn2 --student-type gcn --n-hidden 64 --n-layers 5 --lr 0.01 --all-layer-dp --skip-norm --add-reverse --teacher-pretrain ./ckpt/gcn/tree_circles_feat_bidir_kl4.pt --sl-factor 2

python online_kg2.py --dataset TRG --temp 2 --n-epochs 1000 --gpu 0 --teacher-name gcn2 --student-type gcn --n-hidden 64 --n-layers 5 --lr 0.01 --all-layer-dp --skip-norm --add-reverse --teacher-pretrain ./ckpt/gcn/tree_circles_feat_bidir_kl4.pt --sl-factor 4

python online_kg2.py --dataset TRG --temp 2 --n-epochs 1000 --gpu 0 --teacher-name gcn2 --student-type gcn --n-hidden 64 --n-layers 5 --lr 0.01 --all-layer-dp --skip-norm --add-reverse --teacher-pretrain ./ckpt/gcn/tree_circles_feat_bidir_kl4.pt --sl-factor 5

python online_kg2.py --dataset TRG --temp 2 --n-epochs 1000 --gpu 0 --teacher-name gcn2 --student-type gcn --n-hidden 64 --n-layers 5 --lr 0.01 --all-layer-dp --skip-norm --add-reverse --teacher-pretrain ./ckpt/gcn/tree_circles_feat_bidir_kl4.pt --sl-factor 10
```

In [20]:
kl = "2"
# ba_shape1.g, ba_shape.pt
with open('./datasets/tree_cycle_bidir.g', 'rb') as f:
    g = pkl.load(f)

num_classes = 2
test_labels = g.ndata['label'][g.ndata['test_mask']]
feats = g.ndata['feat'].size()[1]
arg = Arg("gcn2", [64], 5, 0.5, 64, True, True)
base = init_teacher(arg, g, feats, num_classes)
graph_std = init_graph_student("gcn", g, feats, num_classes, 0.5, n_hidden=64,
                                n_layers=5, hidden_sizes=None, all_layer_dp=True, skip_norm=True)
mlp = MLP(feats, [64], num_classes, F.relu, 0.5, batch_norm=True, norm_type='bn')
online_mode = AllOnlineKG(base, graph_std, mlp, graph_student_name="gcn")
model = torch.load('./ckpt/gcn/tree_circles_feat_bidir_kl%s.pt'%kl, map_location="cuda:0")
online_mode.load_state_dict(model)

using norm in graph False
GCN2(
  (layers): ModuleList(
    (0): GraphConv(in=3, out=64, normalization=both, activation=None)
    (1): GraphConv(in=64, out=64, normalization=both, activation=None)
    (2): GraphConv(in=64, out=64, normalization=both, activation=None)
    (3): GraphConv(in=64, out=64, normalization=both, activation=None)
    (4): GraphConv(in=64, out=64, normalization=both, activation=None)
    (5): GraphConv(in=64, out=2, normalization=both, activation=None)
  )
  (batch_norms): ModuleList()
  (dropout): Dropout(p=0.5, inplace=False)
)
using norm in graph False
norm type: bn
norm type: bn
norm type: bn


<All keys matched successfully>

In [21]:
b_preds, e_preds = predict(base, graph_std, g)

Base accuracy: 0.9770, Explainer accuracy: 0.9770
Agreement score: 0.9770, KL Score: 0.0048


In [22]:
selected_node = list(range(511,871,6))
test_pentagon = [extract_true_motif(g, i, 510, 6) for i in selected_node]

In [23]:
s = time.time()
all_pre, all_rec, _ = evaluate_dataset(g, selected_node, test_pentagon, 6, 10, dumping_factor=0.55, num_iter=10)
print("Test for tree circles - Precision: %.4f & Recall: %.4f F1 Score: %.4f"
          % (all_pre, all_rec, (all_pre*all_rec*2/(all_pre+all_rec))))
print(time.time() - s)

60it [00:02, 23.56it/s]

Test for tree circles - Precision: 0.9825 & Recall: 0.9333 F1 Score: 0.9573
2.553373098373413





00
Test for tree circles - Precision: 0.9972 & Recall: 1.0000 F1 Score: 0.9986
1.9186439514160156

0.1
Test for tree circles - Precision: 0.9917 & Recall: 1.0000 F1 Score: 0.9959
2.2164316177368164

0.3 Test for tree circles - Precision: 0.9677 & Recall: 1.0000 F1 Score: 0.9836
2.8235068321228027

0.5
Test for tree circles - Precision: 0.9648 & Recall: 0.9889 F1 Score: 0.9767
2.4095003604888916

1
Test for tree circles - Precision: 0.9816 & Recall: 0.8889 F1 Score: 0.9329
2.7337543964385986

2
Test for tree circles - Precision: 0.9825 & Recall: 0.9333 F1 Score: 0.9573
2.553373098373413

4
Test for tree circles - Precision: 0.9832 & Recall: 0.9778 F1 Score: 0.9805
2.6498923301696777

5 
Test for tree circles - Precision: 0.9721 & Recall: 0.9667 F1 Score: 0.9694
2.901827335357666

10
Test for tree circles - Precision: 0.9697 & Recall: 0.8889 F1 Score: 0.9275
3.024421453475952

## TREE GRID

In [9]:
kl = "01"
with open('./datasets/tree_grid_bidir.g', 'rb') as f:
    g1 = pkl.load(f)
#     g1 = g1.add_self_loop()
num_classes = 2
test_labels1 = g1.ndata['label'][g1.ndata['test_mask']]
feats1 = g1.ndata['feat'].size()[1]
arg1 = Arg("gcn2", [64], 5, 0.5, 64, True, True)
base1 = init_teacher(arg1, g1, feats1, num_classes)
graph_std1 = init_graph_student("gcn", g1, feats1, num_classes, 0.5, n_hidden=64,
                                n_layers=5, hidden_sizes=None, all_layer_dp=True, skip_norm=True)
mlp1 = MLP(feats1, [64], num_classes, F.relu, 0.5, batch_norm=True, norm_type='bn')
online_mode1 = AllOnlineKG(base1, graph_std1, mlp1, graph_student_name="gcn")
model1 = torch.load('./ckpt/gcn/tree_grid_feat_bidir_kl%s.pt' % kl, map_location="cuda:0")
online_mode1.load_state_dict(model1)

using norm in graph False
GCN2(
  (layers): ModuleList(
    (0): GraphConv(in=3, out=64, normalization=both, activation=None)
    (1): GraphConv(in=64, out=64, normalization=both, activation=None)
    (2): GraphConv(in=64, out=64, normalization=both, activation=None)
    (3): GraphConv(in=64, out=64, normalization=both, activation=None)
    (4): GraphConv(in=64, out=64, normalization=both, activation=None)
    (5): GraphConv(in=64, out=2, normalization=both, activation=None)
  )
  (batch_norms): ModuleList()
  (dropout): Dropout(p=0.5, inplace=False)
)
using norm in graph False
norm type: bn
norm type: bn
norm type: bn


<All keys matched successfully>

In [10]:
b, e = predict(base1, graph_std1, g1)

Base accuracy: 0.9593, Explainer accuracy: 0.9431
Agreement score: 0.9675, KL Score: 0.0022


In [11]:
selected = list(range(511,800,1))
test_graphs = [extract_true_motif(g1, i, 510, 9) for i in selected]

In [12]:
t = time.time()
all_pre, all_rec, options = evaluate_dataset(g1, selected, test_graphs, 10, 12, dumping_factor=0.9, num_iter=5)
print("Test for tree grid - Precision: %.4f & Recall: %.4f F1: %.4f"
      % (all_pre, all_rec, (2*all_pre*all_rec/(all_pre+all_rec))))
e = time.time()
print(e - t)

289it [00:05, 49.88it/s]


Test for tree grid - Precision: 0.9716 & Recall: 0.9178 F1: 0.9440
5.814428091049194


0.0
Test for tree grid - Precision: 0.9578 & Recall: 0.8414 F1: 0.8959
6.869914293289185

0.1
Test for tree grid - Precision: 0.9716 & Recall: 0.9178 F1: 0.9440
5.814428091049194

0.3
Test for tree grid - Precision: 0.9691 & Recall: 0.9048 F1: 0.9359
6.149960517883301

0.5 Test for tree grid - Precision: 0.9690 & Recall: 0.9014 F1: 0.9340
5.5872581005096436

1
Test for tree grid - Precision: 0.9711 & Recall: 0.9100 F1: 0.9396
5.588150262832642

2
Test for tree grid - Precision: 0.9669 & Recall: 0.8939 F1: 0.9290
5.943668842315674

4 Test for tree grid - Precision: 0.9650 & Recall: 0.8932 F1: 0.9277
6.02161431312561

5
Test for tree grid - Precision: 0.9703 & Recall: 0.9057 F1: 0.9369
6.0125508308410645

10
Test for tree grid - Precision: 0.9678 & Recall: 0.8961 F1: 0.9306
5.863856077194214