In [1]:
import pandas as pd
import numpy as np
import dgl.nn as dglnn
from dgl import from_networkx
import torch.nn as nn
import torch as th
import torch.nn.functional as F
import dgl.function as fn
from dgl.data.utils import load_graphs
import networkx as nx
import pandas as pd
import socket
import struct
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

In [2]:
import networkx
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import *
from torch_geometric.loader import NeighborSampler, NeighborLoader
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, ResGatedGraphConv, GATv2Conv, SAGEConv, GENConv, DeepGCNLayer, PairNorm, GINConv
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
import torch.nn.functional as F
from imblearn.under_sampling import RandomUnderSampler
pd.options.mode.use_inf_as_na = True
from collections import Counter
from sklearn.feature_selection import SelectFromModel
import torch.nn as nn
import time
import pickle
from torch.nn import LayerNorm, Linear, ReLU
from torch_scatter import scatter
from tqdm import tqdm
from torch_geometric.loader import RandomNodeSampler
import math
import copy
from sklearn.metrics import f1_score
from torch.optim import lr_scheduler
from sklearn.manifold import TSNE

In [3]:
class GATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, e_dim):
        super(GATLayer, self).__init__()
        self.w1 = nn.Linear(in_dim + e_dim, out_dim, bias=False)
        self.w2 = nn.Linear(in_dim + out_dim, out_dim, bias=False)
        self.w_att = nn.Linear(in_dim + out_dim, 1, bias=False)
        self.reset_parameters()
    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.w1.weight, gain=gain)
        nn.init.xavier_normal_(self.w2.weight, gain=gain)
        nn.init.xavier_normal_(self.w_att.weight, gain=gain)
    def edge_attention(self, edges):
        z = torch.cat([edges.src['h'], edges.data['m']], -1)
        a = self.w_att(z)
        alpha = F.leaky_relu(a)
        return {'e': alpha}
    
    def msg1(self, edges):
        return {'m': self.w1(th.cat([edges.src['h'], edges.data['h']], -1))}
    
    def message_func(self, edges):
        return {'z': edges.data['m'], 'e': edges.data['e']}
    
    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1) # 归一化每一条入边的注意力系数
        h = torch.mean(nodes.mailbox['z'], dim=1)
        return {'h_neigh': h}
    def forward(self, g_dgl, hfeat, efeat):
        with g_dgl.local_scope():
            g = g_dgl
            g.ndata['h'] = hfeat
            g.edata['h'] = efeat
            g.apply_edges(self.msg1)
            g.apply_edges(self.edge_attention) # 为每一条边获得其注意力系数
            g.update_all(self.message_func, self.reduce_func)
            g.ndata['h'] = F.relu(self.w2(th.cat([g.ndata['h'], g.ndata['h_neigh']], -1)))
            return g.ndata['h']
class Model(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, e_dim):
        super().__init__()
        self.layer1 = GATLayer(in_dim, hidden_dim, e_dim)
        self.layer2 = GATLayer(hidden_dim, out_dim, e_dim)
        self.pred = MLPPredictor(out_dim, 15)
    def forward(self, g, nfeats, efeats):
        nfeats = self.layer1(g, nfeats, efeats)
        nfeats = self.layer2(g, nfeats, efeats)
        return self.pred(g, nfeats)
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)
    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(th.cat([h_u, h_v], 1))
        return {'score': score}
    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

In [4]:
def message_func_init(edges):
    return {'u': edges.data['h']}

def reduce_func_init(nodes):
    h = torch.mean(nodes.mailbox['u'], dim=1)
    return {'h': h}

