In [18]:
'''
=====Experiment=====
Dataset: BoT-IoT dataset

Check Train Test Dataset Coverage for
- 1. Random Split
- 2. Stratified Split
'''

from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import matplotlib.pyplot as plt
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
# import seaborn as sns
# import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from torch_geometric.loader import DataLoader


project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..', '..'))
sys.path.append(project_root)

from Datasets.BoT_IoT.BoT_IoT_config import BoT_IoT_Config

In [19]:
csv_file_name = "all_raw"

data = pd.read_csv(os.path.join(project_root, "Datasets", f"BoT_IoT/All/{csv_file_name}.csv"))

DATASET_NAME = "BoT_IoT"
EXPERIMENT_NAME = "check_train_test_split"
WINDOW_SIZE = 2000

SOURCE_IP_COL_NAME = BoT_IoT_Config.SOURCE_IP_COL_NAME
DESTINATION_IP_COL_NAME = BoT_IoT_Config.DESTINATION_IP_COL_NAME
SOURCE_PORT_COL_NAME = BoT_IoT_Config.SOURCE_PORT_COL_NAME
DESTINATION_PORT_COL_NAME = BoT_IoT_Config.DESTINATION_PORT_COL_NAME

ATTACK_CLASS_COL_NAME = BoT_IoT_Config.ATTACK_CLASS_COL_NAME

BENIGN_CLASS_NAME = BoT_IoT_Config.BENIGN_CLASS_NAME

TIME_COLS = BoT_IoT_Config.TIME_COL_NAMES

print(data[ATTACK_CLASS_COL_NAME].value_counts())

MULTICLASS = True

label_col = ATTACK_CLASS_COL_NAME


saves_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, EXPERIMENT_NAME)

checkpoint_path = os.path.join(saves_path, f"checkpoints_{csv_file_name}.pth")
best_model_path = os.path.join(saves_path, f"best_model_{csv_file_name}.pth")

os.makedirs(saves_path, exist_ok=True)

category
DDoS              1926624
DoS               1650260
Reconnaissance      91082
Normal                477
Theft                  79
Name: count, dtype: int64


In [20]:
data.drop(columns=BoT_IoT_Config.DROP_COLS,inplace=True)
print(data.columns)

Index(['pkSeqID', 'stime', 'flgs_number', 'proto_number', 'saddr', 'sport',
       'daddr', 'dport', 'pkts', 'bytes', 'state_number', 'ltime', 'dur',
       'mean', 'stddev', 'sum', 'min', 'max', 'spkts', 'dpkts', 'sbytes',
       'dbytes', 'rate', 'srate', 'drate', 'TnBPSrcIP', 'TnBPDstIP',
       'TnP_PSrcIP', 'TnP_PDstIP', 'TnP_PerProto', 'TnP_Per_Dport',
       'AR_P_Proto_P_SrcIP', 'AR_P_Proto_P_DstIP', 'N_IN_Conn_P_DstIP',
       'N_IN_Conn_P_SrcIP', 'AR_P_Proto_P_Sport', 'AR_P_Proto_P_Dport',
       'Pkts_P_State_P_Protocol_P_DestIP', 'Pkts_P_State_P_Protocol_P_SrcIP',
       'attack', 'category'],
      dtype='object')


In [21]:
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME].apply(str)
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME].apply(str)

# # Combine Port and IP
data[SOURCE_PORT_COL_NAME] = data[SOURCE_PORT_COL_NAME].apply(str)
data[DESTINATION_PORT_COL_NAME] = data[DESTINATION_PORT_COL_NAME].apply(str)

data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME] + ':' + data[SOURCE_PORT_COL_NAME]
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME] + ':' + data[DESTINATION_PORT_COL_NAME]
data.drop(columns=[SOURCE_PORT_COL_NAME,DESTINATION_PORT_COL_NAME],inplace=True)

data = pd.get_dummies(data, columns = BoT_IoT_Config.CATEGORICAL_COLS) # One Hot Encoding for categorical data
converted_categorical_cols = [col for col in data.columns if col.startswith(tuple(BoT_IoT_Config.CATEGORICAL_COLS))]

