In [1]:
from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
# import seaborn as sns
# import matplotlib.pyplot as plt
import numpy as np
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Datasets.UNSW_NB15.UNSW_NB15_config import UNSW_NB15_Config

In [2]:
csv_file_name = "all_raw_downsampled"

data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))

DATASET_NAME = "UNSW_NB15"

SOURCE_IP_COL_NAME = UNSW_NB15_Config.SOURCE_IP_COL_NAME
DESTINATION_IP_COL_NAME = UNSW_NB15_Config.DESTINATION_IP_COL_NAME
SOURCE_PORT_COL_NAME = UNSW_NB15_Config.SOURCE_PORT_COL_NAME
DESTINATION_PORT_COL_NAME = UNSW_NB15_Config.DESTINATION_PORT_COL_NAME

ATTACK_CLASS_COL_NAME = UNSW_NB15_Config.ATTACK_CLASS_COL_NAME
IS_ATTACK_COL_NAME = UNSW_NB15_Config.IS_ATTACK_COL_NAME

BENIGN_CLASS_NAME = UNSW_NB15_Config.BENIGN_CLASS_NAME

TIME_COLS = UNSW_NB15_Config.TIME_COL_NAMES

MULTICLASS = True
label_col = ATTACK_CLASS_COL_NAME if MULTICLASS else IS_ATTACK_COL_NAME

print(data[ATTACK_CLASS_COL_NAME].value_counts())
print(data[IS_ATTACK_COL_NAME].value_counts())

if MULTICLASS:
    data.drop(columns=[IS_ATTACK_COL_NAME], inplace=True)
else:
    data.drop(columns=[ATTACK_CLASS_COL_NAME], inplace=True)

checkpoint_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"no_batch/checkpoints_{csv_file_name}_no_batch.pth")
best_model_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"no_batch/best_model_{csv_file_name}_no_batch.pth")
final_epoch_model_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"no_batch/final_epoch{csv_file_name}_no_batch.pth")

os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
os.makedirs(os.path.dirname(best_model_path), exist_ok=True)

  data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))


attack_cat
Normal            221876
Generic           215481
Exploits           44525
Fuzzers            24246
DoS                16353
Reconnaissance     13987
Analysis            2677
Backdoor            1795
Shellcode           1511
Backdoors            534
Worms                174
Name: count, dtype: int64
label
1    321283
0    221876
Name: count, dtype: int64


In [3]:
data.drop(columns=UNSW_NB15_Config.DROP_COLS,inplace=True)

In [4]:
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME].apply(str)
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME].apply(str)

# # Combine Port and IP
data[SOURCE_PORT_COL_NAME] = data[SOURCE_PORT_COL_NAME].apply(str)
data[DESTINATION_PORT_COL_NAME] = data[DESTINATION_PORT_COL_NAME].apply(str)

data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME] + ':' + data[SOURCE_PORT_COL_NAME]
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME] + ':' + data[DESTINATION_PORT_COL_NAME]
data.drop(columns=[SOURCE_PORT_COL_NAME,DESTINATION_PORT_COL_NAME],inplace=True)

# data[SOURCE_PORT_COL_NAME] = pd.to_numeric(data[SOURCE_PORT_COL_NAME], errors='coerce').fillna(0).astype(int)
# data[DESTINATION_PORT_COL_NAME] = pd.to_numeric(data[DESTINATION_PORT_COL_NAME], errors='coerce').fillna(0).astype(int)

In [5]:
print(data.head)

<bound method NDFrame.head of                      srcip               dstip state        dur  sbytes  \
0             10.40.85.1:0         224.0.0.5:0   INT  50.004341     384   
1          59.166.0.6:2142    149.171.126.4:53   CON   0.001134     132   
2       175.45.176.0:13284   149.171.126.16:80   FIN   2.390390    1362   
3         59.166.0.3:42587    149.171.126.8:25   FIN  34.077175   37358   
4            10.40.170.2:0       10.40.170.2:0   INT   0.000000      46   
...                    ...                 ...   ...        ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543155  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543156    59.166.0.5:53521    149.171.126.7:21   CON   1.086072    1940   
543157  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   
543158  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   

        dbytes  sttl  dttl  sloss  dloss  ...  is_ftp_login  ct_ftp_c

