In [1]:
from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
# import seaborn as sns
# import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Datasets.UNSW_NB15.UNSW_NB15_config import UNSW_NB15_Config

In [2]:
csv_file_name = "all_raw_downsampled"

data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))

DATASET_NAME = "UNSW_NB15"

SOURCE_IP_COL_NAME = UNSW_NB15_Config.SOURCE_IP_COL_NAME
DESTINATION_IP_COL_NAME = UNSW_NB15_Config.DESTINATION_IP_COL_NAME
SOURCE_PORT_COL_NAME = UNSW_NB15_Config.SOURCE_PORT_COL_NAME
DESTINATION_PORT_COL_NAME = UNSW_NB15_Config.DESTINATION_PORT_COL_NAME

ATTACK_CLASS_COL_NAME = UNSW_NB15_Config.ATTACK_CLASS_COL_NAME
IS_ATTACK_COL_NAME = UNSW_NB15_Config.IS_ATTACK_COL_NAME

BENIGN_CLASS_NAME = UNSW_NB15_Config.BENIGN_CLASS_NAME

TIME_COLS = UNSW_NB15_Config.TIME_COL_NAMES

MULTICLASS = True
label_col = ATTACK_CLASS_COL_NAME if MULTICLASS else IS_ATTACK_COL_NAME

print(data[ATTACK_CLASS_COL_NAME].value_counts())
print(data[IS_ATTACK_COL_NAME].value_counts())

if MULTICLASS:
    data.drop(columns=[IS_ATTACK_COL_NAME], inplace=True)
else:
    data.drop(columns=[ATTACK_CLASS_COL_NAME], inplace=True)

checkpoint_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"main_window/checkpoints_{csv_file_name}.pth")
best_model_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, f"main_window/best_model_{csv_file_name}.pth")

os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
os.makedirs(os.path.dirname(best_model_path), exist_ok=True)

# Compute the class weights
CLASS_WEIGHTS = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(data[label_col]),
    y=data[label_col]
)

print("Class weights:", CLASS_WEIGHTS)

  data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))


attack_cat
Normal            221876
Generic           215481
Exploits           44525
Fuzzers            24246
DoS                16353
Reconnaissance     13987
Analysis            2677
Backdoor            1795
Shellcode           1511
Backdoors            534
Worms                174
Name: count, dtype: int64
label
1    321283
0    221876
Name: count, dtype: int64
Class weights: [1.84453085e+01 2.75086857e+01 9.24683350e+01 3.01951268e+00
 1.10899699e+00 2.03654586e+00 2.29152876e-01 2.22548139e-01
 3.53028461e+00 3.26790807e+01 2.83782132e+02]


In [3]:
data.drop(columns=UNSW_NB15_Config.DROP_COLS,inplace=True)
print(data.columns)

Index(['srcip', 'sport', 'dstip', 'dsport', 'state', 'dur', 'sbytes', 'dbytes',
       'sttl', 'dttl', 'sloss', 'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts',
       'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth',
       'res_bdy_len', 'Sjit', 'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt',
       'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl',
       'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src',
       'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm',
       'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'attack_cat'],
      dtype='object')


In [4]:
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME].apply(str)
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME].apply(str)

# # Combine Port and IP
data[SOURCE_PORT_COL_NAME] = data[SOURCE_PORT_COL_NAME].apply(str)
data[DESTINATION_PORT_COL_NAME] = data[DESTINATION_PORT_COL_NAME].apply(str)

data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME] + ':' + data[SOURCE_PORT_COL_NAME]
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME] + ':' + data[DESTINATION_PORT_COL_NAME]
data.drop(columns=[SOURCE_PORT_COL_NAME,DESTINATION_PORT_COL_NAME],inplace=True)

# data[SOURCE_PORT_COL_NAME] = pd.to_numeric(data[SOURCE_PORT_COL_NAME], errors='coerce').fillna(0).astype(int)
# data[DESTINATION_PORT_COL_NAME] = pd.to_numeric(data[DESTINATION_PORT_COL_NAME], errors='coerce').fillna(0).astype(int)

In [5]:
print(data.head)

<bound method NDFrame.head of                      srcip               dstip state        dur  sbytes  \
0             10.40.85.1:0         224.0.0.5:0   INT  50.004341     384   
1          59.166.0.6:2142    149.171.126.4:53   CON   0.001134     132   
2       175.45.176.0:13284   149.171.126.16:80   FIN   2.390390    1362   
3         59.166.0.3:42587    149.171.126.8:25   FIN  34.077175   37358   
4            10.40.170.2:0       10.40.170.2:0   INT   0.000000      46   
...                    ...                 ...   ...        ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543155  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543156    59.166.0.5:53521    149.171.126.7:21   CON   1.086072    1940   
543157  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   
543158  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   

        dbytes  sttl  dttl  sloss  dloss  ...  is_ftp_login  ct_ftp_c

In [6]:
data = pd.get_dummies(data, columns = UNSW_NB15_Config.CATEGORICAL_COLS) # One Hot Encoding for categorical data
converted_categorical_cols = [col for col in data.columns if col.startswith(tuple(UNSW_NB15_Config.CATEGORICAL_COLS))]

