In [1]:
from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
# import seaborn as sns
# import matplotlib.pyplot as plt
import numpy as np
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Datasets.UNSW_NB15.UNSW_NB15_config import UNSW_NB15_Config

In [19]:
csv_file_name = "all_raw_downsampled"

data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))

DATASET_NAME = "UNSW_NB15"

SOURCE_IP_COL_NAME = UNSW_NB15_Config.SOURCE_IP_COL_NAME
DESTINATION_IP_COL_NAME = UNSW_NB15_Config.DESTINATION_IP_COL_NAME
SOURCE_PORT_COL_NAME = UNSW_NB15_Config.SOURCE_PORT_COL_NAME
DESTINATION_PORT_COL_NAME = UNSW_NB15_Config.DESTINATION_PORT_COL_NAME

ATTACK_CLASS_COL_NAME = UNSW_NB15_Config.ATTACK_CLASS_COL_NAME
IS_ATTACK_COL_NAME = UNSW_NB15_Config.IS_ATTACK_COL_NAME

BENIGN_CLASS_NAME = UNSW_NB15_Config.BENIGN_CLASS_NAME

TIME_COLS = UNSW_NB15_Config.TIME_COL_NAMES

MULTICLASS = True
label_col = ATTACK_CLASS_COL_NAME if MULTICLASS else IS_ATTACK_COL_NAME

print(data[ATTACK_CLASS_COL_NAME].value_counts())
print(data[IS_ATTACK_COL_NAME].value_counts())

if MULTICLASS:
    data.drop(columns=[IS_ATTACK_COL_NAME], inplace=True)
else:
    data.drop(columns=[ATTACK_CLASS_COL_NAME], inplace=True)

checkpoint_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"tuned/checkpoints_{csv_file_name}_tuned.pth")
best_model_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"tuned/best_model_{csv_file_name}_tuned.pth")

os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
os.makedirs(os.path.dirname(best_model_path), exist_ok=True)

  data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))


attack_cat
Normal            221876
Generic           215481
Exploits           44525
Fuzzers            24246
DoS                16353
Reconnaissance     13987
Analysis            2677
Backdoor            1795
Shellcode           1511
Backdoors            534
Worms                174
Name: count, dtype: int64
label
1    321283
0    221876
Name: count, dtype: int64


In [20]:
data.drop(columns=UNSW_NB15_Config.DROP_COLS,inplace=True)

In [21]:
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME].apply(str)
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME].apply(str)

# # Combine Port and IP
data[SOURCE_PORT_COL_NAME] = data[SOURCE_PORT_COL_NAME].apply(str)
data[DESTINATION_PORT_COL_NAME] = data[DESTINATION_PORT_COL_NAME].apply(str)

data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME] + ':' + data[SOURCE_PORT_COL_NAME]
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME] + ':' + data[DESTINATION_PORT_COL_NAME]
data.drop(columns=[SOURCE_PORT_COL_NAME,DESTINATION_PORT_COL_NAME],inplace=True)

# data[SOURCE_PORT_COL_NAME] = pd.to_numeric(data[SOURCE_PORT_COL_NAME], errors='coerce').fillna(0).astype(int)
# data[DESTINATION_PORT_COL_NAME] = pd.to_numeric(data[DESTINATION_PORT_COL_NAME], errors='coerce').fillna(0).astype(int)

In [22]:
print(data.head)

<bound method NDFrame.head of                      srcip               dstip state        dur  sbytes  \
0             10.40.85.1:0         224.0.0.5:0   INT  50.004341     384   
1          59.166.0.6:2142    149.171.126.4:53   CON   0.001134     132   
2       175.45.176.0:13284   149.171.126.16:80   FIN   2.390390    1362   
3         59.166.0.3:42587    149.171.126.8:25   FIN  34.077175   37358   
4            10.40.170.2:0       10.40.170.2:0   INT   0.000000      46   
...                    ...                 ...   ...        ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543155  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543156    59.166.0.5:53521    149.171.126.7:21   CON   1.086072    1940   
543157  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   
543158  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   

        dbytes  sttl  dttl  sloss  dloss  ...  is_ftp_login  ct_ftp_c

