In [47]:
from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
# import seaborn as sns
# import matplotlib.pyplot as plt
import numpy as np
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Datasets.UNSW_NB15.UNSW_NB15_config import UNSW_NB15_Config

In [48]:
csv_file_name = "all_raw_downsampled"

data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))

DATASET_NAME = "UNSW_NB15"

SOURCE_IP_COL_NAME = UNSW_NB15_Config.SOURCE_IP_COL_NAME
DESTINATION_IP_COL_NAME = UNSW_NB15_Config.DESTINATION_IP_COL_NAME
SOURCE_PORT_COL_NAME = UNSW_NB15_Config.SOURCE_PORT_COL_NAME
DESTINATION_PORT_COL_NAME = UNSW_NB15_Config.DESTINATION_PORT_COL_NAME

ATTACK_CLASS_COL_NAME = UNSW_NB15_Config.ATTACK_CLASS_COL_NAME
IS_ATTACK_COL_NAME = UNSW_NB15_Config.IS_ATTACK_COL_NAME

BENIGN_CLASS_NAME = UNSW_NB15_Config.BENIGN_CLASS_NAME

TIME_COLS = UNSW_NB15_Config.TIME_COL_NAMES

MULTICLASS = True
label_col = ATTACK_CLASS_COL_NAME if MULTICLASS else IS_ATTACK_COL_NAME

print(data[ATTACK_CLASS_COL_NAME].value_counts())
print(data[IS_ATTACK_COL_NAME].value_counts())

if MULTICLASS:
    data.drop(columns=[IS_ATTACK_COL_NAME], inplace=True)
else:
    data.drop(columns=[ATTACK_CLASS_COL_NAME], inplace=True)

checkpoint_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"main_window/checkpoints_{csv_file_name}.pth")
best_model_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"main_window/best_model_{csv_file_name}.pth")

os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
os.makedirs(os.path.dirname(best_model_path), exist_ok=True)

# Compute the class weights
CLASS_WEIGHTS = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(data[label_col]),
    y=data[label_col]
)

print("Class weights:", CLASS_WEIGHTS)

  data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))


attack_cat
Normal            221876
Generic           215481
Exploits           44525
Fuzzers            24246
DoS                16353
Reconnaissance     13987
Analysis            2677
Backdoor            1795
Shellcode           1511
Backdoors            534
Worms                174
Name: count, dtype: int64
label
1    321283
0    221876
Name: count, dtype: int64
Class weights: [1.84453085e+01 2.75086857e+01 9.24683350e+01 3.01951268e+00
 1.10899699e+00 2.03654586e+00 2.29152876e-01 2.22548139e-01
 3.53028461e+00 3.26790807e+01 2.83782132e+02]


In [49]:
data.drop(columns=UNSW_NB15_Config.DROP_COLS,inplace=True)
print(data.columns)

Index(['srcip', 'sport', 'dstip', 'dsport', 'state', 'dur', 'sbytes', 'dbytes',
       'sttl', 'dttl', 'sloss', 'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts',
       'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth',
       'res_bdy_len', 'Sjit', 'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt',
       'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl',
       'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src',
       'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm',
       'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'attack_cat'],
      dtype='object')


In [50]:
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME].apply(str)
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME].apply(str)

# # Combine Port and IP
data[SOURCE_PORT_COL_NAME] = data[SOURCE_PORT_COL_NAME].apply(str)
data[DESTINATION_PORT_COL_NAME] = data[DESTINATION_PORT_COL_NAME].apply(str)

data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME] + ':' + data[SOURCE_PORT_COL_NAME]
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME] + ':' + data[DESTINATION_PORT_COL_NAME]
data.drop(columns=[SOURCE_PORT_COL_NAME,DESTINATION_PORT_COL_NAME],inplace=True)

# data[SOURCE_PORT_COL_NAME] = pd.to_numeric(data[SOURCE_PORT_COL_NAME], errors='coerce').fillna(0).astype(int)
# data[DESTINATION_PORT_COL_NAME] = pd.to_numeric(data[DESTINATION_PORT_COL_NAME], errors='coerce').fillna(0).astype(int)

In [51]:
print(data.head)