In [22]:
print(data.head)

<bound method NDFrame.head of          pkSeqID         stime                  saddr                 daddr  \
0        3576925  1.526344e+09       192.168.100.3:80   192.168.100.55:8080   
1        3576926  1.526344e+09    192.168.100.46:3456      192.168.100.5:80   
2        3576919  1.526344e+09      192.168.100.46:80      192.168.100.5:80   
3        3576920  1.526344e+09      192.168.100.46:80      192.168.100.5:80   
4        3576922  1.526344e+09      192.168.100.7:365     192.168.100.3:565   
...          ...           ...                    ...                   ...   
3668517  3668517  1.529381e+09  192.168.100.150:35062      192.168.100.3:22   
3668518  3668518  1.529381e+09  192.168.100.150:35064      192.168.100.3:22   
3668519  3668519  1.529381e+09  192.168.100.150:35066      192.168.100.3:22   
3668520  3668520  1.529381e+09  192.168.100.150:35070      192.168.100.3:22   
3668521  3668521  1.529381e+09    192.168.100.3:43001  192.168.100.150:4433   

          pkts     by

In [23]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of          pkSeqID         stime                  saddr                 daddr  \
0        3576925  1.526344e+09       192.168.100.3:80   192.168.100.55:8080   
1        3576926  1.526344e+09    192.168.100.46:3456      192.168.100.5:80   
2        3576919  1.526344e+09      192.168.100.46:80      192.168.100.5:80   
3        3576920  1.526344e+09      192.168.100.46:80      192.168.100.5:80   
4        3576922  1.526344e+09      192.168.100.7:365     192.168.100.3:565   
...          ...           ...                    ...                   ...   
3668517  3668517  1.529381e+09  192.168.100.150:35062      192.168.100.3:22   
3668518  3668518  1.529381e+09  192.168.100.150:35064      192.168.100.3:22   
3668519  3668519  1.529381e+09  192.168.100.150:35066      192.168.100.3:22   
3668520  3668520  1.529381e+09  192.168.100.150:35070      192.168.100.3:22   
3668521  3668521  1.529381e+09    192.168.100.3:43001  192.168.100.150:4433   

          pkts     by

In [24]:
scaler = StandardScaler()
cols_to_norm = BoT_IoT_Config.COLS_TO_NORM
print(data[cols_to_norm].describe()) # Check if there's any too large value

               pkts         bytes           dur          mean        stddev  \
count  3.668522e+06  3.668522e+06  3.668522e+06  3.668522e+06  3.668522e+06   
mean   7.725963e+00  8.690501e+02  2.033479e+01  2.231063e+00  8.871499e-01   
std    1.155876e+02  1.122667e+05  2.148764e+01  1.517728e+00  8.037139e-01   
min    1.000000e+00  6.000000e+01  0.000000e+00  0.000000e+00  0.000000e+00   
25%    5.000000e+00  4.200000e+02  1.256256e+01  1.819670e-01  3.001900e-02   
50%    7.000000e+00  6.000000e+02  1.550852e+01  2.690125e+00  7.938960e-01   
75%    9.000000e+00  7.700000e+02  2.709986e+01  3.565203e+00  1.745296e+00   
max    7.005700e+04  7.183334e+07  2.771485e+03  4.981882e+00  2.496763e+00   

                sum           min           max         spkts         dpkts  \
count  3.668522e+06  3.668522e+06  3.668522e+06  3.668522e+06  3.668522e+06   
mean   7.721635e+00  1.017540e+00  3.020015e+00  7.314146e+00  4.118173e-01   
std    7.616199e+00  1.483688e+00  1.860877e+00  7.