In [23]:
data = pd.get_dummies(data, columns = UNSW_NB15_Config.CATEGORICAL_COLS) # One Hot Encoding for categorical data
converted_categorical_cols = [col for col in data.columns if col.startswith(tuple(UNSW_NB15_Config.CATEGORICAL_COLS))]

In [24]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of                      srcip               dstip        dur  sbytes  dbytes  \
0             10.40.85.1:0         224.0.0.5:0  50.004341     384       0   
1          59.166.0.6:2142    149.171.126.4:53   0.001134     132     164   
2       175.45.176.0:13284   149.171.126.16:80   2.390390    1362     268   
3         59.166.0.3:42587    149.171.126.8:25  34.077175   37358    3380   
4            10.40.170.2:0       10.40.170.2:0   0.000000      46       0   
...                    ...                 ...        ...     ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543155  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543156    59.166.0.5:53521    149.171.126.7:21   1.086072    1940    2404   
543157  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   
543158  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   

        sttl  dttl  sloss  dloss         Sloa

In [25]:
scaler = StandardScaler()
cols_to_norm = UNSW_NB15_Config.COLS_TO_NORM
print(data[cols_to_norm].describe()) # Check if there's any too large value

                 dur        sbytes        dbytes           sttl  \
count  543159.000000  5.431590e+05  5.431590e+05  543159.000000   
mean        0.706760  5.136572e+03  1.936909e+04     157.197364   
std        12.637229  1.202311e+05  1.390925e+05     108.452474   
min         0.000000  0.000000e+00  0.000000e+00       0.000000   
25%         0.000007  1.140000e+02  0.000000e+00      31.000000   
50%         0.000011  2.000000e+02  0.000000e+00     254.000000   
75%         0.072088  1.580000e+03  1.940000e+03     254.000000   
max      8760.776367  1.435577e+07  1.465753e+07     255.000000   

                dttl          sloss          dloss         Sload  \
count  543159.000000  543159.000000  543159.000000  5.431590e+05   
mean       38.850764       3.800661       8.729770  6.877595e+07   
std        77.034389      45.616565      50.136204  1.420534e+08   
min         0.000000       0.000000       0.000000  0.000000e+00   
25%         0.000000       0.000000       0.000000  3.70

In [26]:
def check_numeric_issues(df, cols_to_norm):
    for col in cols_to_norm:
        try:
            # Try to coerce to numeric
            df[col] = pd.to_numeric(df[col], errors='coerce')
            
            # Try to clip the column
            df[col] = df[col].clip(lower=-1e9, upper=1e9)
            
        except Exception as e:
            print(f"❌ Column '{col}' failed with error: {e}")
            print(f"  - Sample values: {df[col].dropna().unique()[:5]}")
            print(f"  - Data type: {df[col].dtype}")
            continue

    print("\n✅ All other columns processed successfully.")

check_numeric_issues(data, UNSW_NB15_Config.COLS_TO_NORM)


✅ All other columns processed successfully.


In [27]:
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [28]:
from sklearn.preprocessing import LabelEncoder

num_classes = 2
class_map = [0, 1]
if MULTICLASS:
    le = LabelEncoder()
    attack_labels = le.fit_transform(data[ATTACK_CLASS_COL_NAME])
    class_map = le.classes_
    print(class_map)
    print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))
    data[ATTACK_CLASS_COL_NAME] = attack_labels
    num_classes = len(class_map)
    class_dict = {le.inverse_transform([i])[0]: i for i in range(len(le.classes_))}


['Analysis' 'Backdoor' 'Backdoors' 'DoS' 'Exploits' 'Fuzzers' 'Generic'
 'Normal' 'Reconnaissance' 'Shellcode' 'Worms']
Attack label mapping: {'Analysis': 0, 'Backdoor': 1, 'Backdoors': 2, 'DoS': 3, 'Exploits': 4, 'Fuzzers': 5, 'Generic': 6, 'Normal': 7, 'Reconnaissance': 8, 'Shellcode': 9, 'Worms': 10}


In [29]:
# 70% train, 15% validation, 15% test
train_df, test_df = train_test_split(
     data, test_size=0.15, random_state=42, stratify=data[label_col])