In [7]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of                      srcip               dstip        dur  sbytes  dbytes  \
0             10.40.85.1:0         224.0.0.5:0  50.004341     384       0   
1          59.166.0.6:2142    149.171.126.4:53   0.001134     132     164   
2       175.45.176.0:13284   149.171.126.16:80   2.390390    1362     268   
3         59.166.0.3:42587    149.171.126.8:25  34.077175   37358    3380   
4            10.40.170.2:0       10.40.170.2:0   0.000000      46       0   
...                    ...                 ...        ...     ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543155  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543156    59.166.0.5:53521    149.171.126.7:21   1.086072    1940    2404   
543157  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   
543158  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   

        sttl  dttl  sloss  dloss         Sloa

In [8]:
scaler = StandardScaler()
cols_to_norm = UNSW_NB15_Config.COLS_TO_NORM
print(data[cols_to_norm].describe()) # Check if there's any too large value

                 dur        sbytes        dbytes           sttl  \
count  543159.000000  5.431590e+05  5.431590e+05  543159.000000   
mean        0.706760  5.136572e+03  1.936909e+04     157.197364   
std        12.637229  1.202311e+05  1.390925e+05     108.452474   
min         0.000000  0.000000e+00  0.000000e+00       0.000000   
25%         0.000007  1.140000e+02  0.000000e+00      31.000000   
50%         0.000011  2.000000e+02  0.000000e+00     254.000000   
75%         0.072088  1.580000e+03  1.940000e+03     254.000000   
max      8760.776367  1.435577e+07  1.465753e+07     255.000000   

                dttl          sloss          dloss         Sload  \
count  543159.000000  543159.000000  543159.000000  5.431590e+05   
mean       38.850764       3.800661       8.729770  6.877595e+07   
std        77.034389      45.616565      50.136204  1.420534e+08   
min         0.000000       0.000000       0.000000  0.000000e+00   
25%         0.000000       0.000000       0.000000  3.70

In [9]:
def check_numeric_issues(df, cols_to_norm):
    for col in cols_to_norm:
        try:
            # Try to coerce to numeric
            df[col] = pd.to_numeric(df[col], errors='coerce')
            
            # Try to clip the column
            df[col] = df[col].clip(lower=-1e9, upper=1e9)
            
        except Exception as e:
            print(f"❌ Column '{col}' failed with error: {e}")
            print(f"  - Sample values: {df[col].dropna().unique()[:5]}")
            print(f"  - Data type: {df[col].dtype}")
            continue

    print("\n✅ All other columns processed successfully.")

check_numeric_issues(data, UNSW_NB15_Config.COLS_TO_NORM)


✅ All other columns processed successfully.


In [10]:
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [11]:
from sklearn.preprocessing import LabelEncoder

num_classes = 2
class_map = [0, 1]
if MULTICLASS:
    le = LabelEncoder()
    attack_labels = le.fit_transform(data[ATTACK_CLASS_COL_NAME])
    class_map = le.classes_
    print(class_map)
    print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))
    data[ATTACK_CLASS_COL_NAME] = attack_labels
    num_classes = len(class_map)
    class_dict = {le.inverse_transform([i])[0]: i for i in range(len(le.classes_))}


['Analysis' 'Backdoor' 'Backdoors' 'DoS' 'Exploits' 'Fuzzers' 'Generic'
 'Normal' 'Reconnaissance' 'Shellcode' 'Worms']
Attack label mapping: {'Analysis': 0, 'Backdoor': 1, 'Backdoors': 2, 'DoS': 3, 'Exploits': 4, 'Fuzzers': 5, 'Generic': 6, 'Normal': 7, 'Reconnaissance': 8, 'Shellcode': 9, 'Worms': 10}


In [12]:
# # Maintain the order of the rows in the original dataframe

feature_cols = UNSW_NB15_Config.COLS_TO_NORM + converted_categorical_cols

print('Feature Columns:', feature_cols)
num_features = len(feature_cols)
print('Number of Features:', num_features)

data['h'] = data[ feature_cols ].values.tolist()

Feature Columns: ['dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss', 'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit', 'Djit', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'state_ACC', 'state_CLO', 'state_CON', 'state_ECO', 'state_FIN', 'state_INT', 'state_MAS', 'state_PAR', 'state_REQ', 'state_RST', 'state_TST', 'state_TXD', 'state_URH']
Number of Features: 51


In [13]:
def create_graph(df, source_ip_col, destination_ip_col, edge_attr, create_using=nx.MultiDiGraph(), **kwargs):
    G_nx = nx.from_pandas_edgelist(df, source_ip_col, destination_ip_col, edge_attr, create_using=create_using, **kwargs)
    G_pyg = from_networkx(G_nx)

    num_nodes = G_pyg.num_nodes
    num_edges = G_pyg.num_edges

    G_pyg.x = th.ones(num_nodes, len(df['h'].iloc[0])) 

    edge_attr_list = []
    edge_label_list = []

    for u, v, key, data in G_nx.edges(keys=True, data=True):
        edge_attr_list.append(data['h']) 
        edge_label_list.append(data[label_col]) 

    G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
    G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)

    return G_pyg

