## GNN4ID Heterogeneous Graph Model

In this notebook, we provide instructions for using our developed heterogeneous graph models. We have created two different architectures:

1. **Model without Edge Attributes**: In this model, edges provide only the connection information between nodes. This means the model focuses solely on the structural relationships within the graph.
2. **Model with Edge Attributes**: In this model, edges have their own attributes/features in addition to providing connection information between nodes. This allows the model to leverage additional information carried by the edges, potentially improving its performance and insights.

In [1]:
import sys
sys.path.append("G:/")

In [2]:
from Utility.Functions import *
from Utility.Model import *
from Utility.Training import *
from torch_geometric.loader import DataLoader
from tqdm import tqdm

# import copy
# import torch
# import numpy as np
# import torch.nn as nn
# import torch.nn.functional as F
# import torch_geometric.nn as pyg_nn
# from torch_geometric.nn import global_mean_pool
# from torch_sparse import SparseTensor, matmul
# import glob
# import pandas as pd
# import torch
# from tqdm import tqdm
# import os
# import glob
# from torch_geometric.data import HeteroData
# from torch_geometric.data import Dataset, Data
# import sys
# import torch_geometric.transforms as T
# import seaborn as sns
# from sklearn.metrics import confusion_matrix, f1_score,accuracy_score, precision_score, recall_score, roc_auc_score

### Reading Graph Objects

**dir**: Where grapgh data is stored in processed folder.
    data directory will have two folders inside: raw and processed.
    graph objects will be stored in this processed folder

In [3]:
Dict_x = {'Benign': 0 , 
          'WebBased': 1, 
          'Spoofing': 2,
          'Recon' : 3,
          'Mirai' : 4,
          'Dos' : 5,
          'DDos' : 6,
          'BruteForce': 7
         }

dir = "F:/GNN_Project/data/"
Files =glob.glob("F:/GNN_Project/data_nate_approach/raw/*.csv")

In [4]:
Dict_x = {'Benign': 0 , 
          'WebBased': 1, 
          'Spoofing': 2,
          'Recon' : 3,
          'Mirai' : 4,
          'Dos' : 5,
          'DDos' : 6,
          'BruteForce': 7
         }

dir = "F:/GNN_Project/data/" ## Directory where graph data will be stored
Files =glob.glob("/scratch/user/yasir.ali/GNN_Project/data/raw/*.csv") ## Directory where CSV files(Extracted Flow-level and packet-level inforamtion) is stored

In [4]:
data_Hetero = NIDSDataset(root=dir, label_dict=Dict_x, filename=Files, skip_processing=True, test=False, single_file=True)

### Initializing the Model

In [7]:
## Arguments for running the model
args = {
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'hidden_size': 64,
    'epochs': 30,
    'weight_decay': 1e-5,
    'lr': 0.01,
    'attn_size': 32,
    'eps': 1.0,
}

In [8]:
## Initializing a Data Instance for Model Initialization
data_model=data_Hetero[0].to(args['device'])

In [9]:
## Model without edge attributes
model = HeteroGNN(data_model, args, aggr="mean").to(args['device'])

## Model with Edge attributes
# model = HeteroGNN_Edge(data_model, args, aggr="mean").to(args['device'])

### Training Loop


In [10]:
train_loader = DataLoader(data_Hetero, batch_size=64, shuffle=True)

In [None]:
# For training the model without edge attributes
train(train_loader, model, args, args["device"])

# # For training the model with edge attributes 
# train_with_edge_Att(train_loader, model, args, args["device"])

### Testing Loop

In [12]:
data_Hetero = NIDSDataset(root=dir, label_dict=Dict_x, filename=Files, skip_processing=True, test=True, single_file=True)

In [13]:
## For testing the model
test_loader = DataLoader(data_Hetero, batch_size=1, shuffle=False)

In [None]:
# For testing the model without edge attributes
acc, prediction, label = test_cm(test_loader,model)

# # For testing the model with edge attributes 
# acc, prediction, label = test_cm_with_edge_att(test_loader,model)

#### Classification Report

In [None]:
from sklearn.metrics import mean_squared_error, accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
print(classification_report(label,prediction))
print('\n')
print('                    Accuracy %',(round(accuracy_score(label,prediction),4)*100))
print('\n')

#### Confusion Matrix

In [None]:
cm=confusion_matrix(label,prediction, normalize='true') ## Getting Results in Percentage 
plt.figure(figsize=(10, 10))
ax = plt.axes()
sns.heatmap(cm, annot=True, cmap='Blues', fmt='.1%',ax=ax) # fmt= 'd' for just showing the value in int
ax.set_ylabel('True Label') 
ax.set_xlabel('Predicted label')
labels=['Benign','WebBased','Spoofing','Recon','Mirai','Dos','DDos','BruteForce']
ax.xaxis.set_ticklabels(labels); ax.yaxis.set_ticklabels(labels)
plt.show()


#### Saving/Loading Model

In [None]:
torch.save(model, '/scratch/user/yasir.ali/GNN_Project/Saved_Model/GNN4ID_8_Classes/model.pth')
# model = torch.load('/scratch/user/yasir.ali/GNN_Project/Saved_Model/GNN4ID_8_Classes/model.pth')