In [7]:
class SAGELayer(nn.Module):
    def __init__(self, ndim_in, edims, ndim_out, activation):
        super(SAGELayer, self).__init__()
        ### force to outut fix dimensions
        self.W_msg = nn.Linear(ndim_in + edims, ndim_out)
        ### apply weight
        self.W_apply = nn.Linear(ndim_in + ndim_out, ndim_out)
        self.activation = activation

    def message_func(self, edges):
        return {'m': self.W_msg(th.cat([edges.src['h'], edges.data['h']], 2))}

    def forward(self, g_dgl, nfeats, efeats):
        with g_dgl.local_scope():
            g = g_dgl
            g.ndata['h'] = nfeats
            g.edata['h'] = efeats
            # Eq4
            g.update_all(self.message_func, fn.mean('m', 'h_neigh'))
            # Eq5          
            g.ndata['h'] = F.relu(self.W_apply(th.cat([g.ndata['h'], g.ndata['h_neigh']], 2)))
            return g.ndata['h']


class SAGE(nn.Module):
    def __init__(self, ndim_in, ndim_out, edim, activation, dropout):
        super(SAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(SAGELayer(ndim_in, edim, 128, activation))
        self.layers.append(SAGELayer(128, edim, ndim_out, activation))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, nfeats, efeats):
        for i, layer in enumerate(self.layers):
            if i != 0:
                nfeats = self.dropout(nfeats)
            nfeats = layer(g, nfeats, efeats)
        return nfeats.sum(1)
    
class Model(nn.Module):
    def __init__(self, ndim_in, ndim_out, edim, activation, dropout):
        super().__init__()
        self.gnn = SAGE(ndim_in, ndim_out, edim, activation, dropout)
        self.pred = MLPPredictor(ndim_out, 15)
    def forward(self, g, nfeats, efeats):
        h = self.gnn(g, nfeats, efeats)
        return self.pred(g, h)

In [6]:
class SAGELayer(nn.Module):
    def __init__(self, ndim_in, edims, ndim_out, activation):
        super(SAGELayer, self).__init__()
        ### force to outut fix dimensions
        self.W_msg = nn.Linear(ndim_in + edims, ndim_out)
        ### apply weight
        self.W_apply = nn.Linear(ndim_in + ndim_out, ndim_out)
        self.activation = activation

    def message_func(self, edges):
        return {'m': self.W_msg(th.cat([edges.src['h'], edges.data['h']], 2))}

    def forward(self, g_dgl, nfeats, efeats):
        with g_dgl.local_scope():
            g = g_dgl
            g.ndata['h'] = nfeats
            g.edata['h'] = efeats
            # Eq4
            g.update_all(self.message_func, fn.mean('m', 'h_neigh'))
            # Eq5          
            g.ndata['h'] = F.relu(self.W_apply(th.cat([g.ndata['h'], g.ndata['h_neigh']], 2)))
            return g.ndata['h']


class SAGE_FOR_Binary(nn.Module):
    def __init__(self, ndim_in, ndim_out, edim, activation, dropout):
        super(SAGE_FOR_Binary, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(SAGELayer(ndim_in, edim, 128, activation))
        self.layers.append(SAGELayer(128, edim, ndim_out, activation))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, nfeats, efeats):
        for i, layer in enumerate(self.layers):
            if i != 0:
                nfeats = self.dropout(nfeats)
            nfeats = layer(g, nfeats, efeats)
        return nfeats.sum(1)
    
class Model(nn.Module):
    def __init__(self, ndim_in, ndim_out, edim, activation, dropout):
        super().__init__()
        self.gnn = SAGE_FOR_Binary(ndim_in, ndim_out, edim, activation, dropout)
        self.pred = MLPPredictor(ndim_out, 2)
    def forward(self, g, nfeats, efeats):
        h = self.gnn(g, nfeats, efeats)
        return self.pred(g, h)

In [8]:
G = load_graphs("./cic2017_Gtrain.bin")[0][0]

In [9]:
G_test = load_graphs("./cic2017_G_test.bin")[0][0]

In [10]:
G