In [34]:
from sklearn.preprocessing import MultiLabelBinarizer

from tqdm import tqdm

class StratifiedGraphDataset:

    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.total_count = len(self.y)

    def k_fold_split(self, k: int = 5, test_ratio: float = 0.15, random_state: int = 42):
        cv = MultilabelStratifiedShuffleSplit(test_size=test_ratio, random_state=random_state, n_splits=k)

        mlb = MultiLabelBinarizer()

        y_binary = mlb.fit_transform(self.y)

        return cv.split(np.zeros(len(self.X)), y_binary)

    def graph_train_test_split(self, test_ratio: float = 0.15, random_state: int = 42):
        train_idx, test_idx = next(self.k_fold_split(k = 1, test_ratio = test_ratio, random_state = random_state))
        
        X_train = [self.X[i] for i in train_idx]
        X_test = [self.X[i] for i in test_idx]

        y_train = [self.y[i] for i in train_idx]
        y_test = [self.y[i] for i in test_idx]

        return StratifiedGraphDataset(X_train, y_train), StratifiedGraphDataset(X_test, y_test)
    
    def __len__(self):
        return self.total_count

    def __iter__(self):
        for g in self.X:
            yield g


def generate_graph_datasets(
    df: pd.DataFrame, 
    window_size: int = 2000, 
    overlap_ratio: float = 0.5, 
    time_cols=TIME_COLS, 
    label_col=label_col,
    build_graph_func=create_graph,
    build_graph_func_kwargs: dict = None):

    print("Feature Columns", df.columns)
    print("Time Columns", time_cols)
    assert all(col in df.columns for col in time_cols), "All timestamp columns are required"
    assert label_col in df.columns, "Edge label column 'label' is required"
    
    df = df.sort_values(time_cols).reset_index(drop=True)
    df.drop(columns=time_cols, inplace=True)
    window_size = window_size
    stride = int(window_size * (1 - overlap_ratio))
    
    X = []
    y = []

    for start in tqdm(range(0, len(df), stride), desc="Generating graphs"):
        window_df = df.iloc[start: min(start + window_size, len(df))]

        window_df = df.iloc[start: start + window_size]
        contains_label = window_df[label_col].unique()
        G_pyg = build_graph_func(window_df, **build_graph_func_kwargs)

        X.append(G_pyg)
        y.append(contains_label.tolist())

    print(X)
    print(y)

    return StratifiedGraphDataset(X, y)

In [35]:
graph_dataset = generate_graph_datasets(data, 
                                        overlap_ratio=0,
                                        build_graph_func_kwargs={
                                            'source_ip_col': SOURCE_IP_COL_NAME, 
                                            'destination_ip_col': DESTINATION_IP_COL_NAME, 
                                            'edge_attr': ['h', label_col], 
                                            'create_using': nx.MultiDiGraph()})
train_graph_dataset, test_graph_dataset = graph_dataset.graph_train_test_split(test_ratio=0.15, random_state=42)
print("Number of training graphs:", len(train_graph_dataset))
print("Number of testing graphs:", len(test_graph_dataset))

Feature Columns Index(['srcip', 'dstip', 'dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss',
       'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb',
       'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit',
       'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack',
       'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd',
       'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm',
       'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm',
       'attack_cat', 'state_ACC', 'state_CLO', 'state_CON', 'state_ECO',
       'state_FIN', 'state_INT', 'state_MAS', 'state_PAR', 'state_REQ',
       'state_RST', 'state_TST', 'state_TXD', 'state_URH', 'h'],
      dtype='object')
Time Columns ['Stime', 'Ltime']


Generating graphs: 100%|██████████| 272/272 [00:20<00:00, 13.22it/s]

