In [1]:
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
import dgl
import dgl.nn as dglnn
from tqdm import trange

from collections import Counter

Using backend: pytorch


In [2]:
data = np.load('./phase1_gdata.npz')

node_feat = data['x']
node_label = data['y']
edge_pair = data['edge_index']
edge_type = data['edge_type']
edge_time = data['edge_timestamp']
    
train_mask = data['train_mask']
test_mask = data['test_mask']

g = dgl.graph(edge_pair.tolist())
# ss = StandardScaler()
# g.ndata['feat'] = torch.Tensor(ss.fit_transform(node_feat))
g.ndata['feat'] = torch.Tensor(node_feat)
g.ndata['label'] = torch.Tensor(node_label)
g.edata['type'] = torch.Tensor(edge_type)
g.edata['time'] = torch.Tensor(edge_time)
g

Graph(num_nodes=4059035, num_edges=4962032,
      ndata_schemes={'feat': Scheme(shape=(17,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.float32)}
      edata_schemes={'type': Scheme(shape=(), dtype=torch.float32), 'time': Scheme(shape=(), dtype=torch.float32)})

In [3]:
class NeighborSampler(object):
    def __init__(self, g, neighbor_nums):
        """
        g 为 DGLGraph；
        fanouts 为采样节点的数量，实验使用 10,25，指一阶邻居采样 10 个，二阶邻居采样 25 个。
        """
        self.g = g
        self.neighbor_nums = neighbor_nums

    def sample_blocks(self, nodes):
        nodes = torch.LongTensor(nodes)
        blocks = []
        for neighbor_num in self.neighbor_nums: 
            # sample_neighbors 可以对每一个种子的节点进行邻居采样并返回相应的子图
            # replace=True 表示用采样后的邻居节点代替所有邻居节点
            frontier = dgl.sampling.sample_neighbors(g, nodes, neighbor_num, replace=True)
            # 将图转变为可以用于消息传递的二部图（源节点和目的节点）
            # 其中源节点的 id 也可能包含目的节点的 id（原因上面说了）
            # 转变为二部图主要是为了方便进行消息传递
            block = dgl.to_block(frontier, nodes)
            # 获取新图的源节点作为种子节点，为下一层作准备
            # 之所以是从 src 中获取种子节点，是因为采样操作相对于聚合操作来说是一个逆向操作
            nodes = block.srcdata[dgl.NID]
            # 把这一层放在最前面。
            # PS：如果数据量大的话，插入操作是不是不太友好。
            blocks.insert(0, block)
            
        return blocks

In [4]:
train_mask_0 = train_mask[node_label[train_mask]==0]
train_mask_0_sample = np.random.choice(train_mask_0, size=int(len(train_mask_0)/3), replace=False)
train_mask_1 = train_mask[node_label[train_mask]==1]

train_mask_repeat = np.concatenate([train_mask_0_sample, train_mask_1])

In [5]:
batch_size = 256

np.random.shuffle(train_mask_repeat)
train_idx = train_mask_repeat[:int(len(train_mask_repeat)/100*80)]
val_idx = train_mask_repeat[int(len(train_mask_repeat)/100*80):]


sampler = NeighborSampler(g, [10, 5])
dataloader = DataLoader(
    dataset = train_idx,
    batch_size = batch_size,
    collate_fn = sampler.sample_blocks,
    shuffle = True,
    drop_last = False)

# blocks = iter(dataloader).next()

In [10]:
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers, dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'pool'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'pool'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'pool'))
        self.dropout = nn.Dropout(dropout)

    def forward(self, blocks, x):
        # block 是我们采样获得的二部图，这里用于消息传播
        # x 为节点特征
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h_dst = h[:block.number_of_dst_nodes()]
            h = layer(block, (h, h_dst))
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size):
        # inference 用于评估测试，针对的是完全图
        # 目前会出现重复计算的问题，优化方案还在 to do list 上
        nodes = torch.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = torch.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
            for start in trange(0, len(nodes), batch_size):
                end = start + batch_size
                batch_nodes = nodes[start:end]
                block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
                input_nodes = block.srcdata[dgl.NID]
                h = torch.Tensor(x[input_nodes])
                h_dst = h[:block.number_of_dst_nodes()]
                h = layer(block, (h, h_dst))
                if l != len(self.layers) - 1:
                    h = F.relu(h)
                    h = self.dropout(h)
                y[start:end] = h
            x = y
        return y

