In [80]:
import networkx
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import *
from torch_geometric.loader import NeighborSampler, NeighborLoader
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv, ResGatedGraphConv, GATv2Conv, SAGEConv, GENConv, DeepGCNLayer, PairNorm, GINConv
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
import torch.nn.functional as F
from imblearn.under_sampling import RandomUnderSampler
pd.options.mode.use_inf_as_na = True
from collections import Counter
from sklearn.feature_selection import SelectFromModel
import torch.nn as nn
import time
import pickle
from torch.nn import LayerNorm, Linear, ReLU
from torch_scatter import scatter
from tqdm import tqdm
from torch_geometric.loader import RandomNodeSampler
import math
import copy
from sklearn.metrics import f1_score
from torch.optim import lr_scheduler
from sklearn.manifold import TSNE

In [81]:
from densegat.DenseGAT import DenseGAT

In [82]:
class Transition_layer(torch.nn.Module):
    def __init__(self, act, norm, lin):
        super().__init__()
        self.act = act
        self.norm = norm
        self.lin = lin
    def forward(self, x):
        x = self.norm(x)
        x = self.act(x)
        x = self.lin(x)
        return x
        
class DenseGAT(torch.nn.Module):
    def __init__(self, in_channels, num_class, num_layers=6, num_blocks=3, growth_rate=10, theta=0.5):
        super().__init__()
        self.num_layers = num_layers
        self.blocks = num_blocks
        self.theta = theta
        self.growth_rate = growth_rate
        self.base_rate = growth_rate
        self.in_channels = in_channels
        self.out_channels = num_class
        self.linear_layers = torch.nn.ModuleList()
        self.transition_layers = torch.nn.ModuleList()
        self.block_layers = torch.nn.ModuleList()
        self.node_encoder = Linear(in_channels, growth_rate)
        self.linear_layers.append(self.node_encoder)
        
        for i in range(self.blocks):
            # block
            layers = torch.nn.ModuleList()
            # 2^(i - 1) * k0
            self.growth_rate = int(math.pow(2, i) * self.base_rate)
            print(self.growth_rate)
            for j in range(1, self.num_layers + 1):
                conv = SAGEConv(in_channels + (j - 1) * self.growth_rate, self.growth_rate, aggr='mean')
                norm = LayerNorm(self.growth_rate)
                act = ReLU()
                layer = DeepGCNLayer(conv, norm, act, block='dense')
                layers.append(layer)
            self.block_layers.append(layers)
            
            # transition
            hidden_channels = in_channels +  self.num_layers * self.growth_rate
            out_channels = int(hidden_channels * self.theta)
            transition_norm = LayerNorm(hidden_channels, elementwise_affine=True)
            transition_act = ReLU()
            transition_lin = Linear(hidden_channels, out_channels)
            transitionLayer = Transition_layer(transition_act, transition_norm, transition_lin)
            self.transition_layers.append(transitionLayer)
            in_channels = copy.copy(out_channels)
        
        self.lin_last = Linear(in_channels, self.out_channels)
        self.linear_layers.append(self.lin_last)
        
    def forward(self, x, edge_index):
#         x = self.linear_layers[0](x)
        for i in range(self.blocks):
            # block layer
            for layer in self.block_layers[i]:
                x = layer(x, edge_index)
            # transition layer
            x = self.transition_layers[i](x)
        x = self.linear_layers[-1](x)
        return x

In [83]:
import warnings
# action参数可以设置为ignore，一位一次也不喜爱你是，once表示为只显示一次
warnings.filterwarnings(action='ignore')

In [84]:
train_data = torch.load('/home/xiaoyujie/densegat/botlot/bot_lot_train_data')

In [85]:
test_data = torch.load('/home/xiaoyujie/densegat/botlot/bot_lot_test_data')

In [None]:
device = 'cuda:1'

In [86]:
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced',np.unique(train_data.y.numpy()),train_data.y.numpy())
print(class_weights)
class_weights = torch.FloatTensor(class_weights).to(device)
def train():
    total_loss = total_correct = total_examples = 0
    start_time = time.time()
    for batch in train_loader:
        batch = batch.to(device)
        y = batch.y[:batch.batch_size]
        optimizer.zero_grad()
        y_hat = model(batch.x.to(device), batch.edge_index.to(device))[:batch.batch_size]
        loss = F.cross_entropy(y_hat, y, weight=class_weights)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * batch.batch_size
        total_correct += int((y_hat.argmax(dim=-1) == y).sum())
        total_examples += batch.batch_size
    end_time = time.time()  
    
    return total_loss / total_examples, total_correct / total_examples

