In [1]:
import os
import pickle
import time
import pandas as pd
import progressbar
from tqdm import tqdm
import torch.nn as nn
import argparse
import numpy as np
import random
import torch
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [46]:
random.seed(1)
np.random.seed(1)
device = 'cuda:0'
data_class = {"UNSW-NB15": 10,
              "Darknet": 9,
              "CSE-CIC": 7,
              "ToN-IoT": 10}
data_lr = {"UNSW-NB15": 0.007,
           "Darknet": 0.003,
           "CSE-CIC": 0.003,
           "ToN-IoT": 0.01}
test_size = {"UNSW-NB15": 210000,
             "Darknet": 45000,
             "CSE-CIC": 75000,
             "ToN-IoT": 140000}
from torch_geometric.nn import GATv2Conv, GATConv, ResGatedGraphConv,ClusterGCNConv


In [49]:
def load_gat(path, binary):
    # feature
    edge_feat = np.load(path + "edge_feat_scaled.npy")  # （n,f)
    edge_feat = torch.tensor(edge_feat, dtype=torch.float, device=device)

    # label
    if binary:
        label = np.load(path + "label_bi.npy", allow_pickle=True)  # (n,1)
    else:
        label = np.load(path + "label_mul.npy", allow_pickle=True)
    label = torch.tensor(label, dtype=torch.long, device=device)  # Cross entropy expects a long int

    # adjacency
    adj = np.load(path + "adj_random.npy", allow_pickle=True)
    with open(path + 'adj_random_list.dict', 'rb') as file:
        adj_lists = pickle.load(file)

    return edge_feat, label, adj, adj_lists


class Dataset:
    def __init__(self, edge_feat, adj, adj_lists):
        self.device = device
        self.edge_feat = edge_feat
        self.adj = adj
        self.adj_lists = adj_lists

    def get_data(self, edge_idx):
        source_nodes_ids, target_nodes_ids = [], []
        seen_edges = set()
        edges = set(edge_idx)
        edges_neigh = set(edge_idx)
        k = 2  # k-hop
        for i in range(k):
            source_nodes_ids, target_nodes_ids, seen_edges, edges_neigh = self.build_edge_index(source_nodes_ids,
                                                                                                target_nodes_ids,
                                                                                                seen_edges,
                                                                                                edges_neigh)
            edges = edges.union(edges_neigh)

        # Step 2: Construct Batch in_nodes_features
        in_nodes_features = self.edge_feat[list(edges)]
        # Step 3: Construct new mapping; unique_map converts the edges to (0, len(edges))
        unique_map = {}
        for idx, edge in enumerate(edges):
            unique_map[edge] = idx
        source_nodes_ids = [unique_map[ids] for ids in source_nodes_ids]
        target_nodes_ids = [unique_map[ids] for ids in target_nodes_ids]
        # Step 4: Build edge_index; shape = (2, E), where E is the number of edges in the graph
        edge_index = np.row_stack((source_nodes_ids, target_nodes_ids))
        edge_index = torch.tensor(edge_index, dtype=torch.int64, device=self.device)

        # Step 5: Mapped edge_idx
        map_edge_idx = [unique_map[ids] for ids in edge_idx]
        map_edge_idx = torch.tensor(map_edge_idx, dtype=torch.int64, device=self.device)
        # 如果用残差 就需要这一步
        # data = (in_nodes_features, in_nodes_features, edge_index, map_edge_idx) # 第一个是用于残差使用的，第二个是原始的节点特征
        # 没有残差返回这个
        data = (in_nodes_features, edge_index, map_edge_idx)
        return data

    def build_edge_index(self, source_nodes_ids, target_nodes_ids, seen_edges, edges_neigh):
        new_neigh = set()
        for edge in edges_neigh:
            nodes = self.adj[edge]
            for node in nodes:
                neigh = self.adj_lists.get(node)
                new_neigh = new_neigh.union(neigh)
                for edge_neigh in neigh:
                    if (edge, edge_neigh) not in seen_edges and \
                            (edge_neigh, edge) not in seen_edges:
                        # and \
                        # edge != edge_neigh:
                        source_nodes_ids.append(edge)
                        target_nodes_ids.append(edge_neigh)
                        seen_edges.add((edge, edge_neigh))

        return source_nodes_ids, target_nodes_ids, seen_edges, new_neigh