<bound method NDFrame.head of                      srcip               dstip state        dur  sbytes  \
0             10.40.85.1:0         224.0.0.5:0   INT  50.004341     384   
1          59.166.0.6:2142    149.171.126.4:53   CON   0.001134     132   
2       175.45.176.0:13284   149.171.126.16:80   FIN   2.390390    1362   
3         59.166.0.3:42587    149.171.126.8:25   FIN  34.077175   37358   
4            10.40.170.2:0       10.40.170.2:0   INT   0.000000      46   
...                    ...                 ...   ...        ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543155  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543156    59.166.0.5:53521    149.171.126.7:21   CON   1.086072    1940   
543157  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   
543158  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   

        dbytes  sttl  dttl  sloss  dloss  ...  is_ftp_login  ct_ftp_c

In [52]:
data = pd.get_dummies(data, columns = UNSW_NB15_Config.CATEGORICAL_COLS) # One Hot Encoding for categorical data
converted_categorical_cols = [col for col in data.columns if col.startswith(tuple(UNSW_NB15_Config.CATEGORICAL_COLS))]

In [53]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of                      srcip               dstip        dur  sbytes  dbytes  \
0             10.40.85.1:0         224.0.0.5:0  50.004341     384       0   
1          59.166.0.6:2142    149.171.126.4:53   0.001134     132     164   
2       175.45.176.0:13284   149.171.126.16:80   2.390390    1362     268   
3         59.166.0.3:42587    149.171.126.8:25  34.077175   37358    3380   
4            10.40.170.2:0       10.40.170.2:0   0.000000      46       0   
...                    ...                 ...        ...     ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543155  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543156    59.166.0.5:53521    149.171.126.7:21   1.086072    1940    2404   
543157  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   
543158  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   

        sttl  dttl  sloss  dloss         Sloa

In [54]:
scaler = StandardScaler()
cols_to_norm = UNSW_NB15_Config.COLS_TO_NORM
print(data[cols_to_norm].describe()) # Check if there's any too large value

                 dur        sbytes        dbytes           sttl  \
count  543159.000000  5.431590e+05  5.431590e+05  543159.000000   
mean        0.706760  5.136572e+03  1.936909e+04     157.197364   
std        12.637229  1.202311e+05  1.390925e+05     108.452474   
min         0.000000  0.000000e+00  0.000000e+00       0.000000   
25%         0.000007  1.140000e+02  0.000000e+00      31.000000   
50%         0.000011  2.000000e+02  0.000000e+00     254.000000   
75%         0.072088  1.580000e+03  1.940000e+03     254.000000   
max      8760.776367  1.435577e+07  1.465753e+07     255.000000   

                dttl          sloss          dloss         Sload  \
count  543159.000000  543159.000000  543159.000000  5.431590e+05   
mean       38.850764       3.800661       8.729770  6.877595e+07   
std        77.034389      45.616565      50.136204  1.420534e+08   
min         0.000000       0.000000       0.000000  0.000000e+00   
25%         0.000000       0.000000       0.000000  3.70

In [55]:
def check_numeric_issues(df, cols_to_norm):
    for col in cols_to_norm:
        try:
            # Try to coerce to numeric
            df[col] = pd.to_numeric(df[col], errors='coerce')
            
            # Try to clip the column
            df[col] = df[col].clip(lower=-1e9, upper=1e9)
            
        except Exception as e:
            print(f"❌ Column '{col}' failed with error: {e}")
            print(f"  - Sample values: {df[col].dropna().unique()[:5]}")
            print(f"  - Data type: {df[col].dtype}")
            continue

    print("\n✅ All other columns processed successfully.")

check_numeric_issues(data, UNSW_NB15_Config.COLS_TO_NORM)


✅ All other columns processed successfully.


In [56]:
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [57]:
from sklearn.preprocessing import LabelEncoder

num_classes = 2
class_map = [0, 1]
if MULTICLASS:
    le = LabelEncoder()
    attack_labels = le.fit_transform(data[ATTACK_CLASS_COL_NAME])
    class_map = le.classes_
    print(class_map)
    print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))
    data[ATTACK_CLASS_COL_NAME] = attack_labels
    num_classes = len(class_map)
    class_dict = {le.inverse_transform([i])[0]: i for i in range(len(le.classes_))}