def inferrence(model, subgraph_loader):
    total_loss = total_correct = total_examples = 0
    xs = []
    y = []
    pbar = tqdm(total=len(subgraph_loader))
    for batch in subgraph_loader:
        y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size]
        xs.append(y_hat.cpu())
        y_true = batch.y[:batch.batch_size]
        y.append(y_true.cpu())
        pbar.update(1)
    pbar.close()
    y_hat = torch.cat(xs, 0)
    y_hat = y_hat.detach().numpy()
    y_hat = np.argmax(y_hat, -1)
    y = torch.cat(y, 0)
    print(y_hat)
    print(y)
    cr1 = classification_report(y, y_hat,digits=4)
    cf = confusion_matrix(y, y_hat)
    print(cr1)
    print(cf)

def get_metrics(y_true, y_pred):
    cr = classification_report(y_true, y_pred, output_dict=True)
    df = pd.DataFrame(cr).transpose()
    cm = confusion_matrix(y_true, y_pred)
    print(df)
    print(cm)

[1.22151832e+02 5.83275000e-01 5.83275000e-01 6.40389762e-01
 7.40666667e+02]


In [87]:
train_data,test_data

(Data(x=[233310, 10], edge_index=[2, 13380270], y=[233310], num_nodes=233310),
 Data(x=[58328, 10], edge_index=[2, 706918], y=[58328], num_nodes=58328))

In [9]:
train_label = train_data.y.numpy()
test_label = test_data.y.numpy()
train_label[train_label != 0] = 1
test_label[test_label != 0] = 1
train_data.y = torch.LongTensor(train_label)
test_data.y = torch.LongTensor(test_label)

In [88]:
train_num_nodes = torch.arange(len(train_data.y))
test_num_nodes = torch.arange(len(test_data.y))

In [89]:
train_data = train_data.to('cpu')
test_data = test_data.to('cpu')

In [90]:
node_num_per_k = 50
depth = 10

In [91]:
hop = [node_num_per_k] * depth

In [92]:
train_loader = NeighborLoader(train_data, input_nodes = train_num_nodes, num_neighbors=hop, batch_size=1024, shuffle=True, directed=False)

In [107]:
in_channels = train_data.x.size(-1)
num_class = 5
model = DenseGAT(in_channels=in_channels, num_class=num_class, 
                 num_layers=50, num_blocks=1, growth_rate=10, theta=0.8)
loss_all = []

10


In [149]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)

In [161]:
device = 'cuda:1'
model = model.to(device)

In [162]:
model.train()
epochs = 500
for i in range(1, 1 + epochs):
    start_time = time.time()
    loss, acc = train()
    end_time = time.time()
    print('epoch: {:04d}'.format(i),
          'loss_train: {:.4f}'.format(loss),
          'acc_train: {:.4f}'.format(acc),
          'time: {:.4f}s'.format(end_time - start_time))
    loss_all.append(loss)

epoch: 0001 loss_train: 0.0288 acc_train: 0.9865 time: 27.0821s
epoch: 0002 loss_train: 0.0289 acc_train: 0.9863 time: 27.2424s
epoch: 0003 loss_train: 0.0287 acc_train: 0.9865 time: 27.8561s
epoch: 0004 loss_train: 0.0288 acc_train: 0.9863 time: 27.0258s
epoch: 0005 loss_train: 0.0288 acc_train: 0.9865 time: 26.7564s
epoch: 0006 loss_train: 0.0287 acc_train: 0.9866 time: 27.6620s
epoch: 0007 loss_train: 0.0287 acc_train: 0.9866 time: 27.4486s
epoch: 0008 loss_train: 0.0287 acc_train: 0.9866 time: 28.4370s
epoch: 0009 loss_train: 0.0287 acc_train: 0.9864 time: 28.4846s
epoch: 0010 loss_train: 0.0287 acc_train: 0.9865 time: 27.5198s
epoch: 0011 loss_train: 0.0287 acc_train: 0.9865 time: 27.6686s
epoch: 0012 loss_train: 0.0287 acc_train: 0.9865 time: 27.2923s
epoch: 0013 loss_train: 0.0287 acc_train: 0.9863 time: 28.2074s
epoch: 0014 loss_train: 0.0287 acc_train: 0.9865 time: 27.4391s
epoch: 0015 loss_train: 0.0286 acc_train: 0.9867 time: 28.3599s
epoch: 0016 loss_train: 0.0287 acc_train