In [6]:
data = pd.get_dummies(data, columns = UNSW_NB15_Config.CATEGORICAL_COLS) # One Hot Encoding for categorical data

In [7]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of                      srcip               dstip        dur  sbytes  dbytes  \
0             10.40.85.1:0         224.0.0.5:0  50.004341     384       0   
1          59.166.0.6:2142    149.171.126.4:53   0.001134     132     164   
2       175.45.176.0:13284   149.171.126.16:80   2.390390    1362     268   
3         59.166.0.3:42587    149.171.126.8:25  34.077175   37358    3380   
4            10.40.170.2:0       10.40.170.2:0   0.000000      46       0   
...                    ...                 ...        ...     ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543155  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543156    59.166.0.5:53521    149.171.126.7:21   1.086072    1940    2404   
543157  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   
543158  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   

        sttl  dttl  sloss  dloss         Sloa

In [8]:
scaler = StandardScaler()
cols_to_norm = UNSW_NB15_Config.COLS_TO_NORM
print(data[cols_to_norm].describe()) # Check if there's any too large value

                 dur        sbytes        dbytes           sttl  \
count  543159.000000  5.431590e+05  5.431590e+05  543159.000000   
mean        0.706760  5.136572e+03  1.936909e+04     157.197364   
std        12.637229  1.202311e+05  1.390925e+05     108.452474   
min         0.000000  0.000000e+00  0.000000e+00       0.000000   
25%         0.000007  1.140000e+02  0.000000e+00      31.000000   
50%         0.000011  2.000000e+02  0.000000e+00     254.000000   
75%         0.072088  1.580000e+03  1.940000e+03     254.000000   
max      8760.776367  1.435577e+07  1.465753e+07     255.000000   

                dttl          sloss          dloss         Sload  \
count  543159.000000  543159.000000  543159.000000  5.431590e+05   
mean       38.850764       3.800661       8.729770  6.877595e+07   
std        77.034389      45.616565      50.136204  1.420534e+08   
min         0.000000       0.000000       0.000000  0.000000e+00   
25%         0.000000       0.000000       0.000000  3.70

In [9]:
def check_numeric_issues(df, cols_to_norm):
    for col in cols_to_norm:
        try:
            # Try to coerce to numeric
            df[col] = pd.to_numeric(df[col], errors='coerce')
            
            # Try to clip the column
            df[col] = df[col].clip(lower=-1e9, upper=1e9)
            
        except Exception as e:
            print(f"❌ Column '{col}' failed with error: {e}")
            print(f"  - Sample values: {df[col].dropna().unique()[:5]}")
            print(f"  - Data type: {df[col].dtype}")
            continue

    print("\n✅ All other columns processed successfully.")

check_numeric_issues(data, UNSW_NB15_Config.COLS_TO_NORM)


✅ All other columns processed successfully.


In [10]:
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [11]:
from sklearn.preprocessing import LabelEncoder

num_classes = 2
class_map = [0, 1]
if MULTICLASS:
    le = LabelEncoder()
    attack_labels = le.fit_transform(data[ATTACK_CLASS_COL_NAME])
    class_map = le.classes_
    print(class_map)
    print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))
    data[ATTACK_CLASS_COL_NAME] = attack_labels
    num_classes = len(class_map)
    class_dict = {le.inverse_transform([i])[0]: i for i in range(len(le.classes_))}


['Analysis' 'Backdoor' 'Backdoors' 'DoS' 'Exploits' 'Fuzzers' 'Generic'
 'Normal' 'Reconnaissance' 'Shellcode' 'Worms']
Attack label mapping: {'Analysis': 0, 'Backdoor': 1, 'Backdoors': 2, 'DoS': 3, 'Exploits': 4, 'Fuzzers': 5, 'Generic': 6, 'Normal': 7, 'Reconnaissance': 8, 'Shellcode': 9, 'Worms': 10}


In [12]:
# 70% train, 15% validation, 15% test
train_df, temp_df = train_test_split(
     data, test_size=0.3, random_state=42, stratify=data[label_col])

val_df, test_df = train_test_split(
     temp_df, test_size=0.5, random_state=42, stratify=temp_df[label_col])