['Analysis' 'Backdoor' 'Backdoors' 'DoS' 'Exploits' 'Fuzzers' 'Generic'
 'Normal' 'Reconnaissance' 'Shellcode' 'Worms']
Attack label mapping: {'Analysis': 0, 'Backdoor': 1, 'Backdoors': 2, 'DoS': 3, 'Exploits': 4, 'Fuzzers': 5, 'Generic': 6, 'Normal': 7, 'Reconnaissance': 8, 'Shellcode': 9, 'Worms': 10}


In [58]:
# # Maintain the order of the rows in the original dataframe

feature_cols = UNSW_NB15_Config.COLS_TO_NORM + converted_categorical_cols

print('Feature Columns:', feature_cols)
num_features = len(feature_cols)
print('Number of Features:', num_features)

data['h'] = data[ feature_cols ].values.tolist()

Feature Columns: ['dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss', 'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit', 'Djit', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'state_ACC', 'state_CLO', 'state_CON', 'state_ECO', 'state_FIN', 'state_INT', 'state_MAS', 'state_PAR', 'state_REQ', 'state_RST', 'state_TST', 'state_TXD', 'state_URH']
Number of Features: 51


In [59]:
def create_graph(df, source_ip_col, destination_ip_col, edge_attr, create_using=nx.MultiDiGraph(), **kwargs):
    G_nx = nx.from_pandas_edgelist(df, source_ip_col, destination_ip_col, edge_attr, create_using=create_using, **kwargs)
    G_pyg = from_networkx(G_nx)

    num_nodes = G_pyg.num_nodes
    num_edges = G_pyg.num_edges

    G_pyg.x = th.ones(num_nodes, len(df['h'].iloc[0])) 

    edge_attr_list = []
    edge_label_list = []

    for u, v, key, data in G_nx.edges(keys=True, data=True):
        edge_attr_list.append(data['h']) 
        edge_label_list.append(data[label_col]) 

    G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
    G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)

    return G_pyg

In [60]:
from collections import defaultdict

from tqdm import tqdm

class StratifiedGraphDataset:

    def __init__(self, graphs_by_label):
        self.graphs_by_label = graphs_by_label
        class_counts = {label: len(graphs) for label, graphs in self.graphs_by_label.items()}
        self.total_count = sum(class_counts.values())
        
    def graph_train_test_split(self, test_ratio: float = 0.15, random_state: int = 42):
        self.train_graphs_by_label = defaultdict(list)
        self.test_graphs_by_label = defaultdict(list)

        for label, graphs in self.graphs_by_label.items():
            train_graphs, test_graphs = train_test_split(
                graphs, test_size=test_ratio, random_state=random_state
            )
            self.train_graphs_by_label[label] = train_graphs
            self.test_graphs_by_label[label] = test_graphs

        return StratifiedGraphDataset(self.train_graphs_by_label), StratifiedGraphDataset(self.test_graphs_by_label)
    def __len__(self):
        return self.total_count

    def __iter__(self):
        # Calculate class weights and yield according to the weights

        all_graphs = [g for graphs in self.graphs_by_label.values() for g in graphs]
        random.shuffle(all_graphs)
        for g in all_graphs:
            yield g


def generate_graph_datasets(
    df: pd.DataFrame, 
    window_size: int = 2000, 
    overlap_ratio: float = 0.5, 
    time_cols=TIME_COLS, 
    label_col=label_col,
    build_graph_func=create_graph,
    build_graph_func_kwargs: dict = None):

    print("Feature Columns", df.columns)
    print("Time Columns", time_cols)
    assert all(col in df.columns for col in time_cols), "All timestamp columns are required"
    assert label_col in df.columns, "Edge label column 'label' is required"
    
    df = df.sort_values(time_cols).reset_index(drop=True)
    df.drop(columns=time_cols, inplace=True)
    window_size = window_size
    stride = int(window_size * (1 - overlap_ratio))
    
    graphs_by_label = defaultdict(list)
    
    for start in tqdm(range(0, len(df) - window_size + 1, stride), desc="Generating graphs"):

        window_df = df.iloc[start: start + window_size]
        contains_label = window_df[label_col].unique()
        G_pyg = build_graph_func(window_df, **build_graph_func_kwargs)

        for label in contains_label:
            graphs_by_label[label].append(G_pyg)

    return StratifiedGraphDataset(graphs_by_label)