KeyboardInterrupt: 

In [164]:
device = 'cpu'
model = model.to(device)

In [163]:
torch.save(model.state_dict(), 'bot-lot-dgl')

In [165]:
y_hat = model(test_data.x.to(device), test_data.edge_index.to(device))
y_hat = y_hat.detach().numpy()
y_hat = np.argmax(y_hat, -1)
test_y = test_data.y.numpy()

In [155]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.894737  0.894737  0.894737     95.000000
1              0.983730  0.970450  0.977045  20000.000000
2              0.967128  0.985600  0.976277  20000.000000
3              0.998509  0.992480  0.995485  18217.000000
4              1.000000  0.875000  0.933333     16.000000
accuracy       0.982376  0.982376  0.982376      0.982376
macro avg      0.968821  0.943653  0.955375  58328.000000
weighted avg   0.982513  0.982376  0.982395  58328.000000
[[   85     0     3     7     0]
 [    0 19409   576    15     0]
 [    0   285 19712     3     0]
 [   10    36    91 18080     0]
 [    0     0     0     2    14]]


In [160]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.894737  0.894737  0.894737     95.000000
1              0.984521  0.969950  0.977181  20000.000000
2              0.966642  0.986700  0.976568  20000.000000
3              0.998564  0.992150  0.995347  18217.000000
4              1.000000  0.875000  0.933333     16.000000
accuracy       0.982478  0.982478  0.982478      0.982478
macro avg      0.968893  0.943707  0.955433  58328.000000
weighted avg   0.982634  0.982478  0.982498  58328.000000
[[   85     0     3     7     0]
 [    0 19399   587    14     0]
 [    0   263 19734     3     0]
 [   10    42    91 18074     0]
 [    0     0     0     2    14]]


In [141]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.867347  0.894737  0.880829     95.000000
1              0.980753  0.973250  0.976987  20000.000000
2              0.970321  0.982450  0.976348  20000.000000
3              0.997792  0.992425  0.995101  18217.000000
4              1.000000  0.875000  0.933333     16.000000
accuracy       0.982238  0.982238  0.982238      0.982238
macro avg      0.963243  0.943572  0.952520  58328.000000
weighted avg   0.982318  0.982238  0.982257  58328.000000
[[   85     0     3     7     0]
 [    0 19465   507    28     0]
 [    0   348 19649     3     0]
 [   13    34    91 18079     0]
 [    0     0     0     2    14]]


In [126]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.865979  0.884211  0.875000     95.000000
1              0.983824  0.957900  0.970689  20000.000000
2              0.957373  0.988200  0.972542  20000.000000
3              0.994309  0.987923  0.991106  18217.000000
4              1.000000  0.875000  0.933333     16.000000
accuracy       0.977524  0.977524  0.977524      0.977524
macro avg      0.960297  0.938647  0.948534  58328.000000
weighted avg   0.977841  0.977524  0.977535  58328.000000
[[   84     0     3     8     0]
 [    0 19158   770    72     0]
 [    0   215 19764    21     0]
 [   13   100   107 17997     0]
 [    0     0     0     2    14]]


In [113]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.876289  0.894737  0.885417     95.000000
1              0.982945  0.950950  0.966683  20000.000000
2              0.953314  0.990350  0.971479  20000.000000
3              0.991266  0.984410  0.987826  18217.000000
4              1.000000  0.875000  0.933333     16.000000
accuracy       0.974798  0.974798  0.974798      0.974798
macro avg      0.960763  0.939089  0.948948  58328.000000
weighted avg   0.975215  0.974798  0.974789  58328.000000
[[   85     0     3     7     0]
 [    0 19019   874   107     0]
 [    0   151 19807    42     0]
 [   12   179    93 17933     0]
 [    0     0     0     2    14]]


In [100]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.906250  0.915789  0.910995     95.000000
1              0.987116  0.953850  0.970198  20000.000000
2              0.955473  0.990300  0.972575  20000.000000
3              0.992787  0.989790  0.991286  18217.000000
4              0.933333  0.875000  0.903226     16.000000
accuracy       0.977489  0.977489  0.977489      0.977489
macro avg      0.954992  0.944946  0.949656  58328.000000
weighted avg   0.977891  0.977489  0.977484  58328.000000
[[   87     0     0     8     0]
 [    0 19077   828    95     0]
 [    0   168 19806    26     0]
 [    9    81    95 18031     1]
 [    0     0     0     2    14]]