# # Maintain the order of the rows in the original dataframe
# train_df = train_df.sort_values(by=UNSW_NB15_Config.TIME_COL_NAMES)
# test_df = test_df.sort_values(by=UNSW_NB15_Config.TIME_COL_NAMES)
train_df = train_df.drop(columns=UNSW_NB15_Config.TIME_COL_NAMES)
test_df = test_df.drop(columns=UNSW_NB15_Config.TIME_COL_NAMES)

feature_cols = UNSW_NB15_Config.COLS_TO_NORM + converted_categorical_cols

print('Feature Columns:', feature_cols)

train_df['h'] = train_df[ feature_cols ].values.tolist()
test_df['h'] = test_df[ feature_cols ].values.tolist()

# X_train = train_df.drop(columns=[label_col])
# X_val = val_df.drop(columns=[label_col])
# X_test = test_df.drop(columns=[label_col])

y_train = train_df[label_col]
y_test = test_df[label_col]

pd.set_option('display.max_columns', None)

print("Number of training samples:", len(train_df))
print(y_train.value_counts())
print("Number of test samples:", len(test_df))
print(y_test.value_counts())

print(train_df.head(5))

Feature Columns: ['dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss', 'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit', 'Djit', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'state_ACC', 'state_CLO', 'state_CON', 'state_ECO', 'state_FIN', 'state_INT', 'state_MAS', 'state_PAR', 'state_REQ', 'state_RST', 'state_TST', 'state_TXD', 'state_URH']
Number of training samples: 461685
attack_cat
7     188595
6     183159
4      37846
5      20609
3      13900
8      11889
0       2275
1       1526
9       1284
2        454
10       148
Name: count, dtype: int64
Number of test samples: 81474
attack_cat
7     33281
6     32322
4      6679
5      3637
3      2453
8      2098
0       402
1       269
9       227

In [30]:
def create_graph(df, source_ip_col, destination_ip_col, edge_attr, create_using=nx.MultiDiGraph(), **kwargs):
    G_nx = nx.from_pandas_edgelist(df, source_ip_col, destination_ip_col, edge_attr, create_using=create_using, **kwargs)
    G_pyg = from_networkx(G_nx)

    num_nodes = G_pyg.num_nodes
    num_edges = G_pyg.num_edges

    G_pyg.x = th.ones(num_nodes, len(df['h'].iloc[0])) 

    edge_attr_list = []
    edge_label_list = []

    for u, v, key, data in G_nx.edges(keys=True, data=True):
        edge_attr_list.append(data['h']) 
        edge_label_list.append(data[label_col]) 

    G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
    G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)

    print("Number of edges in G_pyg:", num_edges)
    print("Number of node in G_pyg:", num_nodes)
    print("Shape of node in G_pyg:", G_pyg.x.shape)
    print("Shape of edge attr in G_pyg:", G_pyg.edge_attr.shape)
    print("Shape of edge label in G_pyg:", G_pyg.edge_label.shape)

    return G_nx, G_pyg

In [31]:
G_nx_train, G_pyg_train = create_graph(train_df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())
G_nx_test, G_pyg_test = create_graph(test_df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())

Number of edges in G_pyg: 461685
Number of node in G_pyg: 261376
Shape of node in G_pyg: torch.Size([261376, 51])
Shape of edge attr in G_pyg: torch.Size([461685, 51])
Shape of edge label in G_pyg: torch.Size([461685])
Number of edges in G_pyg: 81474
Number of node in G_pyg: 52521
Shape of node in G_pyg: torch.Size([52521, 51])
Shape of edge attr in G_pyg: torch.Size([81474, 51])
Shape of edge label in G_pyg: torch.Size([81474])