In [61]:
graph_dataset = generate_graph_datasets(data, 
                                        overlap_ratio=0,
                                        build_graph_func_kwargs={
                                            'source_ip_col': SOURCE_IP_COL_NAME, 
                                            'destination_ip_col': DESTINATION_IP_COL_NAME, 
                                            'edge_attr': ['h', label_col], 
                                            'create_using': nx.MultiDiGraph()})
train_graph_dataset, test_graph_dataset = graph_dataset.graph_train_test_split(test_ratio=0.15, random_state=42)

Feature Columns Index(['srcip', 'dstip', 'dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss',
       'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb',
       'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit',
       'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack',
       'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd',
       'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm',
       'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm',
       'attack_cat', 'state_ACC', 'state_CLO', 'state_CON', 'state_ECO',
       'state_FIN', 'state_INT', 'state_MAS', 'state_PAR', 'state_REQ',
       'state_RST', 'state_TST', 'state_TXD', 'state_URH', 'h'],
      dtype='object')
Time Columns ['Stime', 'Ltime']


Generating graphs: 100%|██████████| 271/271 [00:22<00:00, 12.11it/s]


In [62]:
class SAGELayerPyG(MessagePassing):
    def __init__(self, in_channels, edge_dim, out_channels, activation=F.relu):
        super().__init__(aggr='mean')  # mean aggregation
        self.W_msg = nn.Linear(in_channels + edge_dim, out_channels)
        self.W_apply = nn.Linear(in_channels + out_channels, out_channels)
        self.activation = activation

    def forward(self, x, edge_index, edge_attr):
        # x: [num_nodes, in_channels]
        # edge_attr: [num_edges, edge_dim]
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j: features of source nodes (neighbours)
        msg_input = th.cat([x_j, edge_attr], dim=1)
        return self.W_msg(msg_input)

    def update(self, aggr_out, x):
        # aggr_out: [num_nodes, out_channels]
        combined = th.cat([x, aggr_out], dim=1)
        out = self.W_apply(combined)
        return self.activation(out)
    
class MLPPredictor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MLPPredictor, self).__init__()
        self.lin = nn.Linear(in_channels * 2, out_channels)

    def forward(self, data, z):
        row, col = data.edge_index
        # Concatenate the features of source and target nodes for each edge
        edge_feat = th.cat([z[row], z[col]], dim=1)
        return self.lin(edge_feat)

class EGraphSAGE(nn.Module):
    def __init__(self, node_in_channels, edge_in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = SAGELayerPyG(node_in_channels, edge_in_channels, hidden_channels)
        self.conv2 = SAGELayerPyG(hidden_channels, edge_in_channels, hidden_channels)
        self.mlp_predictor = MLPPredictor(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.2)
        x = self.conv2(x, edge_index, edge_attr)
        return self.mlp_predictor(data, x)

In [63]:
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
print(device)

cuda:0


In [64]:
th.cuda.empty_cache()

In [65]:


def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)



In [66]:
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