Graph(num_nodes=82830, num_edges=849178,
      ndata_schemes={}
      edata_schemes={'daddr': Scheme(shape=(), dtype=torch.int64), 'saddr': Scheme(shape=(), dtype=torch.int64), 'h': Scheme(shape=(78,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)})

In [11]:
G_test

Graph(num_nodes=52228, num_edges=363934,
      ndata_schemes={}
      edata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'h': Scheme(shape=(78,), dtype=torch.float32)})

In [12]:
G.edata['label'].shape

torch.Size([849178])

In [13]:
train_label = G.edata['label'].cpu().numpy()

In [14]:
test_label = G_test.edata['label'].cpu().numpy()

In [21]:
# 二分类
# train_label[train_label!=0] = 1
# test_label[test_label!=0] = 1

In [15]:
Counter(train_label)

Counter({1: 179234,
         9: 322174,
         13: 912,
         10: 14410,
         8: 7698,
         7: 8114,
         2: 222326,
         12: 2110,
         0: 70000,
         14: 30,
         5: 11110,
         6: 8256,
         3: 2738,
         4: 50,
         11: 16})

In [16]:
Counter(test_label)

Counter({1: 76816,
         9: 138074,
         2: 95282,
         5: 4760,
         12: 904,
         7: 3478,
         10: 6176,
         8: 3300,
         13: 392,
         0: 30000,
         14: 12,
         6: 3538,
         3: 1174,
         11: 6,
         4: 22})

In [17]:
G = G.to('cpu')

In [18]:
G.edata['label'] = torch.LongTensor(train_label)

In [19]:
G.ndata['h'] = th.ones(G.num_nodes(), G.edata['h'].shape[1]) 
G.edata['train_mask'] = th.ones(len(G.edata['h']), dtype= th.bool)

In [20]:
G.ndata['h'] = th.reshape(G.ndata['h'], (G.ndata['h'].shape[0], 1, G.ndata['h'].shape[1]))
G.edata['h'] = th.reshape(G.edata['h'], (G.edata['h'].shape[0], 1, G.edata['h'].shape[1]))

In [21]:
G_test.ndata['h'] = th.ones(G_test.num_nodes(), G_test.edata['h'].shape[1]) 

In [22]:
G_test.ndata['h'] = th.reshape(G_test.ndata['h'], (G_test.ndata['h'].shape[0], 1, G_test.ndata['h'].shape[1]))
G_test.edata['h'] = th.reshape(G_test.edata['h'], (G_test.edata['h'].shape[0], 1, G_test.edata['h'].shape[1]))

In [58]:
model = Model(G.ndata['h'].shape[2], 128, G.ndata['h'].shape[2], F.relu, 0.2)

In [24]:
def test(G_test, model):
    test_node_features = G_test.ndata['h']
    test_edge_features = G_test.edata['h']
    y_true = G_test.edata['label'].detach().numpy()
    pred = model(G_test, test_node_features, test_edge_features)
    y_pred = pred.detach().numpy()
    y_pred = np.argmax(y_pred, -1)
    cm = confusion_matrix(y_true, y_pred)
    cr = classification_report(y_true, y_pred, digits=4)
    return cm, cr

In [25]:
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(G.edata['label'].cpu().numpy()),
                                                 G.edata['label'].cpu().numpy())
class_weights = th.FloatTensor(class_weights)



In [26]:
class_weights

tensor([8.0874e-01, 3.1585e-01, 2.5463e-01, 2.0676e+01, 1.1322e+03, 5.0956e+00,
        6.8571e+00, 6.9771e+00, 7.3541e+00, 1.7572e-01, 3.9287e+00, 3.5382e+03,
        2.6830e+01, 6.2074e+01, 1.8871e+03])

In [21]:
# G.update_all(message_func_init, reduce_func_init)

In [28]:
!nvidia-smi

Fri Sep 23 20:17:26 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.73.01    Driver Version: 460.73.01    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  On   | 00000000:3B:00.0 Off |                    0 |
| N/A   38C    P0    38W / 250W |   3121MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-PCIE...  On   | 00000000:86:00.0 Off |                    0 |
| N/A   37C    P0    40W / 250W |   1415MiB / 16160MiB |      0%      Default |
|       

