In [61]:
from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
# import seaborn as sns
# import matplotlib.pyplot as plt
import numpy as np
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Datasets.UNSW_NB15.UNSW_NB15_config import UNSW_NB15_Config

In [None]:
data = pd.read_csv(os.path.join(project_root, "Datasets", "UNSW_NB15/All/all_raw_downsampled.csv"))

DATASET_NAME = "UNSW_NB15"

SOURCE_IP_COL_NAME = UNSW_NB15_Config.SOURCE_IP_COL_NAME
DESTINATION_IP_COL_NAME = UNSW_NB15_Config.DESTINATION_IP_COL_NAME
SOURCE_PORT_COL_NAME = UNSW_NB15_Config.SOURCE_PORT_COL_NAME
DESTINATION_PORT_COL_NAME = UNSW_NB15_Config.DESTINATION_PORT_COL_NAME

ATTACK_CLASS_COL_NAME = UNSW_NB15_Config.ATTACK_CLASS_COL_NAME
IS_ATTACK_COL_NAME = UNSW_NB15_Config.IS_ATTACK_COL_NAME

BENIGN_CLASS_NAME = UNSW_NB15_Config.BENIGN_CLASS_NAME

MULTICLASS = True
label_col = ATTACK_CLASS_COL_NAME if MULTICLASS else IS_ATTACK_COL_NAME

  data = pd.read_csv(os.path.join(project_root, "Datasets", "UNSW_NB15/All/all_raw_downsampled.csv"))


attack_cat
Normal            221876
Generic           215481
Exploits           44525
Fuzzers            24246
DoS                16353
Reconnaissance     13987
Analysis            2677
Backdoor            1795
Shellcode           1511
Backdoors            534
Worms                174
Name: count, dtype: int64


In [None]:
data.drop(columns=UNSW_NB15_Config.DROP_COLS,inplace=True)

In [None]:
print(data[ATTACK_CLASS_COL_NAME].value_counts())
print(data[IS_ATTACK_COL_NAME].value_counts())

label
1    321283
0    221876
Name: count, dtype: int64


In [65]:
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME].apply(str)
data[SOURCE_PORT_COL_NAME] = data[SOURCE_PORT_COL_NAME].apply(str)
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME].apply(str)
data[DESTINATION_PORT_COL_NAME] = data[DESTINATION_PORT_COL_NAME].apply(str)
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME] + ':' + data[SOURCE_PORT_COL_NAME]
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME] + ':' + data[DESTINATION_PORT_COL_NAME]
data.drop(columns=[SOURCE_PORT_COL_NAME,DESTINATION_PORT_COL_NAME],inplace=True)

In [66]:
print(data.head)

<bound method NDFrame.head of                      srcip               dstip state        dur  sbytes  \
0             10.40.85.1:0         224.0.0.5:0   INT  50.004341     384   
1          59.166.0.6:2142    149.171.126.4:53   CON   0.001134     132   
2       175.45.176.0:13284   149.171.126.16:80   FIN   2.390390    1362   
3         59.166.0.3:42587    149.171.126.8:25   FIN  34.077175   37358   
4            10.40.170.2:0       10.40.170.2:0   INT   0.000000      46   
...                    ...                 ...   ...        ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543155  175.45.176.0:47439   149.171.126.10:53   INT   0.000001     114   
543156    59.166.0.5:53521    149.171.126.7:21   CON   1.086072    1940   
543157  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   
543158  175.45.176.0:17293  149.171.126.17:110   CON   0.942984     574   

        dbytes  sttl  dttl  sloss  dloss  ...  ct_ftp_cmd  ct_srv_src

In [None]:
data = pd.get_dummies(data, columns = UNSW_NB15_Config.CATEGORICAL_COLS) # One Hot Encoding for categorical data