def grid_search(train_graph_dataset, epochs, learning_rates, hidden_dims, folds=3):
    global CLASS_WEIGHTS
    global num_features
    
    best_params = {}
    best_f1 = 0

    # Precompute the train and validation graphs for all folds
    folds_list = []
    for i in range(folds):
        train_graph_dataset, val_graph_dataset = train_graph_dataset.graph_train_test_split(test_ratio=0.15, random_state=i)
        folds_list.append((train_graph_dataset, val_graph_dataset))

    params_results = {}
    for lr in learning_rates:
        for hidden_dim in hidden_dims:
            print(f"Testing with learning rate: {lr}, hidden_dim: {hidden_dim}")
            fold_f1_scores = []

            for fold, (train_graph_dataset, val_graph_dataset) in enumerate(folds_list):
                print(f"Fold {fold + 1}")

                model = EGraphSAGE(node_in_channels=num_features,
                                   edge_in_channels=num_features,
                                   hidden_channels=hidden_dim,
                                   out_channels=num_classes).to(device)

                model.apply(init_weights)

                # Normalize to stabilize training
                class_weights = th.FloatTensor(CLASS_WEIGHTS).to(device)
                print("Class weights:", class_weights)

                criterion = nn.CrossEntropyLoss(weight=class_weights)
                optimizer = th.optim.Adam(model.parameters(), lr=lr)

                best_epoch_f1 = 0  # Track the best F1 score for this fold

                for epoch in range(epochs):
                    try:
                        total_train_loss = 0
                        total_val_loss = 0

                        model.train()
                        for G_pyg_train in tqdm(train_graph_dataset, desc="Training", leave=False):
                            G_pyg_train = G_pyg_train.to(device)

                            G_pyg_train.edge_label = G_pyg_train.edge_label.to(device)
                            G_pyg_train.edge_attr = G_pyg_train.edge_attr.to(device)
                            
                            out = model(G_pyg_train)
                            loss = criterion(out, G_pyg_train.edge_label)
                            total_train_loss += loss.item()

                            optimizer.zero_grad()
                            loss.backward()

                            optimizer.step()
                        
                        model.eval()
                        total_f1 = 0
                        with th.no_grad():
                            for G_pyg_val in tqdm(val_graph_dataset, desc="Validation", leave=False):

                                G_pyg_val.to(device)
                                G_pyg_val.edge_label = G_pyg_val.edge_label.to(device)
                                G_pyg_val.edge_attr = G_pyg_val.edge_attr.to(device)

                                out = model(G_pyg_val)
                                loss = criterion(out, G_pyg_val.edge_label)
                                total_val_loss += loss.item()

                                total_f1 += f1_score(G_pyg_val.edge_label.cpu(), out.argmax(dim=1).cpu(), average='weighted')

                        avg_f1 = total_f1 / len(val_graph_dataset)
                        if avg_f1 > best_epoch_f1:
                            best_epoch_f1 = avg_f1  # Update the best F1 score for this fold
                            print(f"Best F1 Score at epoch {epoch}: {best_epoch_f1:.4f}, Train Loss: {total_train_loss:.4f}, Validation Loss: {total_val_loss:.4f}, Parameters: lr={lr}, hidden_dim{hidden_dim}")

                        print(f'Epoch {epoch}, Train Loss: {total_train_loss:.4f}, Validation Loss: {total_val_loss:.4f}, Validation F1: {avg_f1:.4f}')

                    except Exception as e:
                        print(f"An error occurred at epoch {epoch}: {str(e)}")
                        break

                fold_f1_scores.append(best_epoch_f1)  # Append the best F1 score for this fold
            
            avg_f1 = sum(fold_f1_scores) / len(fold_f1_scores)
            params_results[(lr, hidden_dim)] = {'folds': fold_f1_scores, 'avg_f1': avg_f1}
            print(f"Average F1 Score for learning rate {lr}, hidden_dim {hidden_dim}: {avg_f1:.4f}")

            if avg_f1 > best_f1:
                best_f1 = avg_f1
                best_params = {'learning_rate': lr, 'hidden_dim': hidden_dim}

    print(f"Best Parameters: {best_params}, Best F1 Score: {best_f1:.4f}")
    print("All results:", params_results)


learning_rates = [0.005]
hidden_dims = [256, 512, 1024]

grid_search(train_graph_dataset, epochs=10, learning_rates=learning_rates, hidden_dims=hidden_dims)


Testing with learning rate: 0.005, hidden_dim: 256
Fold 1
Class weights: tensor([1.8445e+01, 2.7509e+01, 9.2468e+01, 3.0195e+00, 1.1090e+00, 2.0365e+00,
        2.2915e-01, 2.2255e-01, 3.5303e+00, 3.2679e+01, 2.8378e+02],
       device='cuda:0')


Training:   0%|          | 0/1445 [00:00<?, ?it/s]

                                                              

Best F1 Score at epoch 0: 0.8396, Train Loss: 1813.2893, Validation Loss: 233.7888, Parameters: lr=0.005, hidden_dim256
Epoch 0, Train Loss: 1813.2893, Validation Loss: 233.7888, Validation F1: 0.8396


                                                              