# Maintain the order of the rows in the original dataframe
train_df = train_df.sort_values(by=UNSW_NB15_Config.TIME_COL_NAMES)
val_df = val_df.sort_values(by=UNSW_NB15_Config.TIME_COL_NAMES)
test_df = test_df.sort_values(by=UNSW_NB15_Config.TIME_COL_NAMES)

feature_cols = [col for col in data.columns if col not in [label_col, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME]]

print('Feature Columns:', feature_cols)

train_df['h'] = train_df[ feature_cols ].values.tolist()
val_df['h'] = val_df[ feature_cols ].values.tolist()
test_df['h'] = test_df[ feature_cols ].values.tolist()

# X_train = train_df.drop(columns=[label_col])
# X_val = val_df.drop(columns=[label_col])
# X_test = test_df.drop(columns=[label_col])

y_train = train_df[label_col]
y_val = test_df[label_col]
y_test = test_df[label_col]

pd.set_option('display.max_columns', None)

print("Number of training samples:", len(train_df))
print(y_train.value_counts())
print("Number of validation samples:", len(val_df))
print(y_val.value_counts())
print("Number of test samples:", len(test_df))
print(y_test.value_counts())

print(train_df.head(5))

Feature Columns: ['dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss', 'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit', 'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'state_ACC', 'state_CLO', 'state_CON', 'state_ECO', 'state_FIN', 'state_INT', 'state_MAS', 'state_PAR', 'state_REQ', 'state_RST', 'state_TST', 'state_TXD', 'state_URH']
Number of training samples: 380211
attack_cat
7     155313
6     150837
4      31167
5      16972
3      11447
8       9791
0       1874
1       1256
9       1058
2        374
10       122
Name: count, dtype: int64
Number of validation samples: 81474
attack_cat
7     33282
6     32322
4      6679
5      3637
3      2453
8      2098
0       402

In [13]:
def create_graph(df, source_ip_col, destination_ip_col, edge_attr, create_using=nx.MultiDiGraph(), **kwargs):
    G_nx = nx.from_pandas_edgelist(df, source_ip_col, destination_ip_col, edge_attr, create_using=create_using, **kwargs)
    G_pyg = from_networkx(G_nx)

    num_nodes = G_pyg.num_nodes
    num_edges = G_pyg.num_edges

    G_pyg.x = th.ones(num_nodes, len(df['h'].iloc[0])) 

    edge_attr_list = []
    edge_label_list = []

    for u, v, key, data in G_nx.edges(keys=True, data=True):
        edge_attr_list.append(data['h']) 
        edge_label_list.append(data[label_col]) 

    G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
    G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)

    print("Number of edges in G_pyg:", num_edges)
    print("Number of node in G_pyg:", num_nodes)
    print("Shape of node in G_pyg:", G_pyg.x.shape)
    print("Shape of edge attr in G_pyg:", G_pyg.edge_attr.shape)
    print("Shape of edge label in G_pyg:", G_pyg.edge_label.shape)

    return G_nx, G_pyg

In [14]:
G_nx_train, G_pyg_train = create_graph(train_df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())
G_nx_val, G_pyg_val = create_graph(val_df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())

Number of edges in G_pyg: 380211
Number of node in G_pyg: 220487
Shape of node in G_pyg: torch.Size([220487, 53])
Shape of edge attr in G_pyg: torch.Size([380211, 53])
Shape of edge label in G_pyg: torch.Size([380211])
Number of edges in G_pyg: 81474
Number of node in G_pyg: 52568
Shape of node in G_pyg: torch.Size([52568, 53])
Shape of edge attr in G_pyg: torch.Size([81474, 53])
Shape of edge label in G_pyg: torch.Size([81474])