[Data(edge_index=[2, 2000], h=[2000, 51], attack_cat=[2000], num_nodes=2530, x=[2530, 51], edge_attr=[2000, 51], edge_label=[2000]), Data(edge_index=[2, 2000], h=[2000, 51], attack_cat=[2000], num_nodes=2382, x=[2382, 51], edge_attr=[2000, 51], edge_label=[2000]), Data(edge_index=[2, 2000], h=[2000, 51], attack_cat=[2000], num_nodes=2418, x=[2418, 51], edge_attr=[2000, 51], edge_label=[2000]), Data(edge_index=[2, 2000], h=[2000, 51], attack_cat=[2000], num_nodes=2534, x=[2534, 51], edge_attr=[2000, 51], edge_label=[2000]), Data(edge_index=[2, 2000], h=[2000, 51], attack_cat=[2000], num_nodes=2662, x=[2662, 51], edge_attr=[2000, 51], edge_label=[2000]), Data(edge_index=[2, 2000], h=[2000, 51], attack_cat=[2000], num_nodes=2608, x=[2608, 51], edge_attr=[2000, 51], edge_label=[2000]), Data(edge_index=[2, 2000], h=[2000, 51], attack_cat=[2000], num_nodes=1677, x=[1677, 51], edge_attr=[2000, 51], edge_label=[2000]), Data(edge_index=[2, 2000], h=[2000, 51], attack_cat=[2000], num_nodes=366, 




In [36]:
class SAGELayerPyG(MessagePassing):
    def __init__(self, in_channels, edge_dim, out_channels, activation=F.relu):
        super().__init__(aggr='mean')  # mean aggregation
        self.W_msg = nn.Linear(in_channels + edge_dim, out_channels)
        self.W_apply = nn.Linear(in_channels + out_channels, out_channels)
        self.activation = activation

    def forward(self, x, edge_index, edge_attr):
        # x: [num_nodes, in_channels]
        # edge_attr: [num_edges, edge_dim]
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j: features of source nodes (neighbours)
        msg_input = th.cat([x_j, edge_attr], dim=1)
        return self.W_msg(msg_input)

    def update(self, aggr_out, x):
        # aggr_out: [num_nodes, out_channels]
        combined = th.cat([x, aggr_out], dim=1)
        out = self.W_apply(combined)
        return self.activation(out)
    
class MLPPredictor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MLPPredictor, self).__init__()
        self.lin = nn.Linear(in_channels * 2, out_channels)

    def forward(self, data, z):
        row, col = data.edge_index
        # Concatenate the features of source and target nodes for each edge
        edge_feat = th.cat([z[row], z[col]], dim=1)
        return self.lin(edge_feat)

class EGraphSAGE(nn.Module):
    def __init__(self, node_in_channels, edge_in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = SAGELayerPyG(node_in_channels, edge_in_channels, hidden_channels)
        self.conv2 = SAGELayerPyG(hidden_channels, edge_in_channels, hidden_channels)
        self.mlp_predictor = MLPPredictor(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.2)
        x = self.conv2(x, edge_index, edge_attr)
        return self.mlp_predictor(data, x)

In [37]:
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
print(device)

cuda:0


In [38]:
th.cuda.empty_cache()

In [39]:


def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)



In [40]:
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

def grid_search(graph_dataset, epochs, learning_rates, hidden_dims, folds=3):
    global CLASS_WEIGHTS
    global num_features
    
    best_params = {}
    best_f1 = 0

    # Precompute the train and validation graphs for all folds
    folds_list = []
    for i in range(folds):
        train_graph_dataset, val_graph_dataset = graph_dataset.graph_train_test_split(test_ratio=0.15, random_state=i)
        folds_list.append((train_graph_dataset, val_graph_dataset))

    params_results = {}
    for lr in learning_rates:
        for hidden_dim in hidden_dims:
            print(f"Testing with learning rate: {lr}, hidden_dim: {hidden_dim}")
            fold_f1_scores = []

            for fold, (train_graph_dataset, val_graph_dataset) in enumerate(folds_list):
                print(f"Fold {fold + 1}")

                model = EGraphSAGE(node_in_channels=num_features,
                                   edge_in_channels=num_features,
                                   hidden_channels=hidden_dim,
                                   out_channels=num_classes).to(device)

                model.apply(init_weights)

                # Normalize to stabilize training
                class_weights = th.FloatTensor(CLASS_WEIGHTS).to(device)
                print("Class weights:", class_weights)

                criterion = nn.CrossEntropyLoss(weight=class_weights)
                optimizer = th.optim.Adam(model.parameters(), lr=lr)

                best_epoch_f1 = 0  # Track the best F1 score for this fold

                for epoch in range(epochs):
                    try:
                        total_train_loss = 0
                        total_val_loss = 0

                        model.train()
                        for G_pyg_train in tqdm(train_graph_dataset, desc="Training", leave=False):
                            G_pyg_train = G_pyg_train.to(device)

                            G_pyg_train.edge_label = G_pyg_train.edge_label.to(device)
                            G_pyg_train.edge_attr = G_pyg_train.edge_attr.to(device)
                            
                            out = model(G_pyg_train)
                            loss = criterion(out, G_pyg_train.edge_label)
                            total_train_loss += loss.item()

                            optimizer.zero_grad()
                            loss.backward()

                            optimizer.step()
                        
                        model.eval()
                        total_f1 = 0
                        with th.no_grad():
                            for G_pyg_val in tqdm(val_graph_dataset, desc="Validation", leave=False):

                                G_pyg_val.to(device)
                                G_pyg_val.edge_label = G_pyg_val.edge_label.to(device)
                                G_pyg_val.edge_attr = G_pyg_val.edge_attr.to(device)

                                out = model(G_pyg_val)
                                loss = criterion(out, G_pyg_val.edge_label)
                                total_val_loss += loss.item()

                                total_f1 += f1_score(G_pyg_val.edge_label.cpu(), out.argmax(dim=1).cpu(), average='weighted')

                        avg_f1 = total_f1 / len(val_graph_dataset)
                        if avg_f1 > best_epoch_f1:
                            best_epoch_f1 = avg_f1  # Update the best F1 score for this fold
                            print(f"Best F1 Score at epoch {epoch}: {best_epoch_f1:.4f}, Train Loss: {total_train_loss:.4f}, Validation Loss: {total_val_loss:.4f}, Parameters: lr={lr}, hidden_dim{hidden_dim}")

                        print(f'Epoch {epoch}, Train Loss: {total_train_loss:.4f}, Validation Loss: {total_val_loss:.4f}, Validation F1: {avg_f1:.4f}')

                    except Exception as e:
                        print(f"An error occurred at epoch {epoch}: {str(e)}")
                        break

                fold_f1_scores.append(best_epoch_f1)  # Append the best F1 score for this fold
            
            avg_f1 = sum(fold_f1_scores) / len(fold_f1_scores)
            params_results[(lr, hidden_dim)] = {'folds': fold_f1_scores, 'avg_f1': avg_f1}
            print(f"Average F1 Score for learning rate {lr}, hidden_dim {hidden_dim}: {avg_f1:.4f}")

            if avg_f1 > best_f1:
                best_f1 = avg_f1
                best_params = {'learning_rate': lr, 'hidden_dim': hidden_dim}

    print(f"Best Parameters: {best_params}, Best F1 Score: {best_f1:.4f}")
    print("All results:", params_results)


learning_rates = [0.005]
hidden_dims = [256, 512, 1024]

grid_search(train_graph_dataset, epochs=10, learning_rates=learning_rates, hidden_dims=hidden_dims)


Testing with learning rate: 0.005, hidden_dim: 256
Fold 1
Class weights: tensor([1.8445e+01, 2.7509e+01, 9.2468e+01, 3.0195e+00, 1.1090e+00, 2.0365e+00,
        2.2915e-01, 2.2255e-01, 3.5303e+00, 3.2679e+01, 2.8378e+02],
       device='cuda:0')


Training:   0%|          | 0/196 [00:00<?, ?it/s]

                                                            

Best F1 Score at epoch 0: 0.9131, Train Loss: 316.0273, Validation Loss: 41.3991, Parameters: lr=0.005, hidden_dim256
Epoch 0, Train Loss: 316.0273, Validation Loss: 41.3991, Validation F1: 0.9131


                                                            

Best F1 Score at epoch 1: 0.9174, Train Loss: 221.3896, Validation Loss: 34.7163, Parameters: lr=0.005, hidden_dim256
Epoch 1, Train Loss: 221.3896, Validation Loss: 34.7163, Validation F1: 0.9174


                                                            

Best F1 Score at epoch 2: 0.9193, Train Loss: 280.0376, Validation Loss: 43.1980, Parameters: lr=0.005, hidden_dim256
Epoch 2, Train Loss: 280.0376, Validation Loss: 43.1980, Validation F1: 0.9193


                                                            

Epoch 3, Train Loss: 219.7124, Validation Loss: 34.8417, Validation F1: 0.9127


                                                            

Epoch 4, Train Loss: 248.2892, Validation Loss: 37.7054, Validation F1: 0.9098


                                                            

Epoch 5, Train Loss: 260.3529, Validation Loss: 42.0671, Validation F1: 0.8900


                                                            

Epoch 6, Train Loss: 233.2375, Validation Loss: 41.9801, Validation F1: 0.8947


                                                            

Epoch 7, Train Loss: 250.7881, Validation Loss: 40.4123, Validation F1: 0.8989


                                                            

Epoch 8, Train Loss: 225.6525, Validation Loss: 37.0919, Validation F1: 0.9000


                                                            

Epoch 9, Train Loss: 215.1970, Validation Loss: 37.8326, Validation F1: 0.9011
Fold 2
Class weights: tensor([1.8445e+01, 2.7509e+01, 9.2468e+01, 3.0195e+00, 1.1090e+00, 2.0365e+00,
        2.2915e-01, 2.2255e-01, 3.5303e+00, 3.2679e+01, 2.8378e+02],
       device='cuda:0')


                                                            

Best F1 Score at epoch 0: 0.8453, Train Loss: 351.5944, Validation Loss: 39.8916, Parameters: lr=0.005, hidden_dim256
Epoch 0, Train Loss: 351.5944, Validation Loss: 39.8916, Validation F1: 0.8453


                                                            

Best F1 Score at epoch 1: 0.8565, Train Loss: 229.2264, Validation Loss: 51.3153, Parameters: lr=0.005, hidden_dim256
Epoch 1, Train Loss: 229.2264, Validation Loss: 51.3153, Validation F1: 0.8565


                                                            

Epoch 2, Train Loss: 236.3446, Validation Loss: 37.2584, Validation F1: 0.8557


                                                            

Best F1 Score at epoch 3: 0.8721, Train Loss: 214.6132, Validation Loss: 35.3189, Parameters: lr=0.005, hidden_dim256
Epoch 3, Train Loss: 214.6132, Validation Loss: 35.3189, Validation F1: 0.8721


                                                            

Epoch 4, Train Loss: 197.1173, Validation Loss: 32.4649, Validation F1: 0.8638


                                                            

Epoch 5, Train Loss: 248.1978, Validation Loss: 39.2082, Validation F1: 0.8404


                                                            

Epoch 6, Train Loss: 235.1632, Validation Loss: 38.9911, Validation F1: 0.8508


                                                            

Epoch 7, Train Loss: 209.7893, Validation Loss: 34.7382, Validation F1: 0.8571


                                                            

Epoch 8, Train Loss: 192.4975, Validation Loss: 32.6409, Validation F1: 0.8612


                                                           

KeyboardInterrupt: 

In [44]:
# Extract the best parameters from the grid search
best_hidden_dim = 256  # Replace with the best hidden_dim found
best_learning_rate = 0.001  # Replace with the best learning_rate found

print("Number of train graphs: ", len(train_graph_dataset))

# Initialize the model with the best parameters
model = EGraphSAGE(node_in_channels=num_features, 
                   edge_in_channels=num_features,
                   hidden_channels=best_hidden_dim,
                   out_channels=num_classes).to(device)

model.apply(init_weights)

# Normalize class weights
class_weights = th.FloatTensor(CLASS_WEIGHTS).to(device)
print("Class weights:", class_weights)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = th.optim.Adam(model.parameters(), lr=best_learning_rate)

lowest_loss = float('inf')
best_model_state = None

# Load checkpoint if exists
start_epoch = 0
epochs = 100

if os.path.exists(checkpoint_path):
    checkpoint = th.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    lowest_loss = checkpoint['lowest_loss']
    print(f"Resumed training from epoch {start_epoch}")

for epoch in range(start_epoch, epochs):
    model.train()

    total_loss = 0
    for G_pyg_train in tqdm(train_graph_dataset, desc="Training", leave=False):

        # Move the graph data to the device
        G_pyg_train = G_pyg_train.to(device)
        G_pyg_train.edge_label = G_pyg_train.edge_label.to(device)
        G_pyg_train.edge_attr = G_pyg_train.edge_attr.to(device)

        optimizer.zero_grad()
        out = model(G_pyg_train)
        loss = criterion(out, G_pyg_train.edge_label)
        
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    # Save the best model based on the lowest loss
    if total_loss < lowest_loss:  # Here, lowest_loss is used to track the lowest loss
        lowest_loss = total_loss
        best_model_state = model.state_dict()
        th.save(best_model_state, best_model_path)
        print("Saved best model. Lowest Loss:", lowest_loss)

    print(f'Epoch {epoch}, Loss: {total_loss:.4f}')

    # Save checkpoint
    th.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'lowest_loss': lowest_loss
    }, checkpoint_path)