In [25]:
def check_numeric_issues(df, cols_to_norm):
    for col in cols_to_norm:
        try:
            # Try to coerce to numeric
            df[col] = pd.to_numeric(df[col], errors='coerce')
            
            # Try to clip the column
            df[col] = df[col].clip(lower=-1e9, upper=1e9)
            
        except Exception as e:
            print(f"❌ Column '{col}' failed with error: {e}")
            print(f"  - Sample values: {df[col].dropna().unique()[:5]}")
            print(f"  - Data type: {df[col].dtype}")
            continue

    print("\n✅ All other columns processed successfully.")

check_numeric_issues(data, BoT_IoT_Config.COLS_TO_NORM)


✅ All other columns processed successfully.


In [26]:
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [27]:
from sklearn.preprocessing import LabelEncoder

num_classes = 2
class_map = [0, 1]
if MULTICLASS:
    le = LabelEncoder()
    attack_labels = le.fit_transform(data[ATTACK_CLASS_COL_NAME])
    class_map = le.classes_
    print(class_map)
    print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))
    data[ATTACK_CLASS_COL_NAME] = attack_labels
    num_classes = len(class_map)
    class_dict = {le.inverse_transform([i])[0]: i for i in range(len(le.classes_))}

BENIGN_CLASS_LABEL = le.transform([BENIGN_CLASS_NAME])[0] if MULTICLASS else 0
ADVERSARIAL_CLASS_LABEL = len(class_map)

['DDoS' 'DoS' 'Normal' 'Reconnaissance' 'Theft']
Attack label mapping: {'DDoS': 0, 'DoS': 1, 'Normal': 2, 'Reconnaissance': 3, 'Theft': 4}


In [28]:
# # Maintain the order of the rows in the original dataframe

feature_cols = cols_to_norm + converted_categorical_cols

print('Feature Columns:', feature_cols)
num_features = len(feature_cols)
print('Number of Features:', num_features)

data['h'] = data[ feature_cols ].values.tolist()

Feature Columns: ['pkts', 'bytes', 'dur', 'mean', 'stddev', 'sum', 'min', 'max', 'spkts', 'dpkts', 'sbytes', 'dbytes', 'rate', 'srate', 'drate', 'TnBPSrcIP', 'TnBPDstIP', 'TnP_PSrcIP', 'TnP_PDstIP', 'TnP_PerProto', 'TnP_Per_Dport', 'AR_P_Proto_P_SrcIP', 'AR_P_Proto_P_DstIP', 'N_IN_Conn_P_DstIP', 'N_IN_Conn_P_SrcIP', 'AR_P_Proto_P_Sport', 'AR_P_Proto_P_Dport', 'Pkts_P_State_P_Protocol_P_DestIP', 'Pkts_P_State_P_Protocol_P_SrcIP', 'flgs_number_1', 'flgs_number_2', 'flgs_number_3', 'flgs_number_4', 'flgs_number_5', 'flgs_number_6', 'flgs_number_7', 'flgs_number_8', 'flgs_number_9', 'state_number_1', 'state_number_2', 'state_number_3', 'state_number_4', 'state_number_5', 'state_number_6', 'state_number_7', 'state_number_8', 'state_number_9', 'state_number_10', 'state_number_11', 'proto_number_1', 'proto_number_2', 'proto_number_3', 'proto_number_4', 'proto_number_5']
Number of Features: 54


In [29]:
def create_graph(df):

    G_nx = nx.from_pandas_edgelist(df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())
    G_pyg = from_networkx(G_nx)

    num_nodes = G_pyg.num_nodes
    num_edges = G_pyg.num_edges

    G_pyg.x = th.ones(num_nodes, len(df['h'].iloc[0])) 

    edge_attr_list = []
    edge_label_list = []

    for u, v, key, data in G_nx.edges(keys=True, data=True):
        edge_attr_list.append(data['h']) 
        edge_label_list.append(data[label_col]) 

    G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
    G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)

    return G_pyg

In [30]:
from collections import defaultdict
import heapq

import tqdm