In [69]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.864865  0.895105  0.879725    143.000000
1              0.887375  0.963867  0.924041  30000.000000
2              0.957683  0.992000  0.974540  30000.000000
3              0.995858  0.862324  0.924293  27325.000000
4              0.954545  0.875000  0.913043     24.000000
accuracy       0.941663  0.941663  0.941663      0.941663
macro avg      0.932065  0.917659  0.923128  87492.000000
weighted avg   0.945345  0.941663  0.941359  87492.000000
[[  128     0     1    13     1]
 [    0 28916  1004    80     0]
 [    0   237 29760     3     0]
 [   20  3433   309 23563     0]
 [    0     0     1     2    21]]


In [39]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.766667  0.965035  0.854489    143.000000
1              0.884249  0.961267  0.921151  30000.000000
2              0.958572  0.988767  0.973435  30000.000000
3              0.997175  0.865361  0.926604  27325.000000
4              0.536585  0.916667  0.676923     24.000000
accuracy       0.940737  0.940737  0.940737      0.940737
macro avg      0.828649  0.939419  0.870520  87492.000000
weighted avg   0.944714  0.940737  0.940605  87492.000000
[[  138     0     1     4     0]
 [    1 28838  1102    59     0]
 [    2   329 29663     3     3]
 [   38  3446   179 23646    16]
 [    1     0     0     1    22]]


In [24]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.859873  0.944056  0.900000    143.000000
1              0.893402  0.968567  0.929467  30000.000000
2              0.925063  0.980567  0.952006  30000.000000
3              0.997912  0.839488  0.911870  27325.000000
4              0.875000  0.875000  0.875000     24.000000
accuracy       0.932302  0.932302  0.932302      0.932302
macro avg      0.910250  0.921535  0.913669  87492.000000
weighted avg   0.936838  0.932302  0.931637  87492.000000
[[  135     1     1     6     0]
 [    0 29057   908    35     0]
 [    0   578 29417     5     0]
 [   22  2888  1473 22939     3]
 [    0     0     1     2    21]]


In [95]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.907216  0.926316  0.916667     95.000000
1              0.999880  0.999845  0.999863  58233.000000
accuracy       0.999726  0.999726  0.999726      0.999726
macro avg      0.953548  0.963081  0.958265  58328.000000
weighted avg   0.999729  0.999726  0.999727  58328.000000
[[   88     7]
 [    9 58224]]


In [85]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.805556  0.915789  0.857143     95.000000
1              0.999863  0.999639  0.999751  58233.000000
accuracy       0.999503  0.999503  0.999503      0.999503
macro avg      0.902709  0.957714  0.928447  58328.000000
weighted avg   0.999546  0.999503  0.999519  58328.000000
[[   87     8]
 [   21 58212]]


In [80]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.798165  0.915789  0.852941     95.000000
1              0.999863  0.999622  0.999742  58233.000000
accuracy       0.999486  0.999486  0.999486      0.999486
macro avg      0.899014  0.957706  0.926342  58328.000000
weighted avg   0.999534  0.999486  0.999503  58328.000000
[[   87     8]
 [   22 58211]]


In [56]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.682171  0.926316  0.785714     95.000000
1              0.972191  0.957900  0.964993  20000.000000
2              0.956580  0.991400  0.973679  20000.000000
3              0.998197  0.972279  0.985067  18217.000000
4              0.666667  0.875000  0.756757     16.000000
accuracy       0.973803  0.973803  0.973803      0.973803
macro avg      0.855161  0.944579  0.893242  58328.000000
weighted avg   0.974404  0.973803  0.973892  58328.000000
[[   88     0     1     5     1]
 [    7 19158   815    20     0]
 [    1   166 19828     5     0]
 [   33   382    84 17712     6]
 [    0     0     0     2    14]]


In [35]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.769912  0.915789  0.836538     95.000000
1              0.985939  0.960650  0.973130  20000.000000
2              0.957549  0.991350  0.974156  20000.000000
3              0.998388  0.986222  0.992268  18217.000000
4              0.518519  0.875000  0.651163     16.000000
accuracy       0.979067  0.979067  0.979067      0.979067
macro avg      0.846061  0.945802  0.885451  58328.000000
weighted avg   0.979612  0.979067  0.979148  58328.000000
[[   87     0     2     6     0]
 [    1 19213   768    18     0]
 [    0   169 19827     4     0]
 [   24   105   109 17966    13]
 [    1     0     0     1    14]]