In [68]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of                      srcip               dstip        dur  sbytes  dbytes  \
0             10.40.85.1:0         224.0.0.5:0  50.004341     384       0   
1          59.166.0.6:2142    149.171.126.4:53   0.001134     132     164   
2       175.45.176.0:13284   149.171.126.16:80   2.390390    1362     268   
3         59.166.0.3:42587    149.171.126.8:25  34.077175   37358    3380   
4            10.40.170.2:0       10.40.170.2:0   0.000000      46       0   
...                    ...                 ...        ...     ...     ...   
543154  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543155  175.45.176.0:47439   149.171.126.10:53   0.000001     114       0   
543156    59.166.0.5:53521    149.171.126.7:21   1.086072    1940    2404   
543157  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   
543158  175.45.176.0:17293  149.171.126.17:110   0.942984     574     676   

        sttl  dttl  sloss  dloss         Sloa

In [69]:
scaler = StandardScaler()
cols_to_norm = UNSW_NB15_Config.COLS_TO_NORM
print(data[cols_to_norm].describe()) # Check if there's any too large value

                 dur        sbytes        dbytes           sttl  \
count  543159.000000  5.431590e+05  5.431590e+05  543159.000000   
mean        0.706760  5.136572e+03  1.936909e+04     157.197364   
std        12.637229  1.202311e+05  1.390925e+05     108.452474   
min         0.000000  0.000000e+00  0.000000e+00       0.000000   
25%         0.000007  1.140000e+02  0.000000e+00      31.000000   
50%         0.000011  2.000000e+02  0.000000e+00     254.000000   
75%         0.072088  1.580000e+03  1.940000e+03     254.000000   
max      8760.776367  1.435577e+07  1.465753e+07     255.000000   

                dttl          sloss          dloss         Sload  \
count  543159.000000  543159.000000  543159.000000  5.431590e+05   
mean       38.850764       3.800661       8.729770  6.877595e+07   
std        77.034389      45.616565      50.136204  1.420534e+08   
min         0.000000       0.000000       0.000000  0.000000e+00   
25%         0.000000       0.000000       0.000000  3.70

In [None]:
def check_numeric_issues(df, cols_to_norm):
    for col in cols_to_norm:
        try:
            # Try to coerce to numeric
            df[col] = pd.to_numeric(df[col], errors='coerce')
            
            # Try to clip the column
            df[col] = df[col].clip(lower=-1e9, upper=1e9)
            
        except Exception as e:
            print(f"❌ Column '{col}' failed with error: {e}")
            print(f"  - Sample values: {df[col].dropna().unique()[:5]}")
            print(f"  - Data type: {df[col].dtype}")
            continue

    print("\n✅ All other columns processed successfully.")

check_numeric_issues(data, UNSW_NB15_Config.COLS_TO_NORM)

In [None]:
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [None]:
from sklearn.preprocessing import LabelEncoder

num_classes = 2
class_map = [0, 1]
if MULTICLASS:
    le = LabelEncoder()
    attack_labels = le.fit_transform(data[ATTACK_CLASS_COL_NAME])
    class_map = le.classes_
    print(class_map)
    print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))
    data[ATTACK_CLASS_COL_NAME] = attack_labels
    num_classes = len(class_map)


['Analysis' 'Backdoor' 'Backdoors' 'DoS' 'Exploits' 'Fuzzers' 'Generic'
 'Normal' 'Reconnaissance' 'Shellcode' 'Worms']
Attack label mapping: {'Analysis': 0, 'Backdoor': 1, 'Backdoors': 2, 'DoS': 3, 'Exploits': 4, 'Fuzzers': 5, 'Generic': 6, 'Normal': 7, 'Reconnaissance': 8, 'Shellcode': 9, 'Worms': 10}


In [None]:
X_train, X_test, y_train, y_test = train_test_split(
     data, data[label_col], test_size=0.3, random_state=42, stratify=data[label_col])
print(len(X_train))
print(len(X_test))
print(y_train.value_counts())
print(y_test.value_counts())

325895
217264
attack_cat
Normal            133125
Generic           129289
Exploits           26715
Fuzzers            14548
DoS                 9812
Reconnaissance      8392
Analysis            1606
Backdoor            1077
Shellcode            907
Backdoors            320
Worms                104
Name: count, dtype: int64
attack_cat
Normal            88751
Generic           86192
Exploits          17810
Fuzzers            9698
DoS                6541
Reconnaissance     5595
Analysis           1071
Backdoor            718
Shellcode           604
Backdoors           214
Worms                70
Name: count, dtype: int64