In [32]:
class EGraphSAGEConv(MessagePassing):
    def __init__(self, node_in_channels, edge_in_channels, out_channels):
        super(EGraphSAGEConv, self).__init__(aggr='mean')  # mean aggregation
        self.lin_node = nn.Linear(node_in_channels, out_channels)
        self.lin_edge = nn.Linear(edge_in_channels, out_channels)
        self.lin_update = nn.Linear(node_in_channels + out_channels, out_channels) # out_channels * 2

    def forward(self, x, edge_index, edge_attr):
        # x: Node features, edge_attr: Edge features, edge_index: Connectivity
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        if edge_attr is not None:
            if edge_attr.size(0) != edge_index.size(1):
                loop_attr = th.zeros((edge_index.size(1) - edge_attr.size(0), edge_attr.size(1))).to(edge_attr.device)
                edge_attr = th.cat([edge_attr, loop_attr], dim=0)
        else:
            print("edge_attr is unexist")
        
        # Propagate and aggregate neighbor information
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j represents the adjacent nodes of x
        # Compute messages by combining node and edge features
        return self.lin_node(x_j) + self.lin_edge(edge_attr)

    def update(self, aggr_out, x):
        # Update node features after message passing
        return self.lin_update(th.cat([x, aggr_out], dim=1))

class MLPPredictor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MLPPredictor, self).__init__()
        self.lin = nn.Linear(in_channels * 2, out_channels)

    def forward(self, data, z):
        row, col = data.edge_index
        # Concatenate the features of source and target nodes for each edge
        edge_feat = th.cat([z[row], z[col]], dim=1)
        return self.lin(edge_feat)