In [59]:
device = 'cuda:2'

In [30]:
G = G.to(device)

In [31]:
class_weights = class_weights.to(device)

In [32]:
node_features = G.ndata['h']
edge_features = G.edata['h']

edge_label = G.edata['label']
train_mask = G.edata['train_mask']

In [33]:
edge_label

tensor([1, 1, 1,  ..., 0, 0, 0], device='cuda:2')

In [34]:
in_dim = 78
hidden_dim = 30
out_dim = 10
e_dim = 78

In [60]:
opt = th.optim.Adam(model.parameters())

In [61]:
criterion = nn.CrossEntropyLoss()

In [49]:
import time

In [50]:
def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

In [61]:
# weight = torch.load('egat_cic_model' + str(1500))

In [62]:
# model.load_state_dict(weight)

<All keys matched successfully>

In [62]:
model = model.to(device)
for epoch in range(0 ,10000):
    pred = model(G, node_features,edge_features)
    loss = criterion(pred[train_mask], edge_label[train_mask])
    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch % 200 == 0:
        print('loss',loss.item(),'Epoch:', epoch ,' Training acc:', compute_accuracy(pred[train_mask], edge_label[train_mask]))
        torch.save(model.state_dict(), 'egraphsage_multiclass_cic_model_' + str(epoch))

loss 2.7497963905334473 Epoch: 0  Training acc: 0.012806502170860767
loss 0.36444365978240967 Epoch: 200  Training acc: 0.8953753113746643
loss 0.3410758078098297 Epoch: 400  Training acc: 0.8958451747894287
loss 0.3313473165035248 Epoch: 600  Training acc: 0.8962326049804688
loss 0.32469668984413147 Epoch: 800  Training acc: 0.8964033722877502
loss 0.3203360438346863 Epoch: 1000  Training acc: 0.8964905142784119
loss 0.31776586174964905 Epoch: 1200  Training acc: 0.8965423107147217
loss 0.3156861662864685 Epoch: 1400  Training acc: 0.8966388702392578
loss 0.3139920234680176 Epoch: 1600  Training acc: 0.8967059850692749
loss 0.31237339973449707 Epoch: 1800  Training acc: 0.8966529965400696
loss 0.31133508682250977 Epoch: 2000  Training acc: 0.8967601656913757
loss 0.3099069893360138 Epoch: 2200  Training acc: 0.8966871500015259
loss 0.3091094493865967 Epoch: 2400  Training acc: 0.8967283964157104
loss 0.307559609413147 Epoch: 2600  Training acc: 0.8968544006347656
loss 0.30776628851890

In [63]:
weight = torch.load('egraphsage_multiclass_cic_model_' + str(3600))
model.load_state_dict(weight)

<All keys matched successfully>

In [64]:
model = model.to('cpu')
cm, cr = test(G_test, model)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [65]:
print(cr)

              precision    recall  f1-score   support

           0     0.9384    0.9912    0.9641     30000
           1     0.9619    0.7290    0.8294     76816
           2     0.9923    0.9943    0.9933     95282
           3     0.9966    1.0000    0.9983      1174
           4     0.0000    0.0000    0.0000        22
           5     0.9901    0.8582    0.9194      4760
           6     0.9965    0.9743    0.9853      3538
           7     0.3984    0.0417    0.0755      3478
           8     0.3644    0.0273    0.0507      3300
           9     0.7980    0.9774    0.8786    138074
          10     0.3750    0.0015    0.0029      6176
          11     0.4286    1.0000    0.6000         6
          12     0.1111    0.0022    0.0043       904
          13     0.0000    0.0000    0.0000       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.8913    363934
   macro avg     0.5567    0.5065    0.4868    363934
weighted avg     0.8825   

