In [1]:
'''
=====Experiment=====
Dataset: UNSW-NB15 dataset

Check Train Test Dataset Coverage for
- 1. Random Split
- 2. Stratified Split
'''

from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import matplotlib.pyplot as plt
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
# import seaborn as sns
# import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from torch_geometric.loader import DataLoader


project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Datasets.UNSW_NB15.UNSW_NB15_config import UNSW_NB15_Config

In [2]:
csv_file_name = "all_raw"

data = pd.read_csv(os.path.join(project_root, "Datasets", f"UNSW_NB15/All/{csv_file_name}.csv"))

DATASET_NAME = "UNSW_NB15"
EXPERIMENT_NAME = "check_train_test_split"

SOURCE_FILE_ID_COL_NAME = UNSW_NB15_Config.SOURCE_FILE_ID_COL_NAME

SOURCE_IP_COL_NAME = UNSW_NB15_Config.SOURCE_IP_COL_NAME
DESTINATION_IP_COL_NAME = UNSW_NB15_Config.DESTINATION_IP_COL_NAME
SOURCE_PORT_COL_NAME = UNSW_NB15_Config.SOURCE_PORT_COL_NAME
DESTINATION_PORT_COL_NAME = UNSW_NB15_Config.DESTINATION_PORT_COL_NAME

ATTACK_CLASS_COL_NAME = UNSW_NB15_Config.ATTACK_CLASS_COL_NAME
IS_ATTACK_COL_NAME = UNSW_NB15_Config.IS_ATTACK_COL_NAME

BENIGN_CLASS_NAME = UNSW_NB15_Config.BENIGN_CLASS_NAME

TIME_COLS = UNSW_NB15_Config.TIME_COL_NAMES

print(data[ATTACK_CLASS_COL_NAME].value_counts())
print(data[IS_ATTACK_COL_NAME].value_counts())

MULTICLASS = True

if MULTICLASS:
    label_col = ATTACK_CLASS_COL_NAME
    data.drop(columns=[IS_ATTACK_COL_NAME], inplace=True)
else:
    label_col = IS_ATTACK_COL_NAME
    data.drop(columns=[ATTACK_CLASS_COL_NAME], inplace=True)


saves_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, EXPERIMENT_NAME)

checkpoint_path = os.path.join(saves_path, f"checkpoints_{csv_file_name}.pth")
best_model_path = os.path.join(saves_path, f"best_model_{csv_file_name}.pth")

os.makedirs(saves_path, exist_ok=True)

attack_cat
Normal            2218764
Generic            215481
Exploits            44525
Fuzzers             24246
DoS                 16353
Reconnaissance      13987
Analysis             2677
Backdoors            2329
Shellcode            1511
Worms                 174
Name: count, dtype: int64
label
0    2218764
1     321283
Name: count, dtype: int64


In [3]:
data.drop(columns=UNSW_NB15_Config.DROP_COLS,inplace=True)
print(data.columns)

Index(['srcip', 'sport', 'dstip', 'dsport', 'state', 'dur', 'sbytes', 'dbytes',
       'sttl', 'dttl', 'sloss', 'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts',
       'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth',
       'res_bdy_len', 'Sjit', 'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt',
       'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl',
       'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src',
       'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm',
       'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'attack_cat', 'source_file_id'],
      dtype='object')


In [4]:
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME].apply(str)
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME].apply(str)

# # Combine Port and IP
data[SOURCE_PORT_COL_NAME] = data[SOURCE_PORT_COL_NAME].apply(str)
data[DESTINATION_PORT_COL_NAME] = data[DESTINATION_PORT_COL_NAME].apply(str)

data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME] + ':' + data[SOURCE_PORT_COL_NAME]
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME] + ':' + data[DESTINATION_PORT_COL_NAME]
data.drop(columns=[SOURCE_PORT_COL_NAME,DESTINATION_PORT_COL_NAME],inplace=True)

data = pd.get_dummies(data, columns = UNSW_NB15_Config.CATEGORICAL_COLS) # One Hot Encoding for categorical data
converted_categorical_cols = [col for col in data.columns if col.startswith(tuple(UNSW_NB15_Config.CATEGORICAL_COLS))]

In [5]:
print(data.head)