In [15]:
class EGraphSAGEConv(MessagePassing):
    def __init__(self, node_in_channels, edge_in_channels, out_channels):
        super(EGraphSAGEConv, self).__init__(aggr='mean')  # mean aggregation
        self.lin_node = nn.Linear(node_in_channels, out_channels)
        self.lin_edge = nn.Linear(edge_in_channels, out_channels)
        self.lin_update = nn.Linear(node_in_channels + out_channels, out_channels) # out_channels * 2

    def forward(self, x, edge_index, edge_attr):
        # x: Node features, edge_attr: Edge features, edge_index: Connectivity
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        if edge_attr is not None:
            if edge_attr.size(0) != edge_index.size(1):
                loop_attr = th.zeros((edge_index.size(1) - edge_attr.size(0), edge_attr.size(1))).to(edge_attr.device)
                edge_attr = th.cat([edge_attr, loop_attr], dim=0)
        else:
            print("edge_attr is unexist")
        
        # Propagate and aggregate neighbor information
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j represents the adjacent nodes of x
        # Compute messages by combining node and edge features
        return self.lin_node(x_j) + self.lin_edge(edge_attr)

    def update(self, aggr_out, x):
        # Update node features after message passing
        return self.lin_update(th.cat([x, aggr_out], dim=1))

class MLPPredictor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MLPPredictor, self).__init__()
        self.lin = nn.Linear(in_channels * 2, out_channels)

    def forward(self, data, z):
        row, col = data.edge_index
        # Concatenate the features of source and target nodes for each edge
        edge_feat = th.cat([z[row], z[col]], dim=1)
        return self.lin(edge_feat)

class EGraphSAGE(nn.Module):
    def __init__(self, node_in_channels, edge_in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = EGraphSAGEConv(node_in_channels, edge_in_channels, hidden_channels)
        self.conv2 = EGraphSAGEConv(hidden_channels, edge_in_channels, hidden_channels)
        self.mlp_predictor = MLPPredictor(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        return self.mlp_predictor(data, x)

In [16]:
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
print(device)

cuda:0


In [17]:
th.cuda.empty_cache()

In [18]:
model = EGraphSAGE(node_in_channels=G_pyg_train.num_node_features, 
                   edge_in_channels=G_pyg_train.num_edge_features,
                   hidden_channels=128, 
                   out_channels=num_classes).to(device)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)

labels = G_pyg_train.edge_label.cpu().numpy()
class_weights = class_weight.compute_class_weight('balanced',
                                                  classes=np.unique(labels),
                                                  y=labels)

# Normalise to stabilise training
class_weights = class_weights / np.mean(class_weights)

class_weights = th.FloatTensor(class_weights).to(device)
print("Class weights:", class_weights)
criterion = nn.CrossEntropyLoss(weight = class_weights)
optimizer = th.optim.Adam(model.parameters(), lr=0.001)

Class weights: tensor([4.3677e-01, 6.5168e-01, 2.1885e+00, 7.1504e-02, 2.6262e-02, 4.8227e-02,
        5.4265e-03, 5.2701e-03, 8.3598e-02, 7.7364e-01, 6.7091e+00],
       device='cuda:0')


In [19]:
from sklearn.metrics import f1_score

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

G_pyg_train.to(device)
G_pyg_val.to(device)

G_pyg_train.edge_label = G_pyg_train.edge_label.to(device)
G_pyg_train.edge_attr = G_pyg_train.edge_attr.to(device)

G_pyg_val.edge_label = G_pyg_val.edge_label.to(device)
G_pyg_val.edge_attr = G_pyg_val.edge_attr.to(device)

best_f1 = 0
best_model_state = None

# Load checkpoint if exists
start_epoch = 0
epochs = 15000

if os.path.exists(checkpoint_path):
    checkpoint = th.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_f1 = checkpoint['best_f1']
    print(f"Resumed training from epoch {start_epoch}")

for epoch in range(start_epoch, epochs):
    train_loss = 0

    val_loss = 0
    
    try:
        model.train()
        
        try:
            
            out = model(G_pyg_train)

            loss = criterion(out, G_pyg_train.edge_label)
            train_loss = loss.item()
            if th.isnan(loss):
                print(f"loss: {loss}")
                print(f"out: {out}")
                print(f"edge_labels: {G_pyg_train.edge_label}")
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        except Exception as forward_error:
            print(f"Error during forward/backward pass at Epoch {epoch}: {forward_error}")
            continue
        
        model.eval()
        with th.no_grad():
                
            try:
                out = model(G_pyg_val)

                loss = criterion(out, G_pyg_val.edge_label)
                val_loss = loss.item()
                
            except Exception as forward_error:
                print(f"Error during validation at Epoch {epoch}: {forward_error}")
                continue

        val_f1 = f1_score(G_pyg_val.edge_label.cpu(), out.argmax(dim=1).cpu(), average='weighted')
        if epoch % 100 == 0:
            print(f'Epoch {epoch}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F1: {val_f1:.4f}')

        # Save the best model
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_model_state = model.state_dict()
            th.save(best_model_state, best_model_path)
            print("Saved best model. F1 Score:", best_f1)

        # Save checkpoint
        th.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_f1': best_f1
        }, checkpoint_path)

    except Exception as e:
        print(f"An error occurred at epoch {epoch} {str(e)}")