In [55]:
print(cr)

              precision    recall  f1-score   support

           0     0.9185    0.9825    0.9494     30000
           1     0.8873    0.7630    0.8204     76816
           2     0.9981    0.9829    0.9904     95282
           3     0.9868    0.9540    0.9701      1174
           4     0.0300    1.0000    0.0582        22
           5     0.9633    0.9931    0.9780      4760
           6     0.9591    0.9943    0.9764      3538
           7     0.1589    0.5894    0.2504      3478
           8     0.1213    0.7485    0.2087      3300
           9     0.8803    0.5160    0.6506    138074
          10     0.1165    0.5002    0.1890      6176
          11     0.0556    1.0000    0.1053         6
          12     0.0471    0.6704    0.0879       904
          13     0.0332    0.5077    0.0624       392
          14     0.0011    0.1667    0.0022        12

    accuracy                         0.7441    363934
   macro avg     0.4771    0.7579    0.4866    363934
weighted avg     0.8881   

In [89]:
print(cr)

              precision    recall  f1-score   support

           0     0.9958    0.9840    0.9899     30000
           1     0.9986    0.9996    0.9991    333934

    accuracy                         0.9983    363934
   macro avg     0.9972    0.9918    0.9945    363934
weighted avg     0.9983    0.9983    0.9983    363934



In [90]:
cm

array([[ 29520,    480],
       [   125, 333809]])

In [85]:
print(cr)

              precision    recall  f1-score   support

           0     0.9657    0.9884    0.9769     30000
           1     0.9990    0.9968    0.9979    333934

    accuracy                         0.9962    363934
   macro avg     0.9823    0.9926    0.9874    363934
weighted avg     0.9962    0.9962    0.9962    363934



In [78]:
cm

array([[ 29646,    354],
       [  1090, 332844]])

In [63]:
print(cr)

              precision    recall  f1-score   support

           0     0.9764    0.9858    0.9811     30000
           1     0.9422    0.7419    0.8301     76816
           2     0.9916    0.9961    0.9939     95282
           3     0.9923    0.8731    0.9289      1174
           4     1.0000    0.0455    0.0870        22
           5     0.9859    0.9523    0.9688      4760
           6     0.9912    0.9596    0.9752      3538
           7     0.5116    0.0127    0.0247      3478
           8     0.3851    0.0188    0.0358      3300
           9     0.8001    0.9727    0.8780    138074
          10     0.0000    0.0000    0.0000      6176
          11     1.0000    1.0000    1.0000         6
          12     0.1790    0.0454    0.0724       904
          13     0.0000    0.0000    0.0000       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.8927    363934
   macro avg     0.6504    0.5069    0.5184    363934
weighted avg     0.8772   

In [48]:
print(cr)

              precision    recall  f1-score   support

           0     0.9589    0.9877    0.9731     30000
           1     0.9505    0.7372    0.8304     76816
           2     0.9924    0.9950    0.9937     95282
           3     0.9966    0.9949    0.9957      1174
           4     0.0000    0.0000    0.0000        22
           5     0.9898    0.8754    0.9291      4760
           6     0.9936    0.9666    0.9799      3538
           7     0.6000    0.0052    0.0103      3478
           8     0.4071    0.0173    0.0331      3300
           9     0.7988    0.9771    0.8790    138074
          10     1.0000    0.0003    0.0006      6176
          11     1.0000    1.0000    1.0000         6
          12     0.2222    0.0044    0.0087       904
          13     0.0000    0.0000    0.0000       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.8925    363934
   macro avg     0.6607    0.5041    0.5089    363934
weighted avg     0.8953   

In [42]:
print(cr)

              precision    recall  f1-score   support

           0     0.9681    0.9873    0.9776     30000
           1     0.9489    0.7380    0.8303     76816
           2     0.9920    0.9950    0.9935     95282
           3     0.9963    0.9259    0.9598      1174
           4     0.0000    0.0000    0.0000        22
           5     0.9816    0.9435    0.9622      4760
           6     0.9874    0.9771    0.9822      3538
           7     0.4219    0.0388    0.0711      3478
           8     0.4565    0.0064    0.0126      3300
           9     0.7999    0.9731    0.8780    138074
          10     0.5000    0.0006    0.0013      6176
          11     1.0000    1.0000    1.0000         6
          12     0.1490    0.0575    0.0830       904
          13     0.0548    0.0204    0.0297       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.8923    363934
   macro avg     0.6171    0.5109    0.5188    363934