class MyGNN(nn.Module):
    def __init__(self, in_features, out_features):
        super(MyGNN, self).__init__()
        self.hidden_size1 = 50
        self.hidden_size2 = 30
        self.hidden_size3 = 20
        self.hidden_size4 = 10
        self.conv1 = ClusterGCNConv(in_features, self.hidden_size1)
        self.conv2 = ClusterGCNConv(self.hidden_size1, self.hidden_size2)
        self.conv3 = ClusterGCNConv(self.hidden_size2, self.hidden_size3)
        self.conv4 = ClusterGCNConv(self.hidden_size3, self.hidden_size4)
        self.conv5 = ClusterGCNConv(self.hidden_size4, out_features)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        z = self.conv1(x, edge_index)
        z = self.conv2(z, edge_index)
        z = self.conv3(z, edge_index)
        z = self.conv4(z, edge_index)
        z = self.conv5(z, edge_index)
        output = F.log_softmax(z, 1)
        return output


class GNNTrain:
    def __init__(self, model, train_idx, val_idx, dataset, label, epochs):
        self.loss_fn = nn.CrossEntropyLoss().to(device)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
        self.model = model
        self.train_idx = train_idx
        self.val_idx = val_idx
        self.dataset = dataset
        self.epochs = epochs
        self.label = label

    def train(self):
        for epoch in range(self.epochs):
            random.shuffle(self.train_idx)
            print("epoch: ", epoch)
            for batch in range(int(len(self.train_idx) / 500)):
                start_time = time.time()
                batch_edges = self.train_idx[500 * batch:500 * (batch + 1)]  # 500 records per batch
                x, edge_index, map_edge_idx = self.dataset.get_data(batch_edges)
                data = Data(x=x, edge_index=edge_index, y=self.label[batch_edges]).to(device)
                self.model.train()
                output = self.model(data)
                train_output = output.index_select(0, map_edge_idx)  # 选择idx的行
                train_output_ = train_output.cpu()
                loss = self.loss_fn(train_output, self.label[batch_edges])
                self.optimizer.zero_grad()  # 权重清零
                loss.backward()
                self.optimizer.step()
                end_time = time.time()
                acc_train = f1_score(data.y.cpu(), torch.argmax(train_output_, dim=-1), average="weighted")

                print('batch: {:03d}'.format(batch + 1),
                      'loss_train: {:.4f}'.format(loss.item()),
                      'acc_train: {:.4f}'.format(acc_train.item()),
                      'time: {:.4f}s'.format(end_time - start_time))

                if batch >= 179:
                    break
                torch.save(self.model.state_dict(), './ClusterGCNConv.pth' + f'{epoch}' + '.pth')

    def evaluate(self):
        batch_size = 500
        predict_output = []
        model_path = './ResGatedGraphConv2.pth'
        if not os.path.exists(model_path):
            raise RuntimeError('没有已保存的模型')
        model_dict = torch.load(model_path)
        self.model.load_state_dict(model_dict)
        allloss = []
        for batch in tqdm(range(int(len(self.val_idx) / batch_size))):
            batch_edges = self.val_idx[batch_size * batch:batch_size * (batch + 1)]
            x, edge_index, map_edge_idx = self.dataset.get_data(batch_edges)
            data = Data(x=x, edge_index=edge_index, y=self.label[batch_edges]).to(device)
            output = self.model(data)
            batch_output = output.index_select(0, map_edge_idx)  # 选择idx的行
            batch_loss = self.loss_fn(batch_output, self.label[batch_edges])
            allloss.append(batch_loss.item())
            batch_output = torch.argmax(batch_output, dim=-1)
            predict_output.extend(batch_output.cpu())
        allloss = np.array(allloss)
        print(classification_report(self.label[self.val_idx].cpu(), predict_output, digits=4))
        print(confusion_matrix(self.label[self.val_idx].cpu(), predict_output))
        return predict_output, self.label[self.val_idx].cpu()
        
    


In [50]:
data_name = 'CSE-CIC'
path = "datasets/" + data_name + "/"
num_class = data_class[data_name]
binary = False
# ------------------------------------------------------------------------
edge_feat, label, adj, adj_lists = load_gat(path, binary)
in_features = len(edge_feat[-1])
num_edges = len(edge_feat)
# 分割数据集
alltrainidx = np.arange(num_edges)
label2 = np.array(label.cpu())
train_val, test = train_test_split(alltrainidx, test_size=test_size[data_name], stratify=label2)

# 加载模型
model = MyGNN(in_features, num_class).to(device)
dataset = Dataset(edge_feat=edge_feat, adj=adj, adj_lists=adj_lists)
# 开始训练
gnn = GNNTrain(model=model,
               train_idx=train_val, val_idx=test,
               dataset=dataset, label=label,
               epochs=3)