class Downsampler:
    def __init__(self, downsample_classes=[BENIGN_CLASS_LABEL], downsample_ratios=[0.1]):
        """
        downsample_classes: list of class names to downsample
        downsample_ratio: keep no more than this ratio for each class
        """
        assert len(downsample_classes) == len(downsample_ratios)
        self.downsample_classes = downsample_classes
        self.downsample_ratio = downsample_ratios

    def downsample(self, label_counts_list, X, y):
        total_counts = defaultdict(int)

        class_heaps = defaultdict(list)
        for i, lc in enumerate(label_counts_list):
            for cls in self.downsample_classes:
                class_label_count = lc.get(cls, 0)
                total_counts[cls] += class_label_count
                heapq.heappush(class_heaps[cls], (-class_label_count, i))

        class_target = {
            cls: total_counts[cls] * self.downsample_ratio[i] for i, cls in enumerate(self.downsample_classes)
        }
                
        indices_to_remove = set()
        class_counts = total_counts

        # 3. For each class, remove top contributing samples until threshold reached
        for cls in self.downsample_classes:
            target = class_target[cls]
            heap = class_heaps[cls]

            pbar = tqdm(desc=f"Downsampling '{cls}'", total=len(heap))
            while class_counts[cls] > target and heap:
                _, idx = heapq.heappop(heap)
                if idx in indices_to_remove:
                    continue
                # For each class in this sample, if it's a downsample class, decrement the count
                for sample_cls, count in label_counts_list[idx].items():
                    if sample_cls in self.downsample_classes:
                        class_counts[sample_cls] -= count
                indices_to_remove.add(idx)
                pbar.update(1)
                pbar.set_postfix(class_label=cls, remaining=class_counts[cls], target=target)
            pbar.close()

        # 4. Apply filter
        keep_mask = [i for i in range(len(X)) if i not in indices_to_remove]
        X_new = [X[i] for i in keep_mask]
        y_new = [y[i] for i in keep_mask]

        return X_new, y_new

downsampler = Downsampler()

In [31]:
from collections import defaultdict
from typing import Counter
from sklearn.preprocessing import MultiLabelBinarizer

from tqdm import tqdm
class RandomSplitGraphDataset:

    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.total_count = len(self.y)

        # Compute class weights
        labels = []

        for graph in self.X:
            labels.append(graph.edge_label.tolist())

        labels = np.concatenate(labels)

        self.class_counts = Counter(labels)

        # Compute the class weights
        self.class_weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(labels),
            y=labels
        )

    def graph_train_test_split(self, test_ratio: float = 0.15, random_state: int = 42):
        train_idx, test_idx = train_test_split(
            np.arange(len(self.X)), 
            test_size=test_ratio, 
            random_state=random_state,
        )

        X_train = [self.X[i] for i in train_idx]
        X_test = [self.X[i] for i in test_idx]

        y_train = [self.y[i] for i in train_idx]
        y_test = [self.y[i] for i in test_idx]

        return RandomSplitGraphDataset(X_train, y_train), RandomSplitGraphDataset(X_test, y_test)


class StratifiedGraphDataset:

    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.total_count = len(self.y)

        # Compute class weights
        labels = []

        for graph in self.X:
            labels.append(graph.edge_label.tolist())

        labels = np.concatenate(labels)

        self.class_counts = Counter(labels)

        # Compute the class weights
        self.class_weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(labels),
            y=labels
        )

    def k_fold_split(self, k: int = 5, test_ratio: float = 0.15, random_state: int = 42):
        cv = MultilabelStratifiedShuffleSplit(test_size=test_ratio, random_state=random_state, n_splits=k)

        mlb = MultiLabelBinarizer()

        y_binary = mlb.fit_transform(self.y)

        return cv.split(np.zeros(len(self.X)), y_binary)

    def graph_train_test_split(self, test_ratio: float = 0.15, random_state: int = 42):
        train_idx, test_idx = next(self.k_fold_split(k = 1, test_ratio = test_ratio, random_state = random_state))
        
        X_train = [self.X[i] for i in train_idx]
        X_test = [self.X[i] for i in test_idx]

        y_train = [self.y[i] for i in train_idx]
        y_test = [self.y[i] for i in test_idx]

        return StratifiedGraphDataset(X_train, y_train), StratifiedGraphDataset(X_test, y_test)
    
    def print_class_distribution_and_weights(self):
        # Use the label encoder to inverse transform the class labels
        class_counts_named = {cls: count for cls, count in self.class_counts.items()}
        class_weights_named = {cls: weight for cls, weight in enumerate(self.class_weights)}
        print("Class Counts and Weights:")
        for cls_label in class_counts_named.keys():
            count = class_counts_named[cls_label]
            weight = class_weights_named[cls_label]
            print(f"{cls_label:<2}  {le.inverse_transform([cls_label])[0]:<15}: Count = {count:<10}, Weight = {weight:<10.4f}")
    
    def __len__(self):
        return self.total_count

    def __iter__(self):
        for g in self.X:
            yield g
            
    def __getitem__(self, idx):
        if isinstance(idx, int):
            return self.X[idx], self.y[idx]
        elif isinstance(idx, slice):
            return [self.X[i] for i in range(len(self.X))][idx], [self.y[i] for i in range(len(self.y))][idx]
        else:
            raise TypeError("Index must be an integer or a slice.")