weighted avg     0.8860   

In [82]:
print(cr)

              precision    recall  f1-score   support

           0     0.9301    0.9899    0.9591     30000
           1     0.9561    0.7338    0.8303     76816
           2     0.9954    0.9958    0.9956     95282
           3     0.9941    0.9966    0.9953      1174
           4     0.0000    0.0000    0.0000        22
           5     0.9913    0.7183    0.8330      4760
           6     0.9965    0.9763    0.9863      3538
           7     0.4238    0.0184    0.0353      3478
           8     0.3156    0.0612    0.1025      3300
           9     0.7989    0.9769    0.8790    138074
          10     0.7500    0.0005    0.0010      6176
          11     1.0000    1.0000    1.0000         6
          12     0.0000    0.0000    0.0000       904
          13     0.2000    0.0051    0.0100       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.8907    363934
   macro avg     0.6235    0.4982    0.5085    363934
weighted avg     0.8879   

In [78]:
print(cr)

              precision    recall  f1-score   support

           0     0.9709    0.9845    0.9777     30000
           1     0.9401    0.7395    0.8278     76816
           2     0.9863    0.9959    0.9911     95282
           3     0.9973    0.6363    0.7769      1174
           4     0.0000    0.0000    0.0000        22
           5     0.9794    0.9674    0.9734      4760
           6     0.9795    0.9845    0.9820      3538
           7     0.4087    0.0135    0.0262      3478
           8     0.4079    0.0094    0.0184      3300
           9     0.7995    0.9708    0.8769    138074
          10     0.3333    0.0002    0.0003      6176
          11     1.0000    1.0000    1.0000         6
          12     0.3333    0.0022    0.0044       904
          13     0.0000    0.0000    0.0000       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.8908    363934
   macro avg     0.6091    0.4869    0.4970    363934
weighted avg     0.8797   

In [89]:
print(cr)

              precision    recall  f1-score   support

           0     0.6506    0.9932    0.7862     30000
           1     0.7520    0.8983    0.8187     12000
           2     0.9643    0.9752    0.9698     12000
           3     0.9459    0.9838    0.9645      1174
           4     0.0000    0.0000    0.0000        22
           5     0.1119    0.0067    0.0127      4760
           6     0.2583    0.0088    0.0169      3538
           7     0.9711    0.0483    0.0920      3478
           8     1.0000    0.0064    0.0126      3300
           9     0.5672    0.6318    0.5978     12000
          10     0.3750    0.0044    0.0086      6176
          11     0.0000    0.0000    0.0000         6
          12     0.3071    0.6261    0.4121       904
          13     0.1183    0.1173    0.1178       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.6897     89762
   macro avg     0.4681    0.3534    0.3206     89762
weighted avg     0.6550   

In [46]:
print(cr)

              precision    recall  f1-score   support

           0     0.9266    0.9888    0.9567     30000
           1     0.8360    0.9143    0.8734     12000
           2     0.9902    0.9562    0.9729     12000
           3     0.9551    0.9966    0.9754      1174
           4     0.0000    0.0000    0.0000        22
           5     0.9824    0.9239    0.9523      4760
           6     0.9742    0.6614    0.7879      3538
           7     0.9039    0.4166    0.5704      3478
           8     0.5370    0.8273    0.6512      3300
           9     0.7853    0.6808    0.7293     12000
          10     0.7950    0.5494    0.6498      6176
          11     0.5455    1.0000    0.7059         6
          12     0.3376    0.7566    0.4669       904
          13     0.1063    0.4158    0.1693       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.8535     89762
   macro avg     0.6450    0.6725    0.6307     89762