gnn.train()
# 测试
# pred, label = gnn.evaluate()

epoch:  0
batch: 001 loss_train: 5.9969 acc_train: 0.0856 time: 5.9186s
batch: 002 loss_train: 4.0600 acc_train: 0.2980 time: 6.1450s
batch: 003 loss_train: 3.4513 acc_train: 0.4180 time: 5.9042s
batch: 004 loss_train: 2.2568 acc_train: 0.7147 time: 6.5370s
batch: 005 loss_train: 2.3909 acc_train: 0.7212 time: 5.4628s
batch: 006 loss_train: 1.8252 acc_train: 0.6933 time: 5.6928s
batch: 007 loss_train: 1.9634 acc_train: 0.7210 time: 6.1728s
batch: 008 loss_train: 1.3968 acc_train: 0.7502 time: 5.9159s
batch: 009 loss_train: 1.2552 acc_train: 0.7806 time: 5.9816s
batch: 010 loss_train: 1.0715 acc_train: 0.8080 time: 6.0314s
batch: 011 loss_train: 1.0083 acc_train: 0.7990 time: 5.8860s
batch: 012 loss_train: 1.5124 acc_train: 0.8073 time: 5.8352s
batch: 013 loss_train: 0.8878 acc_train: 0.8762 time: 5.7954s
batch: 014 loss_train: 0.6298 acc_train: 0.8768 time: 5.9676s
batch: 015 loss_train: 0.7318 acc_train: 0.8959 time: 6.4645s
batch: 016 loss_train: 0.5743 acc_train: 0.9039 time: 6.3321

batch: 133 loss_train: 0.3024 acc_train: 0.9111 time: 5.8860s
batch: 134 loss_train: 0.3002 acc_train: 0.9114 time: 6.5412s
batch: 135 loss_train: 0.2373 acc_train: 0.9330 time: 6.2365s
batch: 136 loss_train: 0.2983 acc_train: 0.9136 time: 6.0652s
batch: 137 loss_train: 0.2992 acc_train: 0.9203 time: 6.6995s
batch: 138 loss_train: 0.2638 acc_train: 0.9099 time: 6.3122s
batch: 139 loss_train: 0.2999 acc_train: 0.9268 time: 6.9365s
batch: 140 loss_train: 0.3088 acc_train: 0.9205 time: 6.6378s
batch: 141 loss_train: 0.4022 acc_train: 0.9121 time: 6.1180s
batch: 142 loss_train: 0.3662 acc_train: 0.9211 time: 6.2624s
batch: 143 loss_train: 0.4377 acc_train: 0.8930 time: 6.7155s
batch: 144 loss_train: 0.3350 acc_train: 0.9263 time: 5.9577s
batch: 145 loss_train: 0.3828 acc_train: 0.9078 time: 5.8601s
batch: 146 loss_train: 0.2930 acc_train: 0.8993 time: 5.7894s
batch: 147 loss_train: 0.3321 acc_train: 0.9030 time: 5.7466s
batch: 148 loss_train: 0.2647 acc_train: 0.9232 time: 6.0344s
batch: 1

batch: 085 loss_train: 0.2171 acc_train: 0.9337 time: 6.6189s
batch: 086 loss_train: 0.3174 acc_train: 0.9282 time: 5.8591s
batch: 087 loss_train: 0.2344 acc_train: 0.9292 time: 5.8511s
batch: 088 loss_train: 0.2833 acc_train: 0.9317 time: 6.1280s
batch: 089 loss_train: 0.2383 acc_train: 0.9096 time: 5.5225s
batch: 090 loss_train: 0.2520 acc_train: 0.9349 time: 5.9866s
batch: 091 loss_train: 0.3129 acc_train: 0.9255 time: 5.7137s
batch: 092 loss_train: 0.2616 acc_train: 0.9455 time: 5.3244s
batch: 093 loss_train: 0.2826 acc_train: 0.9142 time: 5.9617s
batch: 094 loss_train: 0.2761 acc_train: 0.9119 time: 5.9329s
batch: 095 loss_train: 0.2641 acc_train: 0.9136 time: 5.6370s
batch: 096 loss_train: 0.2671 acc_train: 0.9144 time: 5.7954s
batch: 097 loss_train: 0.3345 acc_train: 0.8839 time: 5.8999s
batch: 098 loss_train: 0.3346 acc_train: 0.8979 time: 5.5544s
batch: 099 loss_train: 0.2561 acc_train: 0.9334 time: 5.6918s
batch: 100 loss_train: 0.2716 acc_train: 0.9127 time: 5.5066s
batch: 1