<bound method NDFrame.head of                          srcip                    dstip        dur  sbytes  \
0              10.40.182.1_0:0            224.0.0.5_0:0  50.004337     384   
1               10.40.85.1_0:0            224.0.0.5_0:0  50.004341     384   
2              10.40.182.1_0:0            224.0.0.5_0:0  50.004337     384   
3               10.40.85.1_0:0            224.0.0.5_0:0  50.004341     384   
4        192.168.241.243_0:259  192.168.241.243_0:49320   0.000000    1780   
...                        ...                      ...        ...     ...   
2540042      59.166.0.0_3:2111       149.171.126.5_3:53   0.001035     146   
2540043     59.166.0.5_3:49044    149.171.126.3_3:30639   0.220630     424   
2540044     59.166.0.6_3:37717    149.171.126.7_3:35667   0.031576    2646   
2540045      59.166.0.2_3:1768    149.171.126.7_3:64122   0.096835    4862   
2540046      59.166.0.9_3:7045       149.171.126.7_3:25   0.201886   37552   

         dbytes  sttl  dttl  slos

In [6]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of                          srcip                    dstip        dur  sbytes  \
0              10.40.182.1_0:0            224.0.0.5_0:0  50.004337     384   
1               10.40.85.1_0:0            224.0.0.5_0:0  50.004341     384   
2              10.40.182.1_0:0            224.0.0.5_0:0  50.004337     384   
3               10.40.85.1_0:0            224.0.0.5_0:0  50.004341     384   
4        192.168.241.243_0:259  192.168.241.243_0:49320   0.000000    1780   
...                        ...                      ...        ...     ...   
2540042      59.166.0.0_3:2111       149.171.126.5_3:53   0.001035     146   
2540043     59.166.0.5_3:49044    149.171.126.3_3:30639   0.220630     424   
2540044     59.166.0.6_3:37717    149.171.126.7_3:35667   0.031576    2646   
2540045      59.166.0.2_3:1768    149.171.126.7_3:64122   0.096835    4862   
2540046      59.166.0.9_3:7045       149.171.126.7_3:25   0.201886   37552   

         dbytes  sttl  dttl  slos

In [7]:
scaler = StandardScaler()
cols_to_norm = UNSW_NB15_Config.COLS_TO_NORM
print(data[cols_to_norm].describe()) # Check if there's any too large value

                dur        sbytes        dbytes          sttl          dttl  \
count  2.540047e+06  2.540047e+06  2.540047e+06  2.540047e+06  2.540047e+06   
mean   6.587916e-01  4.339600e+03  3.642759e+04  6.278197e+01  3.076681e+01   
std    1.392493e+01  5.640599e+04  1.610960e+05  7.462277e+01  4.285089e+01   
min    0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00   
25%    1.037000e-03  2.000000e+02  1.780000e+02  3.100000e+01  2.900000e+01   
50%    1.586100e-02  1.470000e+03  1.820000e+03  3.100000e+01  2.900000e+01   
75%    2.145545e-01  3.182000e+03  1.489400e+04  3.100000e+01  2.900000e+01   
max    8.786638e+03  1.435577e+07  1.465753e+07  2.550000e+02  2.540000e+02   

              sloss         dloss         Sload         Dload         Spkts  \
count  2.540047e+06  2.540047e+06  2.540047e+06  2.540047e+06  2.540047e+06   
mean   5.163921e+00  1.632944e+01  3.695645e+07  2.450861e+06  3.328884e+01   
std    2.251707e+01  5.659474e+01  1.186043e+08  4.

In [8]:
def check_numeric_issues(df, cols_to_norm):
    for col in cols_to_norm:
        try:
            # Try to coerce to numeric
            df[col] = pd.to_numeric(df[col], errors='coerce')
            
            # Try to clip the column
            df[col] = df[col].clip(lower=-1e9, upper=1e9)
            
        except Exception as e:
            print(f"❌ Column '{col}' failed with error: {e}")
            print(f"  - Sample values: {df[col].dropna().unique()[:5]}")
            print(f"  - Data type: {df[col].dtype}")
            continue

    print("\n✅ All other columns processed successfully.")

check_numeric_issues(data, UNSW_NB15_Config.COLS_TO_NORM)


✅ All other columns processed successfully.


In [9]:
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [10]:
from sklearn.preprocessing import LabelEncoder

num_classes = 2
class_map = [0, 1]
if MULTICLASS:
    le = LabelEncoder()
    attack_labels = le.fit_transform(data[ATTACK_CLASS_COL_NAME])
    class_map = le.classes_
    print(class_map)
    print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))
    data[ATTACK_CLASS_COL_NAME] = attack_labels
    num_classes = len(class_map)
    class_dict = {le.inverse_transform([i])[0]: i for i in range(len(le.classes_))}