# Save the trained model
print("Model training completed and saved.")

Number of train graphs:  231
Class weights: tensor([1.8445e+01, 2.7509e+01, 9.2468e+01, 3.0195e+00, 1.1090e+00, 2.0365e+00,
        2.2915e-01, 2.2255e-01, 3.5303e+00, 3.2679e+01, 2.8378e+02],
       device='cuda:0')
Resumed training from epoch 2


Training:   0%|          | 0/231 [00:00<?, ?it/s]

                                                            

Saved best model. Lowest Loss: 205.85997266638515
Epoch 2, Loss: 205.8600


                                                            

Saved best model. Lowest Loss: 198.5308462749963
Epoch 3, Loss: 198.5308


                                                            

Saved best model. Lowest Loss: 190.5529824016021
Epoch 4, Loss: 190.5530


                                                            

Saved best model. Lowest Loss: 187.81704581041186
Epoch 5, Loss: 187.8170


                                                            

Saved best model. Lowest Loss: 182.94741404174601
Epoch 6, Loss: 182.9474


                                                            

Saved best model. Lowest Loss: 181.86611750505693
Epoch 7, Loss: 181.8661


                                                            

Saved best model. Lowest Loss: 177.8723961001374
Epoch 8, Loss: 177.8724


                                                            

Epoch 9, Loss: 178.2035


                                                            