batch: 037 loss_train: 0.2831 acc_train: 0.9212 time: 5.5046s
batch: 038 loss_train: 0.2493 acc_train: 0.9311 time: 6.2255s
batch: 039 loss_train: 0.2646 acc_train: 0.9125 time: 5.6619s
batch: 040 loss_train: 0.2423 acc_train: 0.9250 time: 6.0802s
batch: 041 loss_train: 0.3434 acc_train: 0.8806 time: 5.7605s
batch: 042 loss_train: 0.3273 acc_train: 0.9074 time: 5.7217s
batch: 043 loss_train: 0.2725 acc_train: 0.9135 time: 5.5893s
batch: 044 loss_train: 0.1934 acc_train: 0.9417 time: 5.6500s
batch: 045 loss_train: 0.3050 acc_train: 0.9214 time: 5.6679s
batch: 046 loss_train: 0.2484 acc_train: 0.9068 time: 6.0154s
batch: 047 loss_train: 0.2392 acc_train: 0.9126 time: 5.8482s
batch: 048 loss_train: 0.2309 acc_train: 0.9280 time: 5.7834s
batch: 049 loss_train: 0.3672 acc_train: 0.8861 time: 6.0981s
batch: 050 loss_train: 0.2207 acc_train: 0.9174 time: 5.7157s
batch: 051 loss_train: 0.3136 acc_train: 0.9064 time: 5.7287s
batch: 052 loss_train: 0.2660 acc_train: 0.9057 time: 5.9846s
batch: 0

batch: 170 loss_train: 0.2482 acc_train: 0.9231 time: 6.0961s
batch: 171 loss_train: 0.2036 acc_train: 0.9199 time: 5.9507s
batch: 172 loss_train: 0.3164 acc_train: 0.9088 time: 5.6809s
batch: 173 loss_train: 0.3485 acc_train: 0.8998 time: 5.9796s
batch: 174 loss_train: 0.2192 acc_train: 0.9304 time: 6.7493s
batch: 175 loss_train: 0.3677 acc_train: 0.9082 time: 6.7010s
batch: 176 loss_train: 0.4205 acc_train: 0.8883 time: 6.0411s
batch: 177 loss_train: 0.2960 acc_train: 0.9160 time: 7.2392s
batch: 178 loss_train: 0.3818 acc_train: 0.8867 time: 6.0847s
batch: 179 loss_train: 0.3023 acc_train: 0.9052 time: 6.3301s
batch: 180 loss_train: 0.3046 acc_train: 0.9017 time: 6.4292s


In [44]:
pred, label = gnn.evaluate()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [15:06<00:00,  6.04s/it]


              precision    recall  f1-score   support

           0     0.9597    0.9935    0.9763     52898
           1     0.9891    0.9952    0.9921      5460
           2     0.9451    0.9515    0.9483       742
           3     0.9908    0.9988    0.9948      9798
           4     0.8774    0.3357    0.4856       277
           5     0.9489    0.9261    0.9374      4114
           6     0.0000    0.0000    0.0000      1711

    accuracy                         0.9651     75000
   macro avg     0.8159    0.7430    0.7621     75000
weighted avg     0.9430    0.9651    0.9534     75000

[[52554    54    31    84    11   162     2]
 [   26  5434     0     0     0     0     0]
 [   33     3   706     0     0     0     0]
 [   10     1     1  9786     0     0     0]
 [  148     0     0     1    93    35     0]
 [  297     1     0     4     2  3810     0]
 [ 1691     1     9     2     0     8     0]]


In [33]:
pred, label = gnn.evaluate()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [14:19<00:00,  5.73s/it]


              precision    recall  f1-score   support

           0     0.9294    0.9807    0.9543     52898
           1     0.9454    0.9989    0.9714      5460
           2     0.9282    0.9232    0.9257       742
           3     0.9525    0.9960    0.9738      9798
           4     0.7025    0.4007    0.5103       277
           5     0.8988    0.4925    0.6363      4114
           6     0.0000    0.0000    0.0000      1711

    accuracy                         0.9321     75000
   macro avg     0.7653    0.6846    0.7103     75000
weighted avg     0.9098    0.9321    0.9170     75000

[[51875   270    40   451    47   197    18]
 [    5  5454     1     0     0     0     0]
 [   21    36   685     0     0     0     0]
 [   35     4     0  9759     0     0     0]
 [  163     2     1     0   111     0     0]
 [ 2085     0     0     3     0  2026     0]
 [ 1633     3    11    33     0    31     0]]