class EGraphSAGE(nn.Module):
    def __init__(self, node_in_channels, edge_in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = EGraphSAGEConv(node_in_channels, edge_in_channels, hidden_channels)
        self.conv2 = EGraphSAGEConv(hidden_channels, edge_in_channels, hidden_channels)
        self.mlp_predictor = MLPPredictor(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        return self.mlp_predictor(data, x)

In [33]:
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
print(device)

cuda:0


In [34]:
th.cuda.empty_cache()

In [35]:


def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)



In [26]:
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

def grid_search(data, epochs, learning_rates, hidden_dims):
    skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
    best_params = {}
    best_f1 = 0

    # Precompute the train and validation graphs for all folds
    folds = []
    for train_idx, val_idx in skf.split(data, data[label_col]):
        train_df = data.iloc[train_idx]
        val_df = data.iloc[val_idx]

        G_nx_train, G_pyg_train = create_graph(train_df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())
        G_nx_val, G_pyg_val = create_graph(val_df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())

        folds.append((G_pyg_train, G_pyg_val))

    for lr in learning_rates:
        for hidden_dim in hidden_dims:
            print(f"Testing with learning rate: {lr}, hidden_dim: {hidden_dim}")
            fold_f1_scores = []

            for fold, (G_pyg_train, G_pyg_val) in enumerate(folds):
                print(f"Fold {fold + 1}")

                model = EGraphSAGE(node_in_channels=G_pyg_train.num_node_features,
                                   edge_in_channels=G_pyg_train.num_edge_features,
                                   hidden_channels=hidden_dim,
                                   out_channels=num_classes).to(device)

                model.apply(init_weights)

                labels = G_pyg_train.edge_label.cpu().numpy()
                class_weights = class_weight.compute_class_weight('balanced',
                                                                  classes=np.unique(labels),
                                                                  y=labels)

                # Normalize to stabilize training
                class_weights = class_weights / np.mean(class_weights)
                class_weights = th.FloatTensor(class_weights).to(device)
                print("Class weights:", class_weights)

                criterion = nn.CrossEntropyLoss(weight=class_weights)
                optimizer = th.optim.Adam(model.parameters(), lr=lr)

                G_pyg_train.to(device)
                G_pyg_val.to(device)

                G_pyg_train.edge_label = G_pyg_train.edge_label.to(device)
                G_pyg_train.edge_attr = G_pyg_train.edge_attr.to(device)

                G_pyg_val.edge_label = G_pyg_val.edge_label.to(device)
                G_pyg_val.edge_attr = G_pyg_val.edge_attr.to(device)

                best_epoch_f1 = 0  # Track the best F1 score for this fold

                for epoch in range(epochs):
                    train_loss = 0
                    val_loss = 0

                    try:
                        model.train()
                        out = model(G_pyg_train)
                        loss = criterion(out, G_pyg_train.edge_label)
                        train_loss = loss.item()

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        model.eval()
                        with th.no_grad():
                            out = model(G_pyg_val)
                            loss = criterion(out, G_pyg_val.edge_label)
                            val_loss = loss.item()

                        val_f1 = f1_score(G_pyg_val.edge_label.cpu(), out.argmax(dim=1).cpu(), average='weighted')

                        if val_f1 > best_epoch_f1:
                            best_epoch_f1 = val_f1  # Update the best F1 score for this fold
                            print(f"Best F1 Score at epoch {epoch}: {best_epoch_f1:.4f}, Parameters: lr={lr}, hidden_dim{hidden_dim}")

                        if epoch % 100 == 0:
                            print(f'Epoch {epoch}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation F1: {val_f1:.4f}')

                    except Exception as e:
                        print(f"An error occurred at epoch {epoch}: {str(e)}")
                        break

                fold_f1_scores.append(best_epoch_f1)  # Append the best F1 score for this fold

            avg_f1 = sum(fold_f1_scores) / len(fold_f1_scores)
            print(f"Average F1 Score for learning rate {lr}, hidden_dim {hidden_dim}: {avg_f1:.4f}")

            if avg_f1 > best_f1:
                best_f1 = avg_f1
                best_params = {'learning_rate': lr, 'hidden_dim': hidden_dim}

    print(f"Best Parameters: {best_params}, Best F1 Score: {best_f1:.4f}")


learning_rates = [0.001, 0.005, 0.01]
hidden_dims = [64, 128, 256]

grid_search(train_df, epochs=2000, learning_rates=learning_rates, hidden_dims=hidden_dims)


Number of edges in G_pyg: 307790
Number of node in G_pyg: 182726
Shape of node in G_pyg: torch.Size([182726, 53])
Shape of edge attr in G_pyg: torch.Size([307790, 53])
Shape of edge label in G_pyg: torch.Size([307790])
Number of edges in G_pyg: 153895
Number of node in G_pyg: 96060
Shape of node in G_pyg: torch.Size([96060, 53])
Shape of edge attr in G_pyg: torch.Size([153895, 53])
Shape of edge label in G_pyg: torch.Size([153895])
Number of edges in G_pyg: 307790
Number of node in G_pyg: 182512
Shape of node in G_pyg: torch.Size([182512, 53])
Shape of edge attr in G_pyg: torch.Size([307790, 53])
Shape of edge label in G_pyg: torch.Size([307790])
Number of edges in G_pyg: 153895
Number of node in G_pyg: 96341
Shape of node in G_pyg: torch.Size([96341, 53])
Shape of edge attr in G_pyg: torch.Size([153895, 53])
Shape of edge label in G_pyg: torch.Size([153895])
Number of edges in G_pyg: 307790
Number of node in G_pyg: 182467
Shape of node in G_pyg: torch.Size([182467, 53])
Shape of edge 

In [37]:
# Extract the best parameters from the grid search
best_hidden_dim = 256  # Replace with the best hidden_dim found
best_learning_rate = 0.005  # Replace with the best learning_rate found

# Create the graph for the entire training dataset
G_nx_train_full, G_pyg_train_full = create_graph(train_df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())

# Initialize the model with the best parameters
model = EGraphSAGE(node_in_channels=G_pyg_train_full.num_node_features,
                   edge_in_channels=G_pyg_train_full.num_edge_features,
                   hidden_channels=best_hidden_dim,
                   out_channels=num_classes).to(device)

model.apply(init_weights)

# Compute class weights for the training dataset
labels = G_pyg_train_full.edge_label.cpu().numpy()
class_weights = class_weight.compute_class_weight('balanced',
                                                  classes=np.unique(labels),
                                                  y=labels)

# Normalize class weights
class_weights = class_weights / np.mean(class_weights)
class_weights = th.FloatTensor(class_weights).to(device)
print("Class weights:", class_weights)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = th.optim.Adam(model.parameters(), lr=best_learning_rate)

# Move the graph data to the device
G_pyg_train_full.to(device)
G_pyg_train_full.edge_label = G_pyg_train_full.edge_label.to(device)
G_pyg_train_full.edge_attr = G_pyg_train_full.edge_attr.to(device)

best_f1 = 0
best_model_state = None

# Load checkpoint if exists
start_epoch = 0
epochs = 5000

if os.path.exists(checkpoint_path):
    checkpoint = th.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_f1 = checkpoint['best_f1']
    print(f"Resumed training from epoch {start_epoch}")

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    out = model(G_pyg_train_full)
    loss = criterion(out, G_pyg_train_full.edge_label)
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

    # Save the best model based on the lowest loss
    if loss.item() < best_f1:  # Here, best_f1 is used to track the lowest loss
        best_f1 = loss.item()
        best_model_state = model.state_dict()
        th.save(best_model_state, best_model_path)
        print("Saved best model. Lowest Loss:", best_f1)

    # Save checkpoint
    th.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_f1': best_f1
    }, checkpoint_path)