Saved best model. Lowest Loss: 173.689441958702
Epoch 10, Loss: 173.6894


                                                            

Saved best model. Lowest Loss: 171.77075338117638
Epoch 11, Loss: 171.7708


                                                            

Saved best model. Lowest Loss: 170.24774501076342
Epoch 12, Loss: 170.2477


                                                            

Saved best model. Lowest Loss: 168.90820198744325
Epoch 13, Loss: 168.9082


                                                            

Saved best model. Lowest Loss: 164.93952312786388
Epoch 14, Loss: 164.9395


                                                            

Saved best model. Lowest Loss: 162.7693509720039
Epoch 15, Loss: 162.7694


                                                            

Saved best model. Lowest Loss: 161.92309366298161
Epoch 16, Loss: 161.9231


                                                            

Epoch 17, Loss: 164.9943


                                                            

Saved best model. Lowest Loss: 161.3563362832986
Epoch 18, Loss: 161.3563


                                                            

Saved best model. Lowest Loss: 157.84928398214146
Epoch 19, Loss: 157.8493


                                                            

Saved best model. Lowest Loss: 154.13038092648318
Epoch 20, Loss: 154.1304


                                                            

Epoch 21, Loss: 155.9118


                                                            

Epoch 22, Loss: 154.6428


                                                            

Saved best model. Lowest Loss: 152.1936210675257
Epoch 23, Loss: 152.1936


                                                            

Saved best model. Lowest Loss: 149.3834681076325
Epoch 24, Loss: 149.3835


                                                            

Epoch 25, Loss: 150.2028


                                                            

Epoch 26, Loss: 151.4116


                                                            