weighted avg     0.8752   

In [37]:
print(cr)

              precision    recall  f1-score   support

           0     0.8527    0.9954    0.9185     30000
           1     0.9289    0.8697    0.8983     12000
           2     0.9927    0.8769    0.9312     12000
           3     0.9725    0.9932    0.9827      1174
           4     0.0000    0.0000    0.0000        22
           5     0.9967    0.8206    0.9001      4760
           6     0.9937    0.4921    0.6582      3538
           7     0.6724    0.7890    0.7260      3478
           8     0.6487    0.6233    0.6358      3300
           9     0.9148    0.5754    0.7065     12000
          10     0.6090    0.8960    0.7252      6176
          11     0.0000    0.0000    0.0000         6
          12     0.3107    0.2987    0.3046       904
          13     0.1930    0.6301    0.2955       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.8399     89762
   macro avg     0.6057    0.5907    0.5788     89762
weighted avg     0.8647   

In [45]:
print(cr)

              precision    recall  f1-score   support

           0     0.7756    0.9981    0.8729     30000
           1     0.9407    0.8642    0.9008     12000
           2     0.8917    0.7551    0.8177     12000
           3     0.7778    0.0119    0.0235      1174
           4     1.0000    1.0000    1.0000        22
           5     0.9924    0.9071    0.9479      4760
           6     0.0000    0.0000    0.0000      3538
           7     0.9186    0.4543    0.6079      3478
           8     0.7692    0.4000    0.5263      3300
           9     0.7922    0.7003    0.7434     12000
          10     0.5300    0.9616    0.6834      6176
          11     0.0000    0.0000    0.0000         6
          12     0.7047    0.2323    0.3494       904
          13     0.0000    0.0000    0.0000       392
          14     0.0000    0.0000    0.0000        12

    accuracy                         0.7930     89762
   macro avg     0.6062    0.4857    0.4982     89762
weighted avg     0.7806   

In [38]:
G_test.ndata['h'] = th.ones(G_test.num_nodes(), G_test.edata['h'].shape[1]) 

In [76]:
def test(G_test, model):
    test_node_features = G_test.ndata['h']
    test_edge_features = G_test.edata['h']
    y_true = G_test.edata['label'].detach().numpy()
    pred = model(G_test, test_node_features, test_edge_features)
    y_pred = pred.detach().numpy()
    y_pred = np.argmax(y_pred, -1)
    cm = confusion_matrix(y_true, y_pred)
    cr = classification_report(y_true, y_pred, digits=4)
    return cm, cr

In [69]:
G_test = load_graphs("./cic2017_test_data_rus.bin")[0][0]

In [72]:
G_test

Graph(num_nodes=52228, num_edges=363934,
      ndata_schemes={}
      edata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'h': Scheme(shape=(78,), dtype=torch.float32)})

In [75]:
G_test.ndata['h'] = th.reshape(G_test.ndata['h'], (G_test.ndata['h'].shape[0], 1, G_test.ndata['h'].shape[1]))
G_test.edata['h'] = th.reshape(G_test.edata['h'], (G_test.edata['h'].shape[0], 1, G_test.edata['h'].shape[1]))

In [72]:
G_test.update_all(message_func_init, reduce_func_init)

In [71]:
G_test.ndata['h'] = th.ones(G_test.num_nodes(), G_test.edata['h'].shape[1]) 

In [44]:
Counter(G_test.edata['label'].numpy())

Counter({0: 30000,
         10: 6176,
         7: 3478,
         9: 12000,
         1: 12000,
         8: 3300,
         12: 904,
         13: 392,
         2: 12000,
         14: 12,
         3: 1174,
         5: 4760,
         6: 3538,
         11: 6,
         4: 22})