def generate_graph_datasets(
    df: pd.DataFrame, 
    window_size: int = WINDOW_SIZE, 
    # overlap_ratio: float = 0, 
    feature_cols=feature_cols,
    ordering_cols=TIME_COLS, 
    label_col=label_col,
    build_graph_func=create_graph,
    ):

    print("All Columns: ", df.columns)
    print("Ordering Columns: ", ordering_cols)
    assert all(col in df.columns for col in ordering_cols), "All timestamp columns are required"
    assert label_col in df.columns, "Edge label column 'label' is required"
    
    df = df.sort_values(ordering_cols).reset_index(drop=True)
    window_size = int(window_size)
    
    df.drop(columns=set(df.columns) - set(feature_cols) - set(label_col))

    print("Final Columns: ", df.columns)
    
    label_counts_list = []
    X = []
    y = []

    progress_bar = tqdm(range(0, len(df), window_size), desc=f"Generating graphs")
    for start in progress_bar:
        window_df = df[start: min(start + window_size, len(df))]
        contains_label = window_df[label_col].unique()

        G_pyg = build_graph_func(window_df)

        label_counts = window_df[label_col].value_counts()

        label_counts_list.append(label_counts)
        X.append(G_pyg)
        y.append(contains_label.tolist())

    return X, y

In [32]:
X, y = generate_graph_datasets(data)