Epoch 27, Loss: 149.9590


                                                            

Saved best model. Lowest Loss: 145.69245694624283
Epoch 28, Loss: 145.6925


                                                            

Saved best model. Lowest Loss: 144.16323033155112
Epoch 29, Loss: 144.1632


                                                            

Epoch 30, Loss: 145.7490


                                                            

Saved best model. Lowest Loss: 142.42284342605748
Epoch 31, Loss: 142.4228


                                                            

Saved best model. Lowest Loss: 140.4948887602967
Epoch 32, Loss: 140.4949


                                                            

Saved best model. Lowest Loss: 138.9841586634858
Epoch 33, Loss: 138.9842


                                                            

Epoch 34, Loss: 139.2231


                                                            

Saved best model. Lowest Loss: 136.1045540743787
Epoch 35, Loss: 136.1046


                                                            

Epoch 36, Loss: 142.3975


                                                            

Saved best model. Lowest Loss: 134.96302094593423
Epoch 37, Loss: 134.9630


                                                            

Epoch 38, Loss: 136.0231


                                                            

Saved best model. Lowest Loss: 134.39009580527105
Epoch 39, Loss: 134.3901


                                                            

Epoch 40, Loss: 134.5490


                                                            

Saved best model. Lowest Loss: 133.24340115670532
Epoch 41, Loss: 133.2434


                                                            

Saved best model. Lowest Loss: 130.24287296859242
Epoch 42, Loss: 130.2429


                                                            

Epoch 43, Loss: 130.3389


                                                            

Epoch 44, Loss: 149.7709


                                                            

Epoch 45, Loss: 134.5361


                                                            

Saved best model. Lowest Loss: 129.77003862615763
Epoch 46, Loss: 129.7700


                                                            

Saved best model. Lowest Loss: 128.02166829517768
Epoch 47, Loss: 128.0217


                                                            

Saved best model. Lowest Loss: 126.94937335503266
Epoch 48, Loss: 126.9494


                                                            

Saved best model. Lowest Loss: 126.05136575173641
Epoch 49, Loss: 126.0514


                                                            

Saved best model. Lowest Loss: 124.60657049660281
Epoch 50, Loss: 124.6066


                                                            

Epoch 51, Loss: 132.4133


                                                            

Epoch 52, Loss: 128.4624


                                                            

Epoch 53, Loss: 126.4392


                                                            

Saved best model. Lowest Loss: 123.99367274388982
Epoch 54, Loss: 123.9937


                                                            

Saved best model. Lowest Loss: 123.94449099977308
Epoch 55, Loss: 123.9445


                                                            

Saved best model. Lowest Loss: 123.31996661040067
Epoch 56, Loss: 123.3200


                                                            

Saved best model. Lowest Loss: 123.23184221528486
Epoch 57, Loss: 123.2318


                                                            

Epoch 58, Loss: 123.8234


                                                            

Saved best model. Lowest Loss: 122.74235260021553
Epoch 59, Loss: 122.7424


                                                            

Epoch 60, Loss: 123.8818


                                                            

Epoch 61, Loss: 124.0355


                                                            

Epoch 62, Loss: 123.0241


                                                            

Epoch 63, Loss: 123.1453


                                                            

Saved best model. Lowest Loss: 121.72864008365742
Epoch 64, Loss: 121.7286


                                                            

Epoch 65, Loss: 129.6205


                                                            

Epoch 66, Loss: 122.5817


                                                            

Epoch 67, Loss: 122.9991


                                                            

Saved best model. Lowest Loss: 120.86050204053112
Epoch 68, Loss: 120.8605


                                                            

Saved best model. Lowest Loss: 119.92221228322987
Epoch 69, Loss: 119.9222


                                                            

Epoch 70, Loss: 120.6253


                                                            

Epoch 71, Loss: 120.3192


                                                            

Epoch 72, Loss: 120.2114


                                                            

Saved best model. Lowest Loss: 119.61392538402963
Epoch 73, Loss: 119.6139


                                                            

Epoch 74, Loss: 120.5759


                                                            

Saved best model. Lowest Loss: 118.46795713229801
Epoch 75, Loss: 118.4680


                                                            

Epoch 76, Loss: 119.0346


                                                            

Saved best model. Lowest Loss: 116.72479962429064
Epoch 77, Loss: 116.7248


                                                            

Saved best model. Lowest Loss: 116.15532604922359
Epoch 78, Loss: 116.1553


                                                            

Epoch 79, Loss: 119.4063


                                                            

Epoch 80, Loss: 121.9028


                                                            

Epoch 81, Loss: 116.9083


                                                            

Epoch 82, Loss: 118.0421


                                                            

Epoch 83, Loss: 119.0530


                                                            

Saved best model. Lowest Loss: 115.19917683588221
Epoch 84, Loss: 115.1992


                                                            

Saved best model. Lowest Loss: 114.16575751947828
Epoch 85, Loss: 114.1658


                                                            

Saved best model. Lowest Loss: 113.37052055385264
Epoch 86, Loss: 113.3705


                                                            

Epoch 87, Loss: 114.2373


                                                            

Epoch 88, Loss: 115.5538


                                                            

Epoch 89, Loss: 114.5814


                                                            

Epoch 90, Loss: 115.1920


                                                            

Saved best model. Lowest Loss: 112.96326940480972
Epoch 91, Loss: 112.9633


                                                            

Epoch 92, Loss: 116.9985


                                                            

Epoch 93, Loss: 114.3252


                                                            

Saved best model. Lowest Loss: 112.47395187413515
Epoch 94, Loss: 112.4740


                                                            

Saved best model. Lowest Loss: 111.566766040054
Epoch 95, Loss: 111.5668


                                                            

Epoch 96, Loss: 115.0825


                                                            

Epoch 97, Loss: 113.2589


                                                            

Epoch 98, Loss: 113.5656


                                                            

Epoch 99, Loss: 112.0957
Model training completed and saved.


In [45]:
th.save(model.state_dict(), best_model_path)

In [46]:
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

def eval(dataset, adversarial=False):

    best_model = EGraphSAGE(node_in_channels=num_features, 
                       edge_in_channels=num_features,
                       hidden_channels=best_hidden_dim, 
                       out_channels=num_classes).to(device)

    print("Loading model from", best_model_path)
    best_model.load_state_dict(th.load(best_model_path, weights_only=True))

    best_model.eval()

    print("inference start")
    with th.no_grad():
        all_pred_logits = []
        all_test_labels = []
        for G_pyg in tqdm(dataset, desc="Evaluation", leave=False):
            try:
                # Move the graph data to the device
                G_pyg = G_pyg.to(device)
                G_pyg.edge_label = G_pyg.edge_label.to(device)
                G_pyg.edge_attr = G_pyg.edge_attr.to(device)
                out = best_model(G_pyg)
                
            except Exception as forward_error:
                print(f"Error during forward/backward pass at {forward_error}")

            all_pred_logits.append(out.cpu())
            all_test_labels.append(G_pyg.edge_label.cpu())

        all_pred_logits = th.cat(all_pred_logits).to(device)
        all_test_labels = th.cat(all_test_labels).to(device)
        test_accuracy = compute_accuracy(all_pred_logits, all_test_labels)
        print(f'Test Accuracy: {test_accuracy:.4f}')

        pred_labels = all_pred_logits.argmax(dim=1).cpu()
        all_test_labels = all_test_labels.cpu()
    
    global class_map
    class_map_2 = class_map
    if adversarial:
        class_map_2 = np.append(class_map, "Adversarial")

    # Generate a report
    cm = confusion_matrix(all_test_labels, pred_labels, labels=range(len(class_map_2)))
    print(cm)
    report = classification_report(all_test_labels, pred_labels, target_names=class_map_2, digits=4)
    print(report)

print("Number of train graphs: ", len(train_graph_dataset))
eval(train_graph_dataset)
print("Number of test graphs: ", len(test_graph_dataset))
eval(test_graph_dataset)


Number of train graphs:  231
Loading model from /vol/bitbucket/shc20/FYP/GNN-Adversarial-Attack/Models/E_GraphSAGE/logs/UNSW_NB15/main_window/best_model_all_raw_downsampled.pth
inference start


Evaluation:  23%|██▎       | 53/231 [00:00<00:00, 528.40it/s]

                                                              

Test Accuracy: 0.8888
[[   506   1292    242     76      5      1      0      0      3      8
       2]
 [     1   1387      0     15     12      7      0      0     10     33
      22]
 [     0     15    263      5      3      5      0      0      1      2
       5]
 [    26   9698    280    907   1279    150     26      0    178    375
     869]
 [   233  12275    562   1223  15131    609     16      0   1016    902
    5515]
 [    27   1482    494    100    114  17270      5      0     33     41
     765]
 [    95   1539     10    242   1335    173 179408      1    149    212
    1032]
 [   251     17      0      6     99   2114      2 186360     62     31
      54]
 [    23   1714     12    112    531     86      5      0   8201    245
     924]
 [     0     86      0      0      8     33      0      0     72   1054
      29]
 [     1      0      0      1     15      5      0      0      1      2
     126]]
                precision    recall  f1-score   support

      Analysis    

                                                  

Test Accuracy: 0.8712
[[  106   208   226     0     0     2     0     0     0     0     0]
 [    4   277     0     1     8     5     0     0     1     4     8]
 [    0     6   228     0     1     0     0     0     0     0     0]
 [   29  1712   229    80   225    50     0     0    31    59   150]
 [   88  2179   458    69  2717   141     3     0   151   173  1064]
 [   16   258   452    50    30  2887     1     0     0    10   211]
 [   13   255     1    22   255    39 30430     0    19    46   205]
 [   59     1     0     2    16   418     1 32366     2     6     9]
 [    8   366     2     6   101    12     1     0  1423    52   163]
 [    0    15     0     1     0    10     0     0    13   186     4]
 [    0     2     0     0     8     3     0     0     1     0     9]]
                precision    recall  f1-score   support

      Analysis     0.3282    0.1956    0.2451       542
      Backdoor     0.0525    0.8994    0.0992       308
     Backdoors     0.1429    0.9702    0.2490    