# Save the trained model
print("Model training completed and saved.")

Number of edges in G_pyg: 461685
Number of node in G_pyg: 261376
Shape of node in G_pyg: torch.Size([261376, 51])
Shape of edge attr in G_pyg: torch.Size([461685, 51])
Shape of edge label in G_pyg: torch.Size([461685])
Class weights: tensor([4.3659e-01, 6.5088e-01, 2.1878e+00, 7.1456e-02, 2.6244e-02, 4.8195e-02,
        5.4228e-03, 5.2665e-03, 8.3543e-02, 7.7355e-01, 6.7111e+00],
       device='cuda:0')
Epoch 0, Loss: 2.7431
Epoch 100, Loss: 0.9523
Epoch 200, Loss: 0.9031
Epoch 300, Loss: 0.8835
Epoch 400, Loss: 0.8941
Epoch 500, Loss: 1.1878
Epoch 600, Loss: 0.8931
Epoch 700, Loss: 0.8834
Epoch 800, Loss: 0.8759
Epoch 900, Loss: 0.8713
Epoch 1000, Loss: 0.8679
Epoch 1100, Loss: 0.8656
Epoch 1200, Loss: 0.8627
Epoch 1300, Loss: 0.8624
Epoch 1400, Loss: 0.8593
Epoch 1500, Loss: 0.8582
Epoch 1600, Loss: 0.9085
Epoch 1700, Loss: 0.8568
Epoch 1800, Loss: 0.8554
Epoch 1900, Loss: 0.8553
Epoch 2000, Loss: 0.8560
Epoch 2100, Loss: 1.7647
Epoch 2200, Loss: 0.9908
Epoch 2300, Loss: 0.9090
Epoch

In [45]:
th.save(model.state_dict(), best_model_path)

In [41]:
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

def eval(model, dataframe, adversarial=False):
    G_nx_test = nx.from_pandas_edgelist(dataframe, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())

    G_pyg_test = from_networkx(G_nx_test)

    test_num_nodes = G_pyg_test.num_nodes
    test_num_edges = G_pyg_test.num_edges

    G_pyg_test.x = th.ones(test_num_nodes, len(test_df['h'].iloc[0]))

    test_edge_attr_list = []
    test_edge_label_list = []

    for u, v, key, data in G_nx_test.edges(keys=True, data=True):
        test_edge_attr_list.append(data['h']) 
        test_edge_label_list.append(data[label_col]) 

    G_pyg_test.edge_attr = th.tensor(test_edge_attr_list, dtype=th.float32)
    G_pyg_test.edge_label = th.tensor(test_edge_label_list, dtype=th.long)

    G_pyg_test = G_pyg_test.to(device)
    G_pyg_test.edge_label = G_pyg_test.edge_label.to(device)
    G_pyg_test.edge_attr = G_pyg_test.edge_attr.to(device)

    print("Number of edges in G_pyg_test:", G_pyg_test.num_edges)
    print("Number of node in G_pyg_test:", G_pyg_test.num_nodes)
    print("Shape of node in G_pyg_test:", G_pyg_test.x.shape)
    print("Shape of edge attr in G_pyg_test:", G_pyg_test.edge_attr.shape)
    print("Shape of edge label in G_pyg_test:", G_pyg_test.edge_label.shape)

    model.eval()

    print("inference start")
    with th.no_grad():
            
        try:
            out = model(G_pyg_test)
            
        except Exception as forward_error:
            print(f"Error during forward/backward pass at {forward_error}")

    print("inference done")

    test_accuracy = compute_accuracy(out, G_pyg_test.edge_label)
    print(f'Test Accuracy: {test_accuracy:.4f}')
    
    pred_labels = out.argmax(dim=1).cpu()
    all_test_labels = G_pyg_test.edge_label.cpu()

    
    global class_map
    class_map_2 = class_map
    if adversarial:
        class_map_2 = np.append(class_map, "Adversarial")

    # Generate a report
    cm = confusion_matrix(all_test_labels, pred_labels, labels=range(len(class_map_2)))
    print(cm)
    report = classification_report(all_test_labels, pred_labels, target_names=class_map_2, digits=4)
    print(report)