Epoch 1, Train Loss: 1521.4014, Validation Loss: 241.9680, Validation F1: 0.8391


                                                              

Epoch 2, Train Loss: 1375.7354, Validation Loss: 255.7132, Validation F1: 0.8310


                                                              

Best F1 Score at epoch 3: 0.8489, Train Loss: 1398.6467, Validation Loss: 234.7801, Parameters: lr=0.005, hidden_dim256
Epoch 3, Train Loss: 1398.6467, Validation Loss: 234.7801, Validation F1: 0.8489


                                                              

Epoch 4, Train Loss: 1423.4944, Validation Loss: 232.9382, Validation F1: 0.8433


                                                             

KeyboardInterrupt: 

In [None]:
# Extract the best parameters from the grid search
best_hidden_dim = 256  # Replace with the best hidden_dim found
best_learning_rate = 0.001  # Replace with the best learning_rate found

# Create the graph for the entire training dataset


# Initialize the model with the best parameters
model = EGraphSAGE(node_in_channels=num_features, 
                   edge_in_channels=num_features,
                   hidden_channels=best_hidden_dim,
                   out_channels=num_classes).to(device)

model.apply(init_weights)

# Normalize class weights
class_weights = th.FloatTensor(CLASS_WEIGHTS).to(device)
print("Class weights:", class_weights)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = th.optim.Adam(model.parameters(), lr=best_learning_rate)

lowest_loss = float('inf')
best_model_state = None

# Load checkpoint if exists
start_epoch = 0
epochs = 100

if os.path.exists(checkpoint_path):
    checkpoint = th.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    lowest_loss = checkpoint['lowest_loss']
    print(f"Resumed training from epoch {start_epoch}")

for epoch in range(start_epoch, epochs):
    model.train()

    total_loss = 0
    for G_pyg_train in tqdm(train_graph_dataset, desc="Training", leave=False):

        # Move the graph data to the device
        G_pyg_train = G_pyg_train.to(device)
        G_pyg_train.edge_label = G_pyg_train.edge_label.to(device)
        G_pyg_train.edge_attr = G_pyg_train.edge_attr.to(device)

        optimizer.zero_grad()
        out = model(G_pyg_train)
        loss = criterion(out, G_pyg_train.edge_label)
        
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    # Save the best model based on the lowest loss
    if total_loss < lowest_loss:  # Here, lowest_loss is used to track the lowest loss
        lowest_loss = total_loss
        best_model_state = model.state_dict()
        th.save(best_model_state, best_model_path)
        print("Saved best model. Lowest Loss:", lowest_loss)

    print(f'Epoch {epoch}, Loss: {total_loss:.4f}')

    # Save checkpoint
    th.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'lowest_loss': lowest_loss
    }, checkpoint_path)

# Save the trained model
print("Model training completed and saved.")

Class weights: tensor([1.8445e+01, 2.7509e+01, 9.2468e+01, 3.0195e+00, 1.1090e+00, 2.0365e+00,
        2.2915e-01, 2.2255e-01, 3.5303e+00, 3.2679e+01, 2.8378e+02],
       device='cuda:0')


                                                              

Saved best model. Lowest Loss: 1722.790598231113
Epoch 0, Loss: 1722.7906


                                                              

Saved best model. Lowest Loss: 1450.620578548294
Epoch 1, Loss: 1450.6206


                                                              

Saved best model. Lowest Loss: 1303.984880944688
Epoch 2, Loss: 1303.9849


                                                              

Saved best model. Lowest Loss: 1228.953192420356
Epoch 3, Loss: 1228.9532


                                                              

Saved best model. Lowest Loss: 1180.188838810619
Epoch 4, Loss: 1180.1888


                                                              

Saved best model. Lowest Loss: 1139.2070155418023
Epoch 5, Loss: 1139.2070


                                                              

Saved best model. Lowest Loss: 1121.0275495306862
Epoch 6, Loss: 1121.0275


                                                              

Saved best model. Lowest Loss: 1088.939328795434
Epoch 7, Loss: 1088.9393


                                                              

Saved best model. Lowest Loss: 1075.976038286139
Epoch 8, Loss: 1075.9760


                                                              