BENIGN_CLASS_LABEL = le.transform([BENIGN_CLASS_NAME])[0] if MULTICLASS else 0
ADVERSARIAL_CLASS_LABEL = len(class_map)

['Analysis' 'Backdoors' 'DoS' 'Exploits' 'Fuzzers' 'Generic' 'Normal'
 'Reconnaissance' 'Shellcode' 'Worms']
Attack label mapping: {'Analysis': 0, 'Backdoors': 1, 'DoS': 2, 'Exploits': 3, 'Fuzzers': 4, 'Generic': 5, 'Normal': 6, 'Reconnaissance': 7, 'Shellcode': 8, 'Worms': 9}


In [11]:
# # Maintain the order of the rows in the original dataframe

feature_cols = cols_to_norm + converted_categorical_cols

print('Feature Columns:', feature_cols)
num_features = len(feature_cols)
print('Number of Features:', num_features)

data['h'] = data[ feature_cols ].values.tolist()

Feature Columns: ['dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss', 'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit', 'Djit', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'state_ACC', 'state_CLO', 'state_CON', 'state_ECO', 'state_ECR', 'state_FIN', 'state_INT', 'state_MAS', 'state_PAR', 'state_REQ', 'state_RST', 'state_TST', 'state_TXD', 'state_URH', 'state_URN', 'state_no']
Number of Features: 54


In [12]:
def create_graph(df):

    G_nx = nx.from_pandas_edgelist(df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())
    G_pyg = from_networkx(G_nx)

    num_nodes = G_pyg.num_nodes
    num_edges = G_pyg.num_edges

    G_pyg.x = th.ones(num_nodes, len(df['h'].iloc[0])) 

    edge_attr_list = []
    edge_label_list = []

    for u, v, key, data in G_nx.edges(keys=True, data=True):
        edge_attr_list.append(data['h']) 
        edge_label_list.append(data[label_col]) 

    G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
    G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)

    return G_pyg

In [13]:
from collections import defaultdict
import heapq

import tqdm

class Downsampler:
    def __init__(self, downsample_classes=[BENIGN_CLASS_LABEL], downsample_ratios=[0.1]):
        """
        downsample_classes: list of class names to downsample
        downsample_ratio: keep no more than this ratio for each class
        """
        assert len(downsample_classes) == len(downsample_ratios)
        self.downsample_classes = downsample_classes
        self.downsample_ratio = downsample_ratios

    def downsample(self, label_counts_list, X, y):
        total_counts = defaultdict(int)

        class_heaps = defaultdict(list)
        for i, lc in enumerate(label_counts_list):
            for cls in self.downsample_classes:
                class_label_count = lc.get(cls, 0)
                total_counts[cls] += class_label_count
                heapq.heappush(class_heaps[cls], (-class_label_count, i))

        class_target = {
            cls: total_counts[cls] * self.downsample_ratio[i] for i, cls in enumerate(self.downsample_classes)
        }
                
        indices_to_remove = set()
        class_counts = total_counts

        # 3. For each class, remove top contributing samples until threshold reached
        for cls in self.downsample_classes:
            target = class_target[cls]
            heap = class_heaps[cls]

            pbar = tqdm(desc=f"Downsampling '{cls}'", total=len(heap))
            while class_counts[cls] > target and heap:
                _, idx = heapq.heappop(heap)
                if idx in indices_to_remove:
                    continue
                # For each class in this sample, if it's a downsample class, decrement the count
                for sample_cls, count in label_counts_list[idx].items():
                    if sample_cls in self.downsample_classes:
                        class_counts[sample_cls] -= count
                indices_to_remove.add(idx)
                pbar.update(1)
                pbar.set_postfix(class_label=cls, remaining=class_counts[cls], target=target)
            pbar.close()

        # 4. Apply filter
        keep_mask = [i for i in range(len(X)) if i not in indices_to_remove]
        X_new = [X[i] for i in keep_mask]
        y_new = [y[i] for i in keep_mask]

        return X_new, y_new

downsampler = Downsampler()

In [14]:
from collections import defaultdict
from typing import Counter
from sklearn.preprocessing import MultiLabelBinarizer