print("Training is over")

Resumed training from epoch 10056
Epoch 10100, Train Loss: 0.8234, Validation Loss: 0.9552, Validation F1: 0.8974
Epoch 10200, Train Loss: 0.8214, Validation Loss: 0.9376, Validation F1: 0.9023
Epoch 10300, Train Loss: 0.8228, Validation Loss: 0.9565, Validation F1: 0.8982
Epoch 10400, Train Loss: 0.8234, Validation Loss: 0.9589, Validation F1: 0.8974
Epoch 10500, Train Loss: 0.8237, Validation Loss: 0.9433, Validation F1: 0.9022
Epoch 10600, Train Loss: 0.8208, Validation Loss: 0.9515, Validation F1: 0.9024
Epoch 10700, Train Loss: 0.8216, Validation Loss: 0.9494, Validation F1: 0.8981
Saved best model. F1 Score: 0.9191974782695332
Epoch 10800, Train Loss: 0.8230, Validation Loss: 0.9397, Validation F1: 0.8999
Epoch 10900, Train Loss: 0.8252, Validation Loss: 0.9178, Validation F1: 0.8963
Epoch 11000, Train Loss: 0.8233, Validation Loss: 0.9412, Validation F1: 0.8987
Epoch 11100, Train Loss: 0.8225, Validation Loss: 0.9463, Validation F1: 0.9024
Epoch 11200, Train Loss: 0.8197, Valida

In [20]:
th.save(model.state_dict(), final_epoch_model_path)

In [21]:
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()



def eval(dataframe, adversarial=False):
    G_nx_test = nx.from_pandas_edgelist(dataframe, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())

    G_pyg_test = from_networkx(G_nx_test)

    test_num_nodes = G_pyg_test.num_nodes
    test_num_edges = G_pyg_test.num_edges

    G_pyg_test.x = th.ones(test_num_nodes, len(test_df['h'].iloc[0]))

    test_edge_attr_list = []
    test_edge_label_list = []

    for u, v, key, data in G_nx_test.edges(keys=True, data=True):
        test_edge_attr_list.append(data['h']) 
        test_edge_label_list.append(data[label_col]) 

    G_pyg_test.edge_attr = th.tensor(test_edge_attr_list, dtype=th.float32)
    G_pyg_test.edge_label = th.tensor(test_edge_label_list, dtype=th.long)

    G_pyg_test = G_pyg_test.to(device)
    G_pyg_test.edge_label = G_pyg_test.edge_label.to(device)
    G_pyg_test.edge_attr = G_pyg_test.edge_attr.to(device)

    print("Number of edges in G_pyg_test:", G_pyg_test.num_edges)
    print("Number of node in G_pyg_test:", G_pyg_test.num_nodes)
    print("Shape of node in G_pyg_test:", G_pyg_test.x.shape)
    print("Shape of edge attr in G_pyg_test:", G_pyg_test.edge_attr.shape)
    print("Shape of edge label in G_pyg_test:", G_pyg_test.edge_label.shape)
    
    new_model_2 = EGraphSAGE(node_in_channels=G_pyg_test.num_node_features, 
                       edge_in_channels=G_pyg_test.num_edge_features,
                       hidden_channels=128, 
                       out_channels=num_classes).to(device)

    new_model_2.load_state_dict(th.load(best_model_path, weights_only=True))

    model.eval()

    print("inference start")
    with th.no_grad():
            
        try:
            out = model(G_pyg_test)
            
        except Exception as forward_error:
            print(f"Error during forward/backward pass at {forward_error}")

    print("inference done")

    test_accuracy = compute_accuracy(out, G_pyg_test.edge_label)
    print(f'Test Accuracy: {test_accuracy:.4f}')
    
    pred_labels = out.argmax(dim=1).cpu()
    all_test_labels = G_pyg_test.edge_label.cpu()

    
    global class_map
    class_map_2 = class_map
    if adversarial:
        class_map_2 = np.append(class_map, "Adversarial")

    # Generate a report
    cm = confusion_matrix(all_test_labels, pred_labels, labels=range(len(class_map_2)))
    print(cm)
    report = classification_report(all_test_labels, pred_labels, target_names=class_map_2, digits=4)
    print(report)