Saved best model. Lowest Loss: 1046.3192692346408
Epoch 9, Loss: 1046.3193


                                                              

Saved best model. Lowest Loss: 1033.120556626297
Epoch 10, Loss: 1033.1206


                                                              

Saved best model. Lowest Loss: 1024.8647757539704
Epoch 11, Loss: 1024.8648


                                                              

Saved best model. Lowest Loss: 1004.7317750002775
Epoch 12, Loss: 1004.7318


                                                              

Saved best model. Lowest Loss: 995.4333515589212
Epoch 13, Loss: 995.4334


                                                              

Saved best model. Lowest Loss: 990.2069995362152
Epoch 14, Loss: 990.2070


                                                              

Saved best model. Lowest Loss: 980.1719515753673
Epoch 15, Loss: 980.1720


                                                              

Saved best model. Lowest Loss: 979.3049681654712
Epoch 16, Loss: 979.3050


                                                              

Saved best model. Lowest Loss: 958.3948005879774
Epoch 17, Loss: 958.3948


                                                              

Epoch 18, Loss: 960.5073


                                                              

Epoch 19, Loss: 976.7364


                                                              

Saved best model. Lowest Loss: 935.939386700312
Epoch 20, Loss: 935.9394


                                                              

Epoch 21, Loss: 940.0386


                                                              

Epoch 22, Loss: 939.6366


                                                              

Epoch 23, Loss: 936.4316


                                                              

Saved best model. Lowest Loss: 925.791865437683
Epoch 24, Loss: 925.7919


                                                              

Epoch 25, Loss: 928.8374


                                                              

Saved best model. Lowest Loss: 915.0794921074121
Epoch 26, Loss: 915.0795


                                                              

Saved best model. Lowest Loss: 910.6001539880505
Epoch 27, Loss: 910.6002


                                                             

KeyboardInterrupt: 

In [None]:
th.save(model.state_dict(), best_model_path)

In [79]:
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

def eval(dataset, adversarial=False):

    best_model = EGraphSAGE(node_in_channels=num_features, 
                       edge_in_channels=num_features,
                       hidden_channels=best_hidden_dim, 
                       out_channels=num_classes).to(device)

    print("Loading model from", best_model_path)
    best_model.load_state_dict(th.load(best_model_path, weights_only=True))

    best_model.eval()

    print("inference start")
    with th.no_grad():
        all_pred_logits = []
        all_test_labels = []
        for G_pyg in tqdm(dataset, desc="Evaluation", leave=False):
            try:
                # Move the graph data to the device
                G_pyg = G_pyg.to(device)
                G_pyg.edge_label = G_pyg.edge_label.to(device)
                G_pyg.edge_attr = G_pyg.edge_attr.to(device)
                out = best_model(G_pyg)
                
            except Exception as forward_error:
                print(f"Error during forward/backward pass at {forward_error}")

            all_pred_logits.append(out.cpu())
            all_test_labels.append(G_pyg.edge_label.cpu())

        all_pred_logits = th.cat(all_pred_logits).to(device)
        all_test_labels = th.cat(all_test_labels).to(device)
        test_accuracy = compute_accuracy(all_pred_logits, all_test_labels)
        print(f'Test Accuracy: {test_accuracy:.4f}')

        pred_labels = all_pred_logits.argmax(dim=1).cpu()
        all_test_labels = all_test_labels.cpu()
    
    global class_map
    class_map_2 = class_map
    if adversarial:
        class_map_2 = np.append(class_map, "Adversarial")

    # Generate a report
    cm = confusion_matrix(all_test_labels, pred_labels, labels=range(len(class_map_2)))
    print(cm)
    report = classification_report(all_test_labels, pred_labels, target_names=class_map_2, digits=4)
    print(report)

eval(test_graph_dataset)


Loading model from /vol/bitbucket/shc20/FYP/GNN-Adversarial-Attack/Models/E_GraphSAGE/logs/UNSW_NB15/main_window/best_model_all_raw_downsampled.pth
inference start


                                                              