from tqdm import tqdm
class RandomSplitGraphDataset:

    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.total_count = len(self.y)

        # Compute class weights
        labels = []

        for graph in self.X:
            labels.append(graph.edge_label.tolist())

        labels = np.concatenate(labels)

        self.class_counts = Counter(labels)

        # Compute the class weights
        self.class_weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(labels),
            y=labels
        )

    def graph_train_test_split(self, test_ratio: float = 0.15, random_state: int = 42):
        train_idx, test_idx = train_test_split(
            np.arange(len(self.X)), 
            test_size=test_ratio, 
            random_state=random_state,
        )

        X_train = [self.X[i] for i in train_idx]
        X_test = [self.X[i] for i in test_idx]

        y_train = [self.y[i] for i in train_idx]
        y_test = [self.y[i] for i in test_idx]

        return RandomSplitGraphDataset(X_train, y_train), RandomSplitGraphDataset(X_test, y_test)


class StratifiedGraphDataset:

    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.total_count = len(self.y)

        # Compute class weights
        labels = []

        for graph in self.X:
            labels.append(graph.edge_label.tolist())

        labels = np.concatenate(labels)

        self.class_counts = Counter(labels)

        # Compute the class weights
        self.class_weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(labels),
            y=labels
        )

    def k_fold_split(self, k: int = 5, test_ratio: float = 0.15, random_state: int = 42):
        cv = MultilabelStratifiedShuffleSplit(test_size=test_ratio, random_state=random_state, n_splits=k)

        mlb = MultiLabelBinarizer()

        y_binary = mlb.fit_transform(self.y)

        return cv.split(np.zeros(len(self.X)), y_binary)

    def graph_train_test_split(self, test_ratio: float = 0.15, random_state: int = 42):
        train_idx, test_idx = next(self.k_fold_split(k = 1, test_ratio = test_ratio, random_state = random_state))
        
        X_train = [self.X[i] for i in train_idx]
        X_test = [self.X[i] for i in test_idx]

        y_train = [self.y[i] for i in train_idx]
        y_test = [self.y[i] for i in test_idx]

        return StratifiedGraphDataset(X_train, y_train), StratifiedGraphDataset(X_test, y_test)
    
    def print_class_distribution_and_weights(self):
        # Use the label encoder to inverse transform the class labels
        class_counts_named = {cls: count for cls, count in self.class_counts.items()}
        class_weights_named = {cls: weight for cls, weight in enumerate(self.class_weights)}
        print("Class Counts and Weights:")
        for cls_label in class_counts_named.keys():
            count = class_counts_named[cls_label]
            weight = class_weights_named[cls_label]
            print(f"{cls_label:<2}  {le.inverse_transform([cls_label])[0]:<15}: Count = {count:<10}, Weight = {weight:<10.4f}")
    
    def __len__(self):
        return self.total_count

    def __iter__(self):
        for g in self.X:
            yield g
            
    def __getitem__(self, idx):
        if isinstance(idx, int):
            return self.X[idx], self.y[idx]
        elif isinstance(idx, slice):
            return [self.X[i] for i in range(len(self.X))][idx], [self.y[i] for i in range(len(self.y))][idx]
        else:
            raise TypeError("Index must be an integer or a slice.")

def generate_graph_datasets(
    df: pd.DataFrame, 
    window_size: int = 200, 
    # overlap_ratio: float = 0, 
    feature_cols=feature_cols,
    ordering_cols=[SOURCE_FILE_ID_COL_NAME] + TIME_COLS + [ATTACK_CLASS_COL_NAME], 
    label_col=label_col,
    build_graph_func=create_graph,
    downsampler=downsampler
    ):

    print("All Columns: ", df.columns)
    print("Ordering Columns: ", ordering_cols)
    assert all(col in df.columns for col in ordering_cols), "All timestamp columns are required"
    assert label_col in df.columns, "Edge label column 'label' is required"
    
    df = df.sort_values(ordering_cols).reset_index(drop=True)
    window_size = int(window_size)
    
    df.drop(columns=set(df.columns) - set(feature_cols) - set(label_col))

    print("Final Columns: ", df.columns)
    
    label_counts_list = []
    X = []
    y = []

    progress_bar = tqdm(range(0, len(df), window_size), desc=f"Generating graphs")
    for start in progress_bar:
        window_df = df[start: min(start + window_size, len(df))]
        contains_label = window_df[label_col].unique()

        G_pyg = build_graph_func(window_df)

        label_counts = window_df[label_col].value_counts()

        label_counts_list.append(label_counts)
        X.append(G_pyg)
        y.append(contains_label.tolist())

    X, y = downsampler.downsample(label_counts_list, X, y)
        
    return X, y

In [None]:
X, y = generate_graph_datasets(data)

All Columns:  Index(['srcip', 'dstip', 'dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss',
       'dloss', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb',
       'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit',
       'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack',
       'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd',
       'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm',
       'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm',
       'attack_cat', 'source_file_id', 'state_ACC', 'state_CLO', 'state_CON',
       'state_ECO', 'state_ECR', 'state_FIN', 'state_INT', 'state_MAS',
       'state_PAR', 'state_REQ', 'state_RST', 'state_TST', 'state_TXD',
       'state_URH', 'state_URN', 'state_no', 'h'],
      dtype='object')
Ordering Columns:  ['source_file_id', 'Stime', 'Ltime', 'attack_cat']
Final Columns:  Index(['srcip', 'dstip', 'dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 's

Generating graphs: 100%|██████████| 12701/12701 [02:01<00:00, 104.84it/s]
Downsampling '6':  85%|████████▌ | 10853/12701 [00:07<00:01, 1498.50it/s, class_label=6, remaining=221852, target=2.22e+5]


Stratified Train Dataset Size: Counter({np.int64(6): 188281, np.int64(5): 89484, np.int64(3): 15147, np.int64(2): 10218, np.int64(4): 4292, np.int64(7): 2911, np.int64(0): 1752, np.int64(1): 1531, np.int64(8): 196, np.int64(9): 35})
Stratified Train Dataset Size: Counter({np.int64(6): 188277, np.int64(5): 89392, np.int64(3): 15334, np.int64(2): 10014, np.int64(4): 4362, np.int64(7): 2897, np.int64(0): 1778, np.int64(1): 1561, np.int64(8): 196, np.int64(9): 36})


In [None]:

stratified_graph_daataset = StratifiedGraphDataset(X, y)
random_split_graph_dataset = RandomSplitGraphDataset(X, y)


In [35]:
num_tests = 10
for i in range(num_tests):

    total_stratified_js = 0
    total_random_js = 0

    stratified_train_graph_dataset, stratified_test_graph_dataset = stratified_graph_daataset.graph_train_test_split(test_ratio=0.15, random_state=i)
    random_train_graph_dataset, random_test_graph_dataset = random_split_graph_dataset.graph_train_test_split(test_ratio=0.15, random_state=i)

    # Compute the Jensen-Shannon Divergence of train and test datasets
    def compute_js_divergence(counter_p, counter_q):
        """
        Compute the Jensen-Shannon Divergence between two class count dictionaries (Counter).
        """
        # Get the union of all class labels
        all_keys = sorted(set(counter_p.keys()).union(set(counter_q.keys())))
        # Convert to arrays in the same order
        p = np.array([counter_p.get(k, 0) for k in all_keys], dtype=np.float64)
        q = np.array([counter_q.get(k, 0) for k in all_keys], dtype=np.float64)

        # Normalize the distributions
        p = p / np.sum(p) if np.sum(p) > 0 else p
        q = q / np.sum(q) if np.sum(q) > 0 else q

        # Compute the average distribution
        m = 0.5 * (p + q)

        # Compute the Kullback-Leibler divergence
        kl_p_m = np.sum(np.where(p != 0, p * np.log(p / m), 0))
        kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))

        # Return the Jensen-Shannon Divergence
        jsd = 0.5 * (kl_p_m + kl_q_m)
        return jsd


    def KL_divergence(counter_p, counter_q):
        """
        Compute the Kullback-Leibler Divergence between two distributions.
        """
        # Get the union of all class labels
        all_keys = sorted(set(counter_p.keys()).union(set(counter_q.keys())))

        p = np.array([counter_p.get(k, 0) for k in all_keys], dtype=np.float64)
        q = np.array([counter_q.get(k, 0) for k in all_keys], dtype=np.float64)

        # Normalize the distributions
        p = p / np.sum(p) if np.sum(p) > 0 else p
        q = q / np.sum(q) if np.sum(q) > 0 else q

        print("KL-Divergence: ", np.sum(np.where(p != 0, p * np.log(p / q), 0))) 

    total_stratified_js += compute_js_divergence(
        stratified_train_graph_dataset.class_counts,
        stratified_test_graph_dataset.class_counts
    )

    total_random_js += compute_js_divergence(
        random_train_graph_dataset.class_counts,
        random_test_graph_dataset.class_counts
    )


print(f"Stratified Split Jensen-Shannon Divergence: {total_stratified_js / num_tests}")
print(f"Random Split Jensen-Shannon Divergence: {total_random_js / num_tests}")

Stratified Split Jensen-Shannon Divergence: 1.546106727616257e-06
Random Split Jensen-Shannon Divergence: 3.0448183797230593e-05