eval(test_df)


Number of edges in G_pyg_test: 81474
Number of node in G_pyg_test: 52280
Shape of node in G_pyg_test: torch.Size([52280, 53])
Shape of edge attr in G_pyg_test: torch.Size([81474, 53])
Shape of edge label in G_pyg_test: torch.Size([81474])
inference start
inference done
Test Accuracy: 0.8798
[[   41   157    68    36    26     0     0     0     0     0    74]
 [    4   165     8    69    10     0     0     0     0     6     7]
 [    0     4    72     1     1     0     0     0     0     1     1]
 [  231  1257   116   339   203    13    17     1    40    47   189]
 [  475  1422   236   621  2225    47    61     0   186   165  1241]
 [   40   182   153    80   139  2831     0     0    15     0   197]
 [   39   167    11   111   200     8 31487     2    28    29   240]
 [    0     6     1     2    30   367     0 32812     0     0    64]
 [   41   217    10    68    74     1     4     0  1471     9   203]
 [    0     0     2     0     0     1     0     0     0   222     1]
 [    0     0     

In [40]:
def attack_attacker(dataframe, ratio, num_injected_nodes):
    attack_eval = dataframe[dataframe[label_col] != class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    num_injected = int(ratio * len(dataframe))

    # Sample attack rows
    sampled_attack_flows = attack_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)
    
    injected_rows = sampled_attack_flows.copy()
    print("Sampled attack flows:", len(sampled_attack_flows))
    print("Labels of sampled attack flows:", sampled_attack_flows[label_col].value_counts())

    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] # Target the Real Attacker Nodes
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [f"{node_ips[i % len(node_ips)]}:{random.randint(1024, 65535)}" for i in range(num_injected)]
    # injected_rows['pkSeqID'] = [f'Injected-{i}' for i in range(num_injected)]
    injected_rows[label_col] = len(class_map) # Assign a new class for injected samples
    print(injected_rows[0:5])

    # Append and reorder
    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    # Sort using this datetime column
    combined_df = combined_df.sort_values(by=UNSW_NB15_Config.TIME_COL_NAMES).reset_index(drop=True)

    return combined_df

# Inject adversarial samples
attack_attacker_df = attack_attacker(test_df, 0.1, num_injected_nodes=1)
eval(attack_attacker_df, adversarial=True)

Sampled attack flows: 8147
Labels of sampled attack flows: attack_cat
6     5483
4     1136
5      602
3      405
8      361
0       66
1       46
9       35
2        8
10       5
Name: count, dtype: int64
               srcip               dstip       dur    sbytes    dbytes  \
0  192.168.1.1:42518   175.45.176.1:1043 -0.055927 -0.041774 -0.139253   
1  192.168.1.1:26892  175.45.176.1:47439 -0.055926 -0.041774 -0.139253   
2  192.168.1.1:44663  175.45.176.1:47439 -0.055926 -0.041774 -0.139253   
3  192.168.1.1:58368   175.45.176.1:1043 -0.055927 -0.041774 -0.139253   
4  192.168.1.1:41066  175.45.176.1:47439 -0.055926 -0.041774 -0.139253   

       sttl      dttl     sloss     dloss     Sload     Dload     Spkts  \
0  0.892582 -0.504331 -0.083318 -0.174121  0.387611 -0.367084 -0.180233   
1  0.892582 -0.504331 -0.083318 -0.174121 -0.083775 -0.367084 -0.180233   
2  0.892582 -0.504331 -0.083318 -0.174121 -0.178053 -0.367084 -0.180233   
3  0.892582 -0.504331 -0.083318 -0.174121  0.7018

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [41]:
def normalise_attacker(dataframe, ratio, num_injected_nodes):

    normal_eval = dataframe[dataframe[label_col] == class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    attack_eval = dataframe[dataframe[label_col] != class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    print("Normal Flows:", len(normal_eval))
    print("Attack Flows:", len(attack_eval))
    num_injected = int(ratio * len(dataframe))

    sampled_normal_flows = normal_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)
    sampled_attack_flows = attack_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)

    injected_rows = sampled_normal_flows.copy()
    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = sampled_attack_flows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] # Direct BENGIN Traffic to the Real Attacker Nodes
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [f"{node_ips[i % len(node_ips)]}:{random.randint(1024, 65535)}" for i in range(num_injected)]
    injected_rows[label_col] = len(class_map)
    print(injected_rows[0:5])

    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    # Sort using this datetime column
    combined_df = combined_df.sort_values(by=UNSW_NB15_Config.TIME_COL_NAMES).reset_index(drop=True)

    return combined_df

# Inject adversarial samples
normalise_attacker_df = normalise_attacker(test_df, 0.1, 1)
eval(normalise_attacker_df, adversarial=True)

Normal Flows: 33282
Attack Flows: 48192
               srcip               dstip       dur    sbytes    dbytes  \
0  192.168.1.1:34193   175.45.176.1:1043 -0.026753  0.266765 -0.115701   
1  192.168.1.1:29250  175.45.176.1:47439 -0.052112  0.031535 -0.136205   
2   192.168.1.1:5090  175.45.176.1:47439 -0.026270 -0.001003 -0.117110   
3  192.168.1.1:33072   175.45.176.1:1043  0.060097 -0.010717 -0.002050   
4   192.168.1.1:3145  175.45.176.1:47439 -0.055836 -0.041508 -0.137974   

      sttl      dttl     sloss     dloss     Sload     Dload     Spkts  \
0 -1.16362 -0.127875  0.311276 -0.014556 -0.548613 -0.344918  0.310332   
1 -1.16362 -0.127875  0.004370 -0.154176 -0.543784 -0.347401 -0.062497   
2 -1.16362 -0.127875  0.070136 -0.034501 -0.554308 -0.346760  0.074861   
3 -1.16362 -0.127875  0.157823  0.125064 -0.554992 -0.334432  0.290709   
4 -1.16362 -0.127875 -0.083318 -0.174121 -0.550951 -0.168612 -0.180233   

      Dpkts      swin      dwin     stcpb     dtcpb   smeansz   dmeans

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [42]:
def random_connection(dataframe, ratio, num_injected_nodes):

    normal_eval = dataframe[dataframe[label_col] == class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]

    num_injected = int(ratio * len(dataframe))
    print("injected rows:", num_injected)

    sampled_normal_flows = normal_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)

    injected_rows = sampled_normal_flows.copy()
    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = [node_ips[i % len(node_ips)] for i in range(num_injected)]
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [node_ips[(i + 1) % len(node_ips)] for i in range(num_injected)]
    injected_rows[label_col] = len(class_map)

    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    # Sort using this datetime column
    combined_df = combined_df.sort_values(by=UNSW_NB15_Config.TIME_COL_NAMES).reset_index(drop=True)

    return combined_df

# Inject adversarial samples
random_connection_df = random_connection(test_df, 0.1, 1)
eval(random_connection_df, adversarial=True)

injected rows: 8147
Number of edges in G_pyg_test: 89621
Number of node in G_pyg_test: 52281
Shape of node in G_pyg_test: torch.Size([52281, 53])
Shape of edge attr in G_pyg_test: torch.Size([89621, 53])
Shape of edge label in G_pyg_test: torch.Size([89621])
inference start
inference done
Test Accuracy: 0.7998
[[   41   157    68    36    26     0     0     0     0     0    74     0]
 [    4   165     8    69    10     0     0     0     0     6     7     0]
 [    0     4    72     1     1     0     0     0     0     1     1     0]
 [  231  1257   116   339   203    13    17     1    40    47   189     0]
 [  475  1422   236   621  2225    47    61     0   186   165  1241     0]
 [   40   182   153    80   139  2831     0     0    15     0   197     0]
 [   39   167    11   111   200     8 31487     2    28    29   240     0]
 [    0     6     1     2    30   367     0 32812     0     0    64     0]
 [   41   217    10    68    74     1     4     0  1471     9   203     0]
 [    0     0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