All Columns:  Index(['pkSeqID', 'stime', 'saddr', 'daddr', 'pkts', 'bytes', 'ltime', 'dur',
       'mean', 'stddev', 'sum', 'min', 'max', 'spkts', 'dpkts', 'sbytes',
       'dbytes', 'rate', 'srate', 'drate', 'TnBPSrcIP', 'TnBPDstIP',
       'TnP_PSrcIP', 'TnP_PDstIP', 'TnP_PerProto', 'TnP_Per_Dport',
       'AR_P_Proto_P_SrcIP', 'AR_P_Proto_P_DstIP', 'N_IN_Conn_P_DstIP',
       'N_IN_Conn_P_SrcIP', 'AR_P_Proto_P_Sport', 'AR_P_Proto_P_Dport',
       'Pkts_P_State_P_Protocol_P_DestIP', 'Pkts_P_State_P_Protocol_P_SrcIP',
       'attack', 'category', 'flgs_number_1', 'flgs_number_2', 'flgs_number_3',
       'flgs_number_4', 'flgs_number_5', 'flgs_number_6', 'flgs_number_7',
       'flgs_number_8', 'flgs_number_9', 'state_number_1', 'state_number_2',
       'state_number_3', 'state_number_4', 'state_number_5', 'state_number_6',
       'state_number_7', 'state_number_8', 'state_number_9', 'state_number_10',
       'state_number_11', 'proto_number_1', 'proto_number_2', 'proto_number_3',
    

Final Columns:  Index(['pkSeqID', 'stime', 'saddr', 'daddr', 'pkts', 'bytes', 'ltime', 'dur',
       'mean', 'stddev', 'sum', 'min', 'max', 'spkts', 'dpkts', 'sbytes',
       'dbytes', 'rate', 'srate', 'drate', 'TnBPSrcIP', 'TnBPDstIP',
       'TnP_PSrcIP', 'TnP_PDstIP', 'TnP_PerProto', 'TnP_Per_Dport',
       'AR_P_Proto_P_SrcIP', 'AR_P_Proto_P_DstIP', 'N_IN_Conn_P_DstIP',
       'N_IN_Conn_P_SrcIP', 'AR_P_Proto_P_Sport', 'AR_P_Proto_P_Dport',
       'Pkts_P_State_P_Protocol_P_DestIP', 'Pkts_P_State_P_Protocol_P_SrcIP',
       'attack', 'category', 'flgs_number_1', 'flgs_number_2', 'flgs_number_3',
       'flgs_number_4', 'flgs_number_5', 'flgs_number_6', 'flgs_number_7',
       'flgs_number_8', 'flgs_number_9', 'state_number_1', 'state_number_2',
       'state_number_3', 'state_number_4', 'state_number_5', 'state_number_6',
       'state_number_7', 'state_number_8', 'state_number_9', 'state_number_10',
       'state_number_11', 'proto_number_1', 'proto_number_2', 'proto_number_3',
  

Generating graphs: 100%|██████████| 1835/1835 [03:03<00:00, 10.00it/s]


In [33]:

stratified_graph_daataset = StratifiedGraphDataset(X, y)
random_split_graph_dataset = RandomSplitGraphDataset(X, y)


In [34]:
num_tests = 100

# Compute the Jensen-Shannon Divergence of train and test datasets
def compute_js_divergence(counter_p, counter_q, epsilon=1e-10):
    """
    Compute the Jensen-Shannon Divergence and Kullback-Leibler Divergence between two class count dictionaries (Counter).
    Returns (jsd, kld)
    """
    # Get the union of all class labels
    all_keys = sorted(set(counter_p.keys()).union(set(counter_q.keys())))
    # Convert to arrays in the same order
    p = np.array([counter_p.get(k, 0) for k in all_keys], dtype=np.float64)
    q = np.array([counter_q.get(k, 0) for k in all_keys], dtype=np.float64)

    p = p + epsilon
    q = q + epsilon

    # Normalize the distributions
    p = p / np.sum(p) if np.sum(p) > 0 else p
    q = q / np.sum(q) if np.sum(q) > 0 else q

    # Compute the average distribution
    m = 0.5 * (p + q)

    # Compute the Kullback-Leibler divergence
    kl_p_m = np.sum(np.where(p != 0, p * np.log(p / m), 0))
    kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
    kl_p_q = np.sum(np.where(p != 0, p * np.log(p / q), 0))

    # Jensen-Shannon Divergence
    jsd = 0.5 * (kl_p_m + kl_q_m)
    return jsd

total_stratified_js = 0
total_random_js = 0

for i in range(num_tests):
    
    stratified_train_graph_dataset, stratified_test_graph_dataset = stratified_graph_daataset.graph_train_test_split(test_ratio=0.15, random_state=i)
    random_train_graph_dataset, random_test_graph_dataset = random_split_graph_dataset.graph_train_test_split(test_ratio=0.15, random_state=i)

    

    total_stratified_js += compute_js_divergence(
        stratified_train_graph_dataset.class_counts,
        stratified_test_graph_dataset.class_counts
    )

    total_random_js += compute_js_divergence(
        random_train_graph_dataset.class_counts,
        random_test_graph_dataset.class_counts
    )
    

print(f"Stratified Split Jensen-Shannon Divergence: {total_stratified_js / num_tests}")
print(f"Random Split Jensen-Shannon Divergence: {total_random_js / num_tests}")

Stratified Split Jensen-Shannon Divergence: 2.077331673164704e-05
Random Split Jensen-Shannon Divergence: 0.0012107016013186479