Test Accuracy: 0.8990
[[  1803   1206    329     93     14      1      0      0     13     19
       0]
 [     1   1585     43    152     54      0      0      0     70     86
      25]
 [     0      7    388      0      6      0      0      0      2     14
       0]
 [    45   9860    526   1899   2650    166     21      0    642    774
     689]
 [   454  11058   1102   2665  29188    681     15     23   2187   1887
    4939]
 [   112   1497    682    207    199  30906      0      0    199    105
     446]
 [   219   1453     48    381   2651    317 277054      8    359    818
     734]
 [   703      3      0     31    281   3889      0 199903     45     90
      39]
 [    79   1575     98    307   1394     69      1      0  14390    243
     821]
 [     0      1      0     11      3      8      0      0    150   1848
      14]
 [     1      0      0      0     30      3      0      0      0      5
     188]]
                precision    recall  f1-score   support

      Analysis    

In [None]:
def attack_attacker(dataframe, ratio, num_injected_nodes):
    attack_eval = dataframe[dataframe[label_col] != class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    num_injected = int(ratio * len(attack_eval))

    # Sample attack rows
    sampled_attack_flows = attack_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)
    
    injected_rows = sampled_attack_flows.copy()
    print("Sampled attack flows:", len(sampled_attack_flows))
    print("Labels of sampled attack flows:", sampled_attack_flows[label_col].value_counts())

    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] # Target the Real Attacker Nodes
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [f"{node_ips[i % len(node_ips)]}:{random.randint(1024, 65535)}" for i in range(num_injected)]
    # injected_rows['pkSeqID'] = [f'Injected-{i}' for i in range(num_injected)]
    injected_rows[label_col] = len(class_map) # Assign a new class for injected samples
    print(injected_rows[0:5])

    # Append and reorder
    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    return combined_df

# Inject adversarial samples
attack_attacker_df = attack_attacker(test_df, 0.1, num_injected_nodes=1)
eval(attack_attacker_df, adversarial=True)

In [None]:
def normalise_attacker(dataframe, ratio, num_injected_nodes):

    normal_eval = dataframe[dataframe[label_col] == class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    attack_eval = dataframe[dataframe[label_col] != class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]
    print("Normal Flows:", len(normal_eval))
    print("Attack Flows:", len(attack_eval))
    num_injected = int(ratio * len(attack_eval))


    sampled_normal_flows = normal_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)
    sampled_attack_flows = attack_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)

    print("Sampled attack flows:", len(sampled_attack_flows))
    print("Labels of sampled attack flows:", sampled_attack_flows[label_col].value_counts())

    injected_rows = sampled_normal_flows.copy()
    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = sampled_attack_flows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] # Direct BENGIN Traffic to the Real Attacker Nodes
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [f"{node_ips[i % len(node_ips)]}:{random.randint(1024, 65535)}" for i in range(num_injected)]
    injected_rows[label_col] = len(class_map)
    print(injected_rows[0:5])

    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    return combined_df

# Inject adversarial samples
normalise_attacker_df = normalise_attacker(test_df, 0.1, 1)
eval(normalise_attacker_df, adversarial=True)

In [None]:
def random_connection(dataframe, ratio, num_injected_nodes):

    normal_eval = dataframe[dataframe[label_col] == class_dict[UNSW_NB15_Config.BENIGN_CLASS_NAME]]

    num_injected = int(ratio * len(dataframe))
    print("injected rows:", num_injected)

    sampled_normal_flows = normal_eval.sample(n=num_injected, random_state=42).copy().reset_index(drop=True)

    injected_rows = sampled_normal_flows.copy()
    node_ips = [f"192.168.1.{i+1}" for i in range(num_injected_nodes)]
    injected_rows[UNSW_NB15_Config.DESTINATION_IP_COL_NAME] = [node_ips[i % len(node_ips)] for i in range(num_injected)]
    injected_rows[UNSW_NB15_Config.SOURCE_IP_COL_NAME] = [node_ips[(i + 1) % len(node_ips)] for i in range(num_injected)]
    injected_rows[label_col] = len(class_map)

    combined_df = pd.concat([dataframe, injected_rows], ignore_index=True)

    return combined_df

# Inject adversarial samples
random_connection_df = random_connection(test_df, 0.1, 1)
eval(model, random_connection_df, adversarial=True)