In [73]:
X_train['h'] = X_train[ cols_to_norm ].values.tolist()

In [None]:
# Convert NetworkX graph to PyG graph

G_nx = nx.from_pandas_edgelist(X_train, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())
G_pyg = from_networkx(G_nx)

num_nodes = G_pyg.num_nodes
num_edges = G_pyg.num_edges

G_pyg.x = th.ones(num_nodes, len(X_train['h'].iloc[0])) 

edge_attr_list = []
edge_label_list = []

for u, v, key, data in G_nx.edges(keys=True, data=True):
    edge_attr_list.append(data['h']) 
    edge_label_list.append(data[label_col]) 

G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)

print("Number of edges in G_pyg:", G_pyg.num_edges)
print("Number of node in G_pyg:", G_pyg.num_nodes)
print("Shape of node in G_pyg:", G_pyg.x.shape)
print("Shape of edge attr in G_pyg:", G_pyg.edge_attr.shape)
print("Shape of edge label in G_pyg:", G_pyg.edge_label.shape)

Number of edges in G_pyg: 325895
Number of node in G_pyg: 192101
Shape of node in G_pyg: torch.Size([192101, 40])
Shape of edge attr in G_pyg: torch.Size([325895, 40])
Shape of edge class in G_pyg: torch.Size([325895])


In [75]:

class EGraphSAGEConv(MessagePassing):
    def __init__(self, node_in_channels, edge_in_channels, out_channels):
        super(EGraphSAGEConv, self).__init__(aggr='mean')  # mean aggregation
        self.lin_node = nn.Linear(node_in_channels, out_channels)
        self.lin_edge = nn.Linear(edge_in_channels, out_channels)
        self.lin_update = nn.Linear(node_in_channels + out_channels, out_channels) # out_channels * 2

    def forward(self, x, edge_index, edge_attr):
        # x: Node features, edge_attr: Edge features, edge_index: Connectivity
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        if edge_attr is not None:
            if edge_attr.size(0) != edge_index.size(1):
                loop_attr = th.zeros((edge_index.size(1) - edge_attr.size(0), edge_attr.size(1))).to(edge_attr.device)
                edge_attr = th.cat([edge_attr, loop_attr], dim=0)
        else:
            print("edge_attr is unexist")
        
        # Propagate and aggregate neighbor information
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j represents the adjacent nodes of x
        # Compute messages by combining node and edge features
        return self.lin_node(x_j) + self.lin_edge(edge_attr)

    def update(self, aggr_out, x):
        # Update node features after message passing
        return self.lin_update(th.cat([x, aggr_out], dim=1))

class MLPPredictor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MLPPredictor, self).__init__()
        self.lin = nn.Linear(in_channels * 2, out_channels)

    def forward(self, data, z):
        row, col = data.edge_index
        # Concatenate the features of source and target nodes for each edge
        edge_feat = th.cat([z[row], z[col]], dim=1)
        return self.lin(edge_feat)

class EGraphSAGE(nn.Module):
    def __init__(self, node_in_channels, edge_in_channels, hidden_channels, out_channels):
        super(EGraphSAGE, self).__init__()
        self.conv1 = EGraphSAGEConv(node_in_channels, edge_in_channels, hidden_channels)
        self.conv2 = EGraphSAGEConv(hidden_channels, edge_in_channels, hidden_channels)
        self.mlp_predictor = MLPPredictor(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        return self.mlp_predictor(data, x)




In [76]:
device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
print(device)

cuda:0


In [77]:
th.cuda.empty_cache()

In [None]:
model = EGraphSAGE(node_in_channels=G_pyg.num_node_features, 
                   edge_in_channels=G_pyg.num_edge_features,
                   hidden_channels=128, 
                   out_channels=num_classes).to(device)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)

labels = G_pyg.edge_label.cpu().numpy()
class_weights = class_weight.compute_class_weight('balanced',
                                                  classes=num_classes,
                                                  y=labels)

class_weights = th.FloatTensor(class_weights).cuda()
print("Class weights:", class_weights)
criterion = nn.CrossEntropyLoss(weight = class_weights)
optimizer = th.optim.Adam(model.parameters(), lr=0.001)

In [None]:
from torch_geometric.utils import subgraph
from torch_geometric.data import Data

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

G_pyg.edge_label = G_pyg.edge_label.to(device)
G_pyg.edge_attr = G_pyg.edge_attr.to(device)

def generate_edge_based_batches_with_node_expansion(graph, batch_size, min_nodes):
    num_edges = graph.edge_index.size(1)  # Get total number of edges
    edge_indices = th.arange(num_edges)   # Create list of edge indices
    num_edges_processed = 0
    
    while num_edges_processed < num_edges:
        # Select a batch of edges
        batch_edge_indices = edge_indices[num_edges_processed : min(num_edges_processed + batch_size, num_edges)]
        edge_index = graph.edge_index[:, batch_edge_indices]
        
        # Update the number of edges processed
        num_edges_processed += batch_size
        
        # Get the unique nodes associated with these edges
        batch_nodes = th.cat([edge_index[0], edge_index[1]]).unique()

        # Check if the batch has enough unique nodes
        while batch_nodes.size(0) < min_nodes:
            # Sample additional neighboring nodes to ensure diversity
            additional_edges = int(batch_size / 8)  # Ensure additional_edges is an integer
            batch_edge_indices = th.cat([batch_edge_indices, edge_indices[num_edges_processed : min(num_edges_processed + additional_edges, num_edges)]])
            edge_index = graph.edge_index[:, batch_edge_indices]
            batch_nodes = th.cat([edge_index[0], edge_index[1]]).unique()
            num_edges_processed += additional_edges

            # Avoid potential infinite loops by breaking if no more edges can be added
            if num_edges_processed >= num_edges:
                break

        # Create subgraph from the selected nodes and edges
        edge_index, _, edge_mask = subgraph(batch_nodes, graph.edge_index, relabel_nodes=True, return_edge_mask=True)

        # Use edge_mask to select edge attributes and labels
        edge_attr = graph.edge_attr[edge_mask]
        edge_label = graph.edge_label[edge_mask]

        yield batch_nodes, edge_index, edge_attr, edge_label

batch_size = 64
for epoch in range(200):
    print(f'epoch : {epoch}')
    all_preds = []
    all_labels = []
    
    try:
        for batch_idx, (batch_nodes, edge_index, edge_attr, edge_label) in enumerate(generate_edge_based_batches_with_node_expansion(G_pyg, batch_size, 20)):
            # print(f"Processing epoch {epoch}, batch {batch_idx} with {batch_nodes.size(0)} nodes and {edge_index.size(1)} edges")
            batch = Data(x=G_pyg.x[batch_nodes], edge_index=edge_index, edge_attr=edge_attr, edge_label=edge_label)
            
            if batch.edge_index.size(1) == 0 or batch.edge_label.size(0) == 0:
                print(f"Warning: Empty batch at batch {batch_idx}")
                continue
                
            if batch is None or batch.num_nodes == 0:
                print(f"Warning: Empty batch at Batch {batch_idx}")
                continue 
    
            if th.isnan(batch.x).any() or th.isinf(batch.x).any() or th.isnan(batch.edge_attr).any() or th.isinf(batch.edge_attr).any():
                print(f"Warning: batch x and edge_attr contains NaN or Inf at Batch {batch_idx}")
                continue 
                
            try:
                batch = batch.to(device)
            except Exception as batch_error:
                print(f"Error moving batch to device at Batch {batch_idx}: {batch_error}")
                continue
            
            try:
                out = model(batch)
    
                if th.isnan(out).any() or th.isinf(out).any():
                    print(f"Warning: out contains NaN or Inf at Batch {batch_idx}")
                    continue 
                all_preds.append(out.argmax(dim=1))
                all_labels.append(batch.edge_label)
    
                loss = criterion(out, batch.edge_label)
                if th.isnan(loss):
                    print(f"loss: {loss}")
                    print(f"out: {out}")
                    print(f"edge_labels: {batch.edge_label}")
                    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            except Exception as forward_error:
                print(f"Error during forward/backward pass at Epoch {epoch}, Batch {batch_idx}: {forward_error}")
                continue
        
        all_preds = th.cat(all_preds)
        all_labels = th.cat(all_labels)
        
        epoch_accuracy = compute_accuracy(all_preds, all_labels)
        print(f'Epoch {epoch}, Loss: {loss:.4f}, Accuracy: {epoch_accuracy:.4f}')
        print(all_labels.shape)

    except Exception as e:
        print(f"An error occurred at epoch {epoch}, batch {batch_idx}: {str(e)}")
print("Training is over")

epoch : 0
Epoch 0, Loss: 0.7250, Accuracy: 0.6076
torch.Size([361903])
epoch : 1
Epoch 1, Loss: 0.7162, Accuracy: 0.6582
torch.Size([361903])
epoch : 2
Epoch 2, Loss: 0.7039, Accuracy: 0.6064
torch.Size([361903])
epoch : 3
Epoch 3, Loss: 0.7159, Accuracy: 0.6851
torch.Size([361903])
epoch : 4
Epoch 4, Loss: 0.7101, Accuracy: 0.6290
torch.Size([361903])
epoch : 5
Epoch 5, Loss: 0.7064, Accuracy: 0.6851
torch.Size([361903])
epoch : 6
Epoch 6, Loss: 0.7151, Accuracy: 0.6845
torch.Size([361903])
epoch : 7
Epoch 7, Loss: 0.7085, Accuracy: 0.6080
torch.Size([361903])
epoch : 8
Epoch 8, Loss: 0.7047, Accuracy: 0.6843
torch.Size([361903])
epoch : 9
Epoch 9, Loss: 0.7010, Accuracy: 0.6672
torch.Size([361903])
epoch : 10
Epoch 10, Loss: 0.7049, Accuracy: 0.6845
torch.Size([361903])
epoch : 11
Epoch 11, Loss: 0.7044, Accuracy: 0.6078
torch.Size([361903])
epoch : 12
Epoch 12, Loss: 0.7128, Accuracy: 0.6193
torch.Size([361903])
epoch : 13
Epoch 13, Loss: 0.7057, Accuracy: 0.6081
torch.Size([361903]

In [None]:
th.save(model.state_dict(), f"./Weights/GNN_model_weights_{DATASET_NAME}_subset_2.pth")

In [None]:
X_test['h'] = X_test[ cols_to_norm ].values.tolist()

G_nx_test = nx.from_pandas_edgelist(X_test, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())

G_pyg_test = from_networkx(G_nx_test)

test_num_nodes = G_pyg_test.num_nodes
test_num_edges = G_pyg_test.num_edges

G_pyg_test.x = th.ones(test_num_nodes, len(X_test['h'].iloc[0]))

test_edge_attr_list = []
test_edge_label_list = []

for u, v, key, data in G_nx_test.edges(keys=True, data=True):
    test_edge_attr_list.append(data['h']) 
    test_edge_label_list.append(data[label_col]) 

G_pyg_test.edge_attr = th.tensor(test_edge_attr_list, dtype=th.float32)
G_pyg_test.edge_label = th.tensor(test_edge_label_list, dtype=th.long)

print("Number of edges in G_pyg_test:", G_pyg_test.num_edges)
print("Number of node in G_pyg_test:", G_pyg_test.num_nodes)
print("Shape of node in G_pyg_test:", G_pyg_test.x.shape)
print("Shape of edge attr in G_pyg_test:", G_pyg_test.edge_attr.shape)
print("Shape of edge label in G_pyg_test:", G_pyg_test.edge_label.shape)

Number of edges in G_pyg_test: 217264
Number of node in G_pyg_test: 132746
Shape of node in G_pyg_test: torch.Size([132746, 40])
Shape of edge attr in G_pyg_test: torch.Size([217264, 40])
Shape of edge label in G_pyg_test: torch.Size([217264])
Shape of edge class in G_pyg_test: torch.Size([217264])


In [None]:
from torch_geometric.utils import subgraph
from torch_geometric.data import Data

def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

new_model_2 = EGraphSAGE(node_in_channels=G_pyg_test.num_node_features, 
                       edge_in_channels=G_pyg_test.num_edge_features,
                       hidden_channels=128, 
                       out_channels=num_classes).to(device)

new_model_2.load_state_dict(th.load(f"./Weights/GNN_model_weights_{DATASET_NAME}_subset_2.pth", weights_only=True))

new_model_2.eval()

all_test_preds = []
all_test_labels = []
all_test_classes = []

batch_size = 64

print("inference start")
with th.no_grad():
    for batch_idx, (batch_nodes, edge_index, edge_attr, edge_label) in enumerate(generate_edge_based_batches_with_node_expansion(G_pyg_test, batch_size, 20)):
        # print(f"Processing batch {batch_idx} with {batch_nodes.size(0)} nodes and {edge_index.size(1)} edges")
        batch = Data(x=G_pyg_test.x[batch_nodes], edge_index=edge_index, edge_attr=edge_attr, edge_label=edge_label)
        
        if batch.edge_index.size(1) == 0 or batch.edge_label.size(0) == 0:
            print(f"Warning: Empty batch at batch {batch_idx}")
            continue
            
        if batch is None or batch.num_nodes == 0:
            print(f"Warning: Empty batch at Batch {batch_idx}")
            continue

        if th.isnan(batch.x).any() or th.isinf(batch.x).any() or th.isnan(batch.edge_attr).any() or th.isinf(batch.edge_attr).any():
            print(f"Warning: batch x and edge_attr contains NaN or Inf at Batch {batch_idx}")
            continue
            
        try:
            batch = batch.to(device)
        except Exception as batch_error:
            print(f"Error moving batch to device at Batch {batch_idx}: {batch_error}")
            continue
        
        try:
            out = new_model_2(batch)

            if th.isnan(out).any() or th.isinf(out).any():
                print(f"Warning: out contains NaN or Inf at Batch {batch_idx}")
                continue 
            
            all_test_preds.append(out.argmax(dim=1))
            all_test_labels.append(edge_label)

        except Exception as forward_error:
            print(f"Error during forward/backward pass at Batch {batch_idx}: {forward_error}")
            continue

print("inference done")
all_test_preds = th.cat(all_test_preds).to(device)
all_test_labels = th.cat(all_test_labels).to(device)

test_accuracy = compute_accuracy(all_test_preds, all_test_labels)
print(f'Test Accuracy: {test_accuracy:.4f}')



inference start


inference done
Test Accuracy: 0.4756


In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

pred_labels = all_test_preds.cpu()
all_test_labels = all_test_labels.cpu()

cm = confusion_matrix(all_test_labels, pred_labels)
print(cm)

# Generate a classification report
report = classification_report(all_test_labels, pred_labels, target_names=class_map, digits=4)
print(report)

[[  206     0     0     0    58   495     0     0   617     1     1]
 [    3    11     0     0    21   388     0     0   527    13     7]
 [    2     0     0     0     2   199     0     0    17     0     0]
 [   88    21     0    16   891  3098     0     9  4620   102    55]
 [  362    40     8    11  8632  4582     0    12  6427   366   467]
 [   21     0    82     0   448  8104     0     7  1377   157    17]
 [   41     1     0     9   861 62296     0 27887   751    80    82]
 [  299     3    12     2   123  1037     0 89859   110   110     0]
 [   14     1     3     2   284  1349     0     0  3671   507   229]
 [    0     0     1     0     1   173     0     0   201   228     0]
 [    4     0     0     0    30    12     0     0     1     1    22]]
                precision    recall  f1-score   support

      Analysis     0.1981    0.1495    0.1704      1378
      Backdoor     0.1429    0.0113    0.0210       970
     Backdoors     0.0000    0.0000    0.0000       220
           DoS 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