In [None]:
class GATLayer(nn.Module):
    def __init__(self, in_dim , out_dim):
        super(GATLayer, self).__init__()
        self.W = nn.Linear(in_dim, out_dim, bias=False)
        self.attn_fc = nn.Linear(3 * out_dim, 1, bias=False)
        self.reset_parameters()
    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.W.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
    def edge_attention(self, edges):
        z = torch.cat([edges.src['z'], edges.dst['z'], edges.data['z']], -1)
        a = self.attn_fc(z)
        alpha = F.leaky_relu(a)
        return {'e': alpha}
    def message_func(self, edges):
        return {'z': edges.src['z'], 'e': edges.data['e'] ,'m': edges.data['z']}
    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1) # 归一化每一条入边的注意力系数
        h = torch.sum(alpha * nodes.mailbox['m'], dim=1)
        return {'h': F.relu(h)}
    def forward(self, g_dgl, hfeat, efeat):
        with g_dgl.local_scope():
            g = g_dgl
            z1 = self.W(hfeat)
            z2 = self.W(efeat)
            g.ndata['z'] = z1 # 每个节点的特征
            g.edata['z'] = z2 # 每条边的特征
            g.apply_edges(self.edge_attention) # 为每一条边获得其注意力系数
            g.update_all(self.message_func, self.reduce_func)
            return g.ndata['h'], g.edata['z']
class Model(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.layer1 = GATLayer(in_dim, hidden_dim)
        self.layer2 = GATLayer(hidden_dim, out_dim)
        self.pred = MLPPredictor(out_dim, 5)
    def forward(self, g, nfeats, efeats):
        nfeats, efeats = self.layer1(g, nfeats, efeats)
#         nfeats, efeats = self.layer2(g, nfeats, efeats)
        return self.pred(g, nfeats)
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)
    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(th.cat([h_u, h_v], 1))
        return {'score': score}
    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

In [32]:
class GATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, e_dim):
        super(GATLayer, self).__init__()
        self.w1 = nn.Linear(in_dim + e_dim, out_dim, bias=False)
        self.w2 = nn.Linear(in_dim + out_dim, out_dim, bias=False)
        self.w_att = nn.Linear(in_dim + out_dim, 1, bias=False)
        self.reset_parameters()
    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.w1.weight, gain=gain)
        nn.init.xavier_normal_(self.w2.weight, gain=gain)
        nn.init.xavier_normal_(self.w_att.weight, gain=gain)
    def edge_attention(self, edges):
        z = torch.cat([edges.src['h'], edges.data['m']], -1)
        a = self.w_att(z)
        alpha = F.leaky_relu(a)
        return {'e': alpha}
    def msg1(self, edges):
        return {'m': self.w1(th.cat([edges.src['h'], edges.data['h']], -1))}
    
    def message_func(self, edges):
        return {'z': edges.data['m'], 'e': edges.data['e']}
    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1) # 归一化每一条入边的注意力系数
        h = torch.mean(nodes.mailbox['z'], dim=1)
        return {'h_neigh': h}
    def forward(self, g_dgl, hfeat, efeat):
        with g_dgl.local_scope():
            g = g_dgl
            g.ndata['h'] = hfeat
            g.edata['h'] = efeat
            g.apply_edges(self.msg1)
            g.apply_edges(self.edge_attention) # 为每一条边获得其注意力系数
            g.update_all(self.message_func, self.reduce_func)
            g.ndata['h'] = F.relu(self.w2(th.cat([g.ndata['h'], g.ndata['h_neigh']], -1)))
            return g.ndata['h']
class Model(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, e_dim):
        super().__init__()
        self.layer1 = GATLayer(in_dim, hidden_dim, e_dim)
        self.layer2 = GATLayer(hidden_dim, out_dim, e_dim)
        self.pred = MLPPredictor(out_dim, 15)
    def forward(self, g, nfeats, efeats):
        nfeats = self.layer1(g, nfeats, efeats)
        nfeats = self.layer2(g, nfeats, efeats)
        return self.pred(g, nfeats)
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)
    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(th.cat([h_u, h_v], 1))
        return {'score': score}
    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']