eval(model, test_df)


Number of edges in G_pyg_test: 81474
Number of node in G_pyg_test: 52521
Shape of node in G_pyg_test: torch.Size([52521, 51])
Shape of edge attr in G_pyg_test: torch.Size([81474, 51])
Shape of edge label in G_pyg_test: torch.Size([81474])
inference start
inference done
Test Accuracy: 0.8823
[[   60   126    57   116     2     0     0     0     8     0    33]
 [    1   132     3   101     6     1     0     0     1    20     4]
 [    1     3    71     0     1     0     0     0     0     3     1]
 [   37   956   108   875   166    18    12     3    34    60   184]
 [  377  1188   213  1212  1964   176    28     4   232   199  1086]
 [   32   254   147   123    91  2843     4     4    16     1   122]
 [   60   152    15   149   181     8 31454     2    39    35   227]
 [   21    41     2     1    11   389     0 32774     7     3    32]
 [   48   178    10   136    31     0     4     0  1472     6   213]
 [    0     1     1     0     1     0     0     0     0   224     0]
 [    6     0     

In [42]:
def attack_attacker(dataframe, ratio, num_injected_nodes):
    attack_eval = dataframe[dataframe[label_col] != class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    num_injected = int(ratio * len(attack_eval))

    # Sample attack rows
    sampled_attack_flows = attack_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)
    
    injected_rows = sampled_attack_flows.copy()
    print("Sampled attack flows:", len(sampled_attack_flows))
    print("Labels of sampled attack flows:", sampled_attack_flows[label_col].value_counts())

    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] # Target the Real Attacker Nodes
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [f"{node_ips[i % len(node_ips)]}:{random.randint(1024, 65535)}" for i in range(num_injected)]
    # injected_rows['pkSeqID'] = [f'Injected-{i}' for i in range(num_injected)]
    injected_rows[label_col] = len(class_map) # Assign a new class for injected samples
    print(injected_rows[0:5])

    # Append and reorder
    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    return combined_df

# Inject adversarial samples
attack_attacker_df = attack_attacker(test_df, 0.1, num_injected_nodes=1)
eval(model, attack_attacker_df, adversarial=True)

Sampled attack flows: 4819
Labels of sampled attack flows: attack_cat
6     3230
4      675
5      347
3      236
8      217
0       43
1       30
9       28
2        7
10       6
Name: count, dtype: int64
               srcip               dstip       dur    sbytes    dbytes  \
0  192.168.1.1:62742  175.45.176.1:47439 -0.055926 -0.041774 -0.139253   
1  192.168.1.1:35952  175.45.176.2:37472  0.009629 -0.037316 -0.137327   
2  192.168.1.1:25978   175.45.176.1:1043 -0.055927 -0.041774 -0.139253   
3  192.168.1.1:43514      175.45.176.0:0 -0.055926 -0.041059 -0.139253   
4   192.168.1.1:8026   175.45.176.3:1043 -0.055927 -0.041774 -0.139253   

       sttl      dttl     sloss     dloss     Sload     Dload     Spkts  \
0  0.892582 -0.504331 -0.083318 -0.174121 -0.083775 -0.367084 -0.180233   
1  0.892582  2.766939 -0.039474 -0.154176 -0.555115 -0.366392 -0.101742   
2  0.892582 -0.504331 -0.083318 -0.174121  3.215930 -0.367084 -0.180233   
3  0.892582 -0.504331 -0.083318 -0.174121  0.1799

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [43]:
def normalise_attacker(dataframe, ratio, num_injected_nodes):

    normal_eval = dataframe[dataframe[label_col] == class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    attack_eval = dataframe[dataframe[label_col] != class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    print("Normal Flows:", len(normal_eval))
    print("Attack Flows:", len(attack_eval))
    num_injected = int(ratio * len(attack_eval))


    sampled_normal_flows = normal_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)
    sampled_attack_flows = attack_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)

    print("Sampled attack flows:", len(sampled_attack_flows))
    print("Labels of sampled attack flows:", sampled_attack_flows[label_col].value_counts())

    injected_rows = sampled_normal_flows.copy()
    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = sampled_attack_flows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] # Direct BENGIN Traffic to the Real Attacker Nodes
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [f"{node_ips[i % len(node_ips)]}:{random.randint(1024, 65535)}" for i in range(num_injected)]
    injected_rows[label_col] = len(class_map)
    print(injected_rows[0:5])

    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    return combined_df