In [30]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.774775  0.905263  0.834951     95.000000
1              0.979166  0.961100  0.970049  20000.000000
2              0.957923  0.988050  0.972753  20000.000000
3              0.998272  0.983148  0.990652  18217.000000
4              0.875000  0.875000  0.875000     16.000000
accuracy       0.977112  0.977112  0.977112      0.977112
macro avg      0.917027  0.942512  0.928681  58328.000000
weighted avg   0.977488  0.977112  0.977165  58328.000000
[[   86     0     2     7     0]
 [    1 19222   758    19     0]
 [    0   235 19761     4     0]
 [   23   174   108 17910     2]
 [    1     0     0     1    14]]


In [24]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.765217  0.926316  0.838095     95.000000
1              0.969307  0.964800  0.967048  20000.000000
2              0.965089  0.971700  0.968383  20000.000000
3              0.994321  0.989954  0.992133  18217.000000
4              0.437500  0.875000  0.583333     16.000000
accuracy       0.974935  0.974935  0.974935      0.974935
macro avg      0.826287  0.945554  0.869799  58328.000000
weighted avg   0.975195  0.974935  0.975025  58328.000000
[[   88     0     1     5     1]
 [    1 19296   613    90     0]
 [    0   559 19434     7     0]
 [   25    52    89 18034    17]
 [    1     0     0     1    14]]


In [62]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.826923  0.905263  0.864322     95.000000
1              0.793031  0.984350  0.878394  20000.000000
2              0.973565  0.931750  0.952198  20000.000000
3              0.998385  0.780644  0.876190  18217.000000
4              1.000000  0.875000  0.933333     16.000000
accuracy       0.902534  0.902534  0.902534      0.902534
macro avg      0.918381  0.895402  0.900887  58328.000000
weighted avg   0.919182  0.902534  0.903004  58328.000000
[[   86     0     2     7     0]
 [    0 19687   306     7     0]
 [    0  1358 18635     7     0]
 [   18  3780   198 14221     0]
 [    0     0     0     2    14]]


In [54]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.860000  0.905263  0.882051     95.000000
1              0.932349  0.978500  0.954867  20000.000000
2              0.974677  0.943000  0.958577  20000.000000
3              0.994461  0.975737  0.985010  18217.000000
4              1.000000  0.875000  0.933333     16.000000
accuracy       0.965317  0.965317  0.965317      0.965317
macro avg      0.952297  0.935500  0.942768  58328.000000
weighted avg   0.966162  0.965317  0.965429  58328.000000
[[   86     0     2     7     0]
 [    0 19570   358    72     0]
 [    0  1122 18860    18     0]
 [   14   298   130 17775     0]
 [    0     0     0     2    14]]


In [46]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.775701  0.873684  0.821782     95.000000
1              0.951279  0.961600  0.956411  20000.000000
2              0.946963  0.951650  0.949300  20000.000000
3              0.987590  0.969808  0.978619  18217.000000
4              0.875000  0.875000  0.875000     16.000000
accuracy       0.960585  0.960585  0.960585      0.960585
macro avg      0.907306  0.926349  0.916223  58328.000000
weighted avg   0.960833  0.960585  0.960667  58328.000000
[[   83     0     4     8     0]
 [    0 19232   625   143     0]
 [    0   897 19033    70     0]
 [   23    88   437 17667     2]
 [    1     0     0     1    14]]


In [29]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.696000  0.915789  0.790909     95.000000
1              0.999863  0.999347  0.999605  58233.000000
accuracy       0.999211  0.999211  0.999211      0.999211
macro avg      0.847931  0.957568  0.895257  58328.000000
weighted avg   0.999368  0.999211  0.999265  58328.000000
[[   87     8]
 [   38 58195]]


In [21]:
get_metrics(test_y, y_hat)

              precision    recall  f1-score       support
0              0.724638  0.526316  0.609756     95.000000
1              0.999228  0.999674  0.999451  58233.000000
accuracy       0.998903  0.998903  0.998903      0.998903
macro avg      0.861933  0.762995  0.804603  58328.000000
weighted avg   0.998780  0.998903  0.998816  58328.000000
[[   50    45]
 [   19 58214]]