model = GraphSAGE(17, 128, 1, 2, 0.5)
model

GraphSAGE(
  (layers): ModuleList(
    (0): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_pool): Linear(in_features=17, out_features=17, bias=True)
      (fc_self): Linear(in_features=17, out_features=128, bias=False)
      (fc_neigh): Linear(in_features=17, out_features=128, bias=False)
    )
    (1): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_pool): Linear(in_features=128, out_features=128, bias=True)
      (fc_self): Linear(in_features=128, out_features=1, bias=False)
      (fc_neigh): Linear(in_features=128, out_features=1, bias=False)
    )
  )
  (dropout): Dropout(p=0.5, inplace=False)
)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(50):
    for step, blocks in enumerate(dataloader):
        model.train()
        input_nodes = blocks[0].srcdata[dgl.NID]
        label_nodes = blocks[-1].dstdata[dgl.NID]
        
        batch_inputs = g.ndata['feat'][input_nodes]
        batch_labels = g.ndata['label'][label_nodes]
        
        batch_pred = model(blocks, batch_inputs).view(-1)
        loss = F.binary_cross_entropy_with_logits(batch_pred, batch_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 100 == 0:
            print(epoch, step, loss.item())
#             model.eval()
#             with torch.no_grad():
#                 train_auc = roc_auc_score(batch_labels.numpy(), batch_pred.numpy())
#                 print(epoch, step, train_auc)
               
    if epoch % 5 == 0:
        model.eval()
        with torch.no_grad():
            pred = model.inference(g, node_feat, 4096)
            train_pred = pred[train_idx]
            val_pred = pred[val_idx]
            
            train_labels = g.ndata['label'][train_idx]
            val_labels = g.ndata['label'][val_idx]
            
            train_auc = roc_auc_score(train_labels.numpy(), train_pred.numpy())
            val_auc = roc_auc_score(val_labels.numpy(), val_pred.numpy())
            print("---------------------", train_auc, val_auc)

0 0 2.649331569671631
0 100 0.2822313606739044
0 200 0.1855325996875763
0 300 0.1417270302772522
0 400 0.11734362691640854
0 500 0.18479108810424805
0 600 0.22598479688167572
0 700 0.1491628885269165
0 800 0.14096952974796295


100%|██████████| 991/991 [00:06<00:00, 162.58it/s]
100%|██████████| 991/991 [00:04<00:00, 222.80it/s]


--------------------- 0.7496318111612199 0.742103519485826
1 0 0.2351222038269043
1 100 0.10484462976455688
1 200 0.09510451555252075
1 300 0.16834023594856262
1 400 0.12449563294649124
1 500 0.19780775904655457
1 600 0.12789902091026306
1 700 0.09731359034776688
1 800 0.13891535997390747
2 0 0.12722547352313995
2 100 0.22566935420036316
2 200 0.11443990468978882
2 300 0.1254749596118927
2 400 0.11825525760650635
2 500 0.11100316047668457
2 600 0.13452881574630737
2 700 0.13583149015903473
2 800 0.1192898228764534
3 0 0.1690981388092041
3 100 0.11553408950567245
3 200 0.14309239387512207
3 300 0.15974250435829163
3 400 0.11459143459796906
3 500 0.1316060721874237
3 600 0.106626495718956
3 700 0.14556190371513367
3 800 0.13688163459300995
4 0 0.14346589148044586
4 100 0.10383469611406326
4 200 0.12822437286376953
4 300 0.16107340157032013
4 400 0.11760465055704117
4 500 0.174025297164917
4 600 0.12131816148757935
4 700 0.14380605518817902
4 800 0.15728707611560822
5 0 0.1342465430498123

100%|██████████| 991/991 [00:05<00:00, 166.91it/s]
100%|██████████| 991/991 [00:04<00:00, 217.32it/s]


--------------------- 0.7695994610117648 0.7615725405107457
6 0 0.0996476262807846
6 100 0.16668261587619781
6 200 0.17451409995555878
6 300 0.15087038278579712
6 400 0.1277744621038437
6 500 0.14731264114379883
6 600 0.13505993783473969
6 700 0.12625089287757874
6 800 0.13926541805267334
7 0 0.11633947491645813
7 100 0.1040167286992073
7 200 0.12729111313819885
7 300 0.09676625579595566
7 400 0.16745583713054657
7 500 0.16678522527217865
7 600 0.13185365498065948
7 700 0.13042543828487396
7 800 0.16642487049102783
8 0 0.11650822311639786
8 100 0.13537469506263733
8 200 0.08429985493421555
8 300 0.12027116119861603
8 400 0.255287766456604
8 500 0.11052834242582321
8 600 0.17491966485977173
8 700 0.11043275147676468
8 800 0.18348953127861023
9 0 0.1149219274520874
9 100 0.1331632137298584
9 200 0.15526238083839417
9 300 0.12986648082733154
9 400 0.10743719339370728
9 500 0.11312500387430191
9 600 0.09034992754459381
9 700 0.1953626275062561
9 800 0.1249198466539383
10 0 0.10431461781263

100%|██████████| 991/991 [00:05<00:00, 166.99it/s]
100%|██████████| 991/991 [00:04<00:00, 205.69it/s]


--------------------- 0.7731500260703368 0.7637537949190457
11 0 0.044151369482278824
11 100 0.21166634559631348
11 200 0.08172895759344101
11 300 0.12707360088825226
11 400 0.16621969640254974
11 500 0.14149954915046692
11 600 0.1490268111228943
11 700 0.11897815763950348
11 800 0.1512496918439865
12 0 0.13418959081172943
12 100 0.1025773286819458
12 200 0.13762035965919495
12 300 0.12240463495254517
12 400 0.11734871566295624
12 500 0.15668219327926636
12 600 0.15537357330322266
12 700 0.15545634925365448
12 800 0.1961548626422882
13 0 0.1270238161087036
13 100 0.1339651644229889
13 200 0.12010985612869263
13 300 0.12835094332695007
13 400 0.12859201431274414
13 500 0.1414613276720047
13 600 0.08590342104434967
13 700 0.20057174563407898
13 800 0.14641307294368744
14 0 0.09683316946029663
14 100 0.1714901179075241
14 200 0.17323285341262817
14 300 0.11665309965610504
14 400 0.10678181797266006
14 500 0.13454192876815796
14 600 0.22158849239349365
14 700 0.14773418009281158
14 800 0.1

100%|██████████| 991/991 [00:07<00:00, 138.95it/s]
100%|██████████| 991/991 [00:04<00:00, 219.45it/s]


--------------------- 0.7789325305400926 0.7685486712321729
16 0 0.11797630786895752
16 100 0.18863876163959503
16 200 0.10996104031801224
16 300 0.1113162711262703
16 400 0.12235059589147568
16 500 0.16366195678710938
16 600 0.1571805775165558
16 700 0.13379788398742676
16 800 0.1431184858083725
17 0 0.19147442281246185
17 100 0.14029201865196228
17 200 0.13789652287960052
17 300 0.12460345774888992
17 400 0.15465517342090607
17 500 0.08973785489797592
17 600 0.10484766215085983
17 700 0.10525118559598923
17 800 0.12608683109283447
18 0 0.12670694291591644
18 100 0.10912366211414337
18 200 0.0819721445441246
18 300 0.14650794863700867
18 400 0.1360531896352768
18 500 0.146173894405365
18 600 0.17236685752868652
18 700 0.10601635277271271
18 800 0.12060575187206268
19 0 0.10585392266511917
19 100 0.12032372504472733
19 200 0.17424027621746063
19 300 0.10197558254003525
19 400 0.1619439721107483
19 500 0.1622178852558136
19 600 0.13286764919757843
19 700 0.11838684976100922
19 800 0.162

100%|██████████| 991/991 [00:06<00:00, 144.60it/s]
100%|██████████| 991/991 [00:05<00:00, 178.51it/s]


--------------------- 0.7772773442961858 0.7676911095543855
21 0 0.11538077890872955
21 100 0.14996354281902313
21 200 0.15144968032836914
21 300 0.08596760779619217
21 400 0.17402248084545135
21 500 0.1497095227241516
21 600 0.2167794406414032
21 700 0.14149020612239838
21 800 0.12765882909297943
22 0 0.13425205647945404
22 100 0.14500251412391663
22 200 0.10337132215499878
22 300 0.1716977208852768
22 400 0.13674508035182953
22 500 0.20038631558418274
22 600 0.1656927466392517
22 700 0.18137435615062714
22 800 0.12233814597129822
23 0 0.14217476546764374
23 100 0.09936026483774185
23 200 0.1285620629787445
23 300 0.16026772558689117
23 400 0.09420030564069748
23 500 0.18484577536582947
23 600 0.14322546124458313
23 700 0.15138360857963562
23 800 0.18194055557250977
24 0 0.13224327564239502
24 100 0.16369511187076569
24 200 0.15132169425487518
24 300 0.14252053201198578
24 400 0.15863297879695892
24 500 0.17379674315452576
24 600 0.14987069368362427
24 700 0.09876183420419693
24 800 0

100%|██████████| 991/991 [00:05<00:00, 181.66it/s]
100%|██████████| 991/991 [00:04<00:00, 222.02it/s]


--------------------- 0.7792082463993264 0.7682906671519297
26 0 0.14541426301002502
26 100 0.12368693947792053
26 200 0.1484636515378952
26 300 0.1833227425813675
26 400 0.094183549284935
26 500 0.04770203307271004
26 600 0.1649903804063797
26 700 0.14518408477306366
26 800 0.11210904270410538
27 0 0.18484969437122345
27 100 0.15875868499279022
27 200 0.13988587260246277
27 300 0.13513147830963135
27 400 0.14433284103870392
27 500 0.12453325092792511
27 600 0.17627353966236115
27 700 0.13919147849082947
27 800 0.12471386790275574
28 0 0.14083173871040344
28 100 0.1071905717253685
28 200 0.19006389379501343
28 300 0.20336560904979706
28 400 0.16118867695331573
28 500 0.10841681808233261
28 600 0.11060267686843872
28 700 0.09694257378578186
28 800 0.09609711915254593
29 0 0.08873849362134933
29 100 0.16487136483192444
29 200 0.22290128469467163
29 300 0.10994070768356323
29 400 0.14258348941802979
29 500 0.15862159430980682
29 600 0.19332152605056763
29 700 0.13528810441493988
29 800 0.

100%|██████████| 991/991 [00:06<00:00, 163.53it/s]
100%|██████████| 991/991 [00:04<00:00, 221.49it/s]


--------------------- 0.7814463542561233 0.7678267111405812
31 0 0.08132623136043549
31 100 0.14134064316749573
31 200 0.09405093640089035
31 300 0.11197508126497269
31 400 0.138681098818779
31 500 0.13341306149959564
31 600 0.1030363067984581
31 700 0.0707242488861084
31 800 0.1716916561126709
32 0 0.16773372888565063
32 100 0.12723493576049805
32 200 0.12577253580093384
32 300 0.19024981558322906
32 400 0.06992729753255844
32 500 0.13403339684009552
32 600 0.1345643252134323
32 700 0.20047175884246826
32 800 0.15738332271575928
33 0 0.10119946300983429
33 100 0.09345526248216629
33 200 0.11855965852737427
33 300 0.07668382674455643
33 400 0.10544151812791824
33 500 0.1098172664642334
33 600 0.22157099843025208
33 700 0.145948126912117
33 800 0.1567160189151764
34 0 0.06894633173942566
34 100 0.15919072926044464
34 200 0.2225673496723175
34 300 0.15203352272510529
34 400 0.0703260526061058
34 500 0.15816769003868103
34 600 0.16181007027626038
34 700 0.09714831411838531
34 800 0.174262

100%|██████████| 991/991 [00:06<00:00, 164.28it/s]
100%|██████████| 991/991 [00:04<00:00, 202.77it/s]


--------------------- 0.7810051831381439 0.7691584798525389
36 0 0.09981655329465866
36 100 0.12771955132484436
36 200 0.1639297604560852
36 300 0.12175294756889343
36 400 0.13110509514808655
36 500 0.1873593032360077
36 600 0.15668797492980957
36 700 0.09050169587135315
36 800 0.09723052382469177
37 0 0.1895170509815216
37 100 0.1448000967502594
37 200 0.1722574383020401
37 300 0.12488602846860886
37 400 0.08081120252609253
37 500 0.13122360408306122
37 600 0.11361191421747208
37 700 0.12339938431978226
37 800 0.10991963744163513
38 0 0.1809743344783783
38 100 0.1541019082069397
38 200 0.14520329236984253
38 300 0.13526856899261475
38 400 0.1131841391324997
38 500 0.10074348747730255
38 600 0.09914568066596985
38 700 0.1363753229379654
38 800 0.05929453670978546
39 0 0.14641252160072327
39 100 0.09507414698600769
39 200 0.15911029279232025
39 300 0.1364162266254425
39 400 0.1709749847650528
39 500 0.07865433394908905
39 600 0.06944645196199417
39 700 0.11491407454013824
39 800 0.12025

100%|██████████| 991/991 [00:05<00:00, 186.01it/s]
100%|██████████| 991/991 [00:05<00:00, 178.57it/s]


--------------------- 0.7832590031064308 0.7700142602050493
41 0 0.10263333469629288
41 100 0.1932767629623413
41 200 0.17125754058361053
41 300 0.09048742800951004
41 400 0.08171497285366058
41 500 0.19767585396766663
41 600 0.13633784651756287
41 700 0.15666556358337402
41 800 0.13263624906539917
42 0 0.13657912611961365
42 100 0.20766639709472656
42 200 0.11306377500295639
42 300 0.15209420025348663
42 400 0.17322206497192383
42 500 0.10926420986652374
42 600 0.14336815476417542
42 700 0.0946086049079895
42 800 0.13220033049583435
43 0 0.152459979057312
43 100 0.1281658113002777
43 200 0.11363682150840759
43 300 0.09366551786661148
43 400 0.14726628363132477
43 500 0.08229728788137436
43 600 0.0882280170917511
43 700 0.08653338998556137
43 800 0.1154496818780899
44 0 0.08845773339271545
44 100 0.1791001707315445
44 200 0.10419468581676483
44 300 0.1289609670639038
44 400 0.11750152707099915
44 500 0.13356031477451324
44 600 0.15304362773895264
44 700 0.13251852989196777
44 800 0.127

100%|██████████| 991/991 [00:06<00:00, 163.82it/s]
100%|██████████| 991/991 [00:05<00:00, 173.76it/s]


--------------------- 0.7837961050243017 0.7702339921996199
46 0 0.1360456794500351
46 100 0.14236250519752502
46 200 0.1750805377960205
46 300 0.10294429212808609
46 400 0.10698813199996948
46 500 0.1897341012954712
46 600 0.2564755082130432
46 700 0.15048259496688843
46 800 0.14842361211776733
47 0 0.16837337613105774
47 100 0.1824779063463211
47 200 0.19011422991752625
47 300 0.13284869492053986
47 400 0.13099804520606995
47 500 0.17656975984573364
47 600 0.21511830389499664
47 700 0.17391541600227356
47 800 0.15661633014678955
48 0 0.17110596597194672
48 100 0.10329404473304749
48 200 0.09018789976835251
48 300 0.16751818358898163
48 400 0.10892155021429062
48 500 0.1501491367816925
48 600 0.09938539564609528
48 700 0.08473936468362808
48 800 0.172996386885643
49 0 0.05997982621192932
49 100 0.10714725404977798
49 200 0.09840206056833267
49 300 0.11082275211811066
49 400 0.1283702701330185
49 500 0.14794008433818817
49 600 0.11455395072698593
49 700 0.11975614726543427
49 800 0.128

In [8]:
with torch.no_grad():
    test_logits = torch.sigmoid(model.inference(g, node_feat, 4096)[test_mask]).numpy()
    
res = np.concatenate((1-test_logits, test_logits), axis=1)
np.save("sage_sample.npy", res)

100%|██████████| 991/991 [00:05<00:00, 175.45it/s]
100%|██████████| 991/991 [00:04<00:00, 214.38it/s]