# Inject adversarial samples
normalise_attacker_df = normalise_attacker(test_df, 0.1, 1)
eval(model, normalise_attacker_df, adversarial=True)

Normal Flows: 33281
Attack Flows: 48193
Sampled attack flows: 4819
Labels of sampled attack flows: attack_cat
6     3230
4      675
5      347
3      236
8      217
0       43
1       30
9       28
2        7
10       6
Name: count, dtype: int64
               srcip               dstip       dur    sbytes    dbytes  \
0  192.168.1.1:50827  175.45.176.1:47439  0.036655 -0.028716 -0.066151   
1  192.168.1.1:16443  175.45.176.2:37472 -0.024538 -0.029914 -0.127434   
2  192.168.1.1:62621   175.45.176.1:1043 -0.055926 -0.040527 -0.139253   
3  192.168.1.1:39996      175.45.176.0:0  0.378282  0.166541  7.645416   
4  192.168.1.1:14725   175.45.176.3:1043 -0.053969 -0.018320 -0.112350   

       sttl      dttl     sloss     dloss     Sload     Dload     Spkts  \
0 -1.163620 -0.127875 -0.017552 -0.074393 -0.555073 -0.346088 -0.062497   
1 -1.163620 -0.127875  0.004370 -0.094338 -0.554921 -0.357070 -0.042875   
2 -0.896222 -0.504331 -0.083318 -0.174121 -0.009346 -0.367084 -0.180233   
3 -1.1636

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [44]:
def random_connection(dataframe, ratio, num_injected_nodes):

    normal_eval = dataframe[dataframe[label_col] == class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]

    num_injected = int(ratio * len(dataframe))
    print("injected rows:", num_injected)

    sampled_normal_flows = normal_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)

    injected_rows = sampled_normal_flows.copy()
    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = [node_ips[i % len(node_ips)] for i in range(num_injected)]
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [node_ips[(i + 1) % len(node_ips)] for i in range(num_injected)]
    injected_rows[label_col] = len(class_map)

    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    return combined_df

# Inject adversarial samples
random_connection_df = random_connection(test_df, 0.1, 1)
eval(model, random_connection_df, adversarial=True)

injected rows: 8147
Number of edges in G_pyg_test: 89621
Number of node in G_pyg_test: 52522
Shape of node in G_pyg_test: torch.Size([52522, 51])
Shape of edge attr in G_pyg_test: torch.Size([89621, 51])
Shape of edge label in G_pyg_test: torch.Size([89621])
inference start
inference done
Test Accuracy: 0.8021
[[   60   126    57   116     2     0     0     0     8     0    33     0]
 [    1   132     3   101     6     1     0     0     1    20     4     0]
 [    1     3    71     0     1     0     0     0     0     3     1     0]
 [   37   956   108   875   166    18    12     3    34    60   184     0]
 [  377  1188   213  1212  1964   176    28     4   232   199  1086     0]
 [   32   254   147   123    91  2843     4     4    16     1   122     0]
 [   60   152    15   149   181     8 31454     2    39    35   227     0]
 [   21    41     2     1    11   389     0 32774     7     3    32     0]
 [   48   178    10   136    31     0     4     0  1472     6   213     0]
 [    0     1

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
