In [None]:
'''
=====Experiment=====
Dataset: BoT-IoT dataset

Check Train Test Dataset Coverage for
- 1. Random Split
- 2. Stratified Split
'''

from torch_geometric.utils import from_networkx, add_self_loops, degree
from torch_geometric.nn import MessagePassing
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.loader import NeighborSampler
import torch.nn as nn
import torch as th
import torch.nn.functional as F
# import dgl.function as fn
import networkx as nx
import pandas as pd
import socket
import struct
import matplotlib.pyplot as plt
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
# import seaborn as sns
# import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from torch_geometric.loader import DataLoader


project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
sys.path.append(project_root)

from Datasets.BoT_IoT.BoT_IoT_config import BoT_IoT_Config

In [None]:
csv_file_name = "all_raw"

data = pd.read_csv(os.path.join(project_root, "Datasets", f"BoT_IoT/All/{csv_file_name}.csv"))

DATASET_NAME = "BoT_IoT"
EXPERIMENT_NAME = "check_train_test_split"

SOURCE_FILE_ID_COL_NAME = BoT_IoT_Config.SOURCE_FILE_ID_COL_NAME

SOURCE_IP_COL_NAME = BoT_IoT_Config.SOURCE_IP_COL_NAME
DESTINATION_IP_COL_NAME = BoT_IoT_Config.DESTINATION_IP_COL_NAME
SOURCE_PORT_COL_NAME = BoT_IoT_Config.SOURCE_PORT_COL_NAME
DESTINATION_PORT_COL_NAME = BoT_IoT_Config.DESTINATION_PORT_COL_NAME

ATTACK_CLASS_COL_NAME = BoT_IoT_Config.ATTACK_CLASS_COL_NAME

BENIGN_CLASS_NAME = BoT_IoT_Config.BENIGN_CLASS_NAME

TIME_COLS = BoT_IoT_Config.TIME_COL_NAMES

print(data[ATTACK_CLASS_COL_NAME].value_counts())

MULTICLASS = True

label_col = ATTACK_CLASS_COL_NAME


saves_path = os.path.join(project_root, "Models/E_GraphSAGE/logs", DATASET_NAME, EXPERIMENT_NAME)

checkpoint_path = os.path.join(saves_path, f"checkpoints_{csv_file_name}.pth")
best_model_path = os.path.join(saves_path, f"best_model_{csv_file_name}.pth")

os.makedirs(saves_path, exist_ok=True)

Label
BENIGN                        2273097
DoS Hulk                       231073
PortScan                       158930
DDoS                           128027
DoS GoldenEye                   10293
FTP-Patator                      7938
SSH-Patator                      5897
DoS slowloris                    5796
DoS Slowhttptest                 5499
Bot                              1966
Web Attack - Brute Force         1507
Web Attack - XSS                  652
Infiltration                       36
Web Attack - Sql Injection         21
Heartbleed                         11
Name: count, dtype: int64


In [None]:
data.drop(columns=BoT_IoT_Config.DROP_COLS,inplace=True)
print(data.columns)

Index(['Flow ID', 'Source IP', 'Source Port', 'Destination IP',
       'Destination Port', 'Protocol', 'Timestamp', 'Flow Duration',
       'Total Length of Fwd Packets', 'Fwd Packet Length Mean',
       'Fwd Packet Length Std', 'Bwd Packet Length Min',
       'Bwd Packet Length Std', 'Flow Packets/s', 'Flow IAT Mean',
       'Flow IAT Std', 'Flow IAT Min', 'Fwd IAT Min', 'Bwd IAT Mean',
       'Fwd PSH Flags', 'SYN Flag Count', 'PSH Flag Count', 'ACK Flag Count',
       'Average Packet Size', 'Fwd Header Length.1', 'Subflow Fwd Packets',
       'Subflow Fwd Bytes', 'Subflow Bwd Bytes', 'Init_Win_bytes_forward',
       'Active Mean', 'Active Min', 'Label', 'source_file_id'],
      dtype='object')


In [None]:
data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME].apply(str)
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME].apply(str)

# # Combine Port and IP
data[SOURCE_PORT_COL_NAME] = data[SOURCE_PORT_COL_NAME].apply(str)
data[DESTINATION_PORT_COL_NAME] = data[DESTINATION_PORT_COL_NAME].apply(str)

data[SOURCE_IP_COL_NAME] = data[SOURCE_IP_COL_NAME] + ':' + data[SOURCE_PORT_COL_NAME]
data[DESTINATION_IP_COL_NAME] = data[DESTINATION_IP_COL_NAME] + ':' + data[DESTINATION_PORT_COL_NAME]
data.drop(columns=[SOURCE_PORT_COL_NAME,DESTINATION_PORT_COL_NAME],inplace=True)

data = pd.get_dummies(data, columns = BoT_IoT_Config.CATEGORICAL_COLS) # One Hot Encoding for categorical data
converted_categorical_cols = [col for col in data.columns if col.startswith(tuple(BoT_IoT_Config.CATEGORICAL_COLS))]

In [7]:
print(data.head)

<bound method NDFrame.head of                                          Flow ID              Source IP  \
0        192.168.10.5-104.16.207.165-54865-443-6   104.16.207.165_0:443   
1          192.168.10.5-104.16.28.216-55054-80-6     104.16.28.216_0:80   
2          192.168.10.5-104.16.28.216-55055-80-6     104.16.28.216_0:80   
3        192.168.10.16-104.17.241.25-46236-443-6    104.17.241.25_0:443   
4        192.168.10.5-104.19.196.102-54863-443-6   104.19.196.102_0:443   
...                                          ...                    ...   
2830738  192.168.10.16-199.244.48.55-41926-443-6  192.168.10.16_7:41926   
2830739  192.168.10.16-199.244.48.55-41934-443-6  192.168.10.16_7:41934   
2830740  192.168.10.16-199.244.48.55-41932-443-6  192.168.10.16_7:41932   
2830741  192.168.10.16-199.244.48.55-41930-443-6  192.168.10.16_7:41930   
2830742  192.168.10.16-199.244.48.55-41928-443-6  192.168.10.16_7:41928   

                Destination IP      Timestamp  Flow Duration  \
0    

In [8]:
data = data.reset_index()
data.replace([np.inf, -np.inf], np.nan,inplace = True)
data.fillna(0,inplace = True)
data.drop(columns=['index'],inplace=True)
print(data.head)

<bound method NDFrame.head of                                          Flow ID              Source IP  \
0        192.168.10.5-104.16.207.165-54865-443-6   104.16.207.165_0:443   
1          192.168.10.5-104.16.28.216-55054-80-6     104.16.28.216_0:80   
2          192.168.10.5-104.16.28.216-55055-80-6     104.16.28.216_0:80   
3        192.168.10.16-104.17.241.25-46236-443-6    104.17.241.25_0:443   
4        192.168.10.5-104.19.196.102-54863-443-6   104.19.196.102_0:443   
...                                          ...                    ...   
2830738  192.168.10.16-199.244.48.55-41926-443-6  192.168.10.16_7:41926   
2830739  192.168.10.16-199.244.48.55-41934-443-6  192.168.10.16_7:41934   
2830740  192.168.10.16-199.244.48.55-41932-443-6  192.168.10.16_7:41932   
2830741  192.168.10.16-199.244.48.55-41930-443-6  192.168.10.16_7:41930   
2830742  192.168.10.16-199.244.48.55-41928-443-6  192.168.10.16_7:41928   

                Destination IP      Timestamp  Flow Duration  \
0    

In [None]:
scaler = StandardScaler()
cols_to_norm = BoT_IoT_Config.COLS_TO_NORM
print(data[cols_to_norm].describe()) # Check if there's any too large value

       Bwd Packet Length Min  Subflow Fwd Packets  \
count           2.830743e+06         2.830743e+06   
mean            4.104958e+01         9.361160e+00   
std             6.886260e+01         7.496728e+02   
min             0.000000e+00         1.000000e+00   
25%             0.000000e+00         2.000000e+00   
50%             0.000000e+00         2.000000e+00   
75%             7.700000e+01         5.000000e+00   
max             2.896000e+03         2.197590e+05   

       Total Length of Fwd Packets  Fwd Packet Length Mean  \
count                 2.830743e+06            2.830743e+06   
mean                  5.493024e+02            5.820194e+01   
std                   9.993589e+03            1.860912e+02   
min                   0.000000e+00            0.000000e+00   
25%                   1.200000e+01            6.000000e+00   
50%                   6.200000e+01            3.400000e+01   
75%                   1.870000e+02            5.000000e+01   
max                   1.29

In [None]:
def check_numeric_issues(df, cols_to_norm):
    for col in cols_to_norm:
        try:
            # Try to coerce to numeric
            df[col] = pd.to_numeric(df[col], errors='coerce')
            
            # Try to clip the column
            df[col] = df[col].clip(lower=-1e9, upper=1e9)
            
        except Exception as e:
            print(f"❌ Column '{col}' failed with error: {e}")
            print(f"  - Sample values: {df[col].dropna().unique()[:5]}")
            print(f"  - Data type: {df[col].dtype}")
            continue

    print("\n✅ All other columns processed successfully.")

check_numeric_issues(data, BoT_IoT_Config.COLS_TO_NORM)


✅ All other columns processed successfully.


In [11]:
data[cols_to_norm] = scaler.fit_transform(data[cols_to_norm])

In [12]:
from sklearn.preprocessing import LabelEncoder

num_classes = 2
class_map = [0, 1]
if MULTICLASS:
    le = LabelEncoder()
    attack_labels = le.fit_transform(data[ATTACK_CLASS_COL_NAME])
    class_map = le.classes_
    print(class_map)
    print("Attack label mapping:", dict(zip(class_map, range(len(class_map)))))
    data[ATTACK_CLASS_COL_NAME] = attack_labels
    num_classes = len(class_map)
    class_dict = {le.inverse_transform([i])[0]: i for i in range(len(le.classes_))}

BENIGN_CLASS_LABEL = le.transform([BENIGN_CLASS_NAME])[0] if MULTICLASS else 0
ADVERSARIAL_CLASS_LABEL = len(class_map)

['BENIGN' 'Bot' 'DDoS' 'DoS GoldenEye' 'DoS Hulk' 'DoS Slowhttptest'
 'DoS slowloris' 'FTP-Patator' 'Heartbleed' 'Infiltration' 'PortScan'
 'SSH-Patator' 'Web Attack - Brute Force' 'Web Attack - Sql Injection'
 'Web Attack - XSS']
Attack label mapping: {'BENIGN': 0, 'Bot': 1, 'DDoS': 2, 'DoS GoldenEye': 3, 'DoS Hulk': 4, 'DoS Slowhttptest': 5, 'DoS slowloris': 6, 'FTP-Patator': 7, 'Heartbleed': 8, 'Infiltration': 9, 'PortScan': 10, 'SSH-Patator': 11, 'Web Attack - Brute Force': 12, 'Web Attack - Sql Injection': 13, 'Web Attack - XSS': 14}


In [13]:
# # Maintain the order of the rows in the original dataframe

feature_cols = cols_to_norm + converted_categorical_cols

print('Feature Columns:', feature_cols)
num_features = len(feature_cols)
print('Number of Features:', num_features)

data['h'] = data[ feature_cols ].values.tolist()

Feature Columns: ['Bwd Packet Length Min', 'Subflow Fwd Packets', 'Total Length of Fwd Packets', 'Fwd Packet Length Mean', 'Total Length of Fwd Packets', 'Fwd Packet Length Std', 'Fwd IAT Min', 'Flow IAT Min', 'Flow IAT Mean', 'Bwd Packet Length Std', 'Subflow Fwd Bytes', 'Flow Duration', 'Flow IAT Std', 'Active Min', 'Active Mean', 'Bwd IAT Mean', 'Subflow Bwd Bytes', 'Init_Win_bytes_forward', 'ACK Flag Count', 'Fwd PSH Flags', 'SYN Flag Count', 'Flow Packets/s', 'PSH Flag Count', 'Average Packet Size', 'Protocol_0', 'Protocol_6', 'Protocol_17']
Number of Features: 27


In [14]:
def create_graph(df):

    G_nx = nx.from_pandas_edgelist(df, SOURCE_IP_COL_NAME, DESTINATION_IP_COL_NAME, ['h', label_col], create_using=nx.MultiDiGraph())
    G_pyg = from_networkx(G_nx)

    num_nodes = G_pyg.num_nodes
    num_edges = G_pyg.num_edges

    G_pyg.x = th.ones(num_nodes, len(df['h'].iloc[0])) 

    edge_attr_list = []
    edge_label_list = []

    for u, v, key, data in G_nx.edges(keys=True, data=True):
        edge_attr_list.append(data['h']) 
        edge_label_list.append(data[label_col]) 

    G_pyg.edge_attr = th.tensor(edge_attr_list, dtype=th.float32)
    G_pyg.edge_label = th.tensor(edge_label_list, dtype=th.long)

    return G_pyg

In [15]:
from collections import defaultdict
import heapq

import tqdm

class Downsampler:
    def __init__(self, downsample_classes=[BENIGN_CLASS_LABEL], downsample_ratios=[0.1]):
        """
        downsample_classes: list of class names to downsample
        downsample_ratio: keep no more than this ratio for each class
        """
        assert len(downsample_classes) == len(downsample_ratios)
        self.downsample_classes = downsample_classes
        self.downsample_ratio = downsample_ratios

    def downsample(self, label_counts_list, X, y):
        total_counts = defaultdict(int)

        class_heaps = defaultdict(list)
        for i, lc in enumerate(label_counts_list):
            for cls in self.downsample_classes:
                class_label_count = lc.get(cls, 0)
                total_counts[cls] += class_label_count
                heapq.heappush(class_heaps[cls], (-class_label_count, i))

        class_target = {
            cls: total_counts[cls] * self.downsample_ratio[i] for i, cls in enumerate(self.downsample_classes)
        }
                
        indices_to_remove = set()
        class_counts = total_counts

        # 3. For each class, remove top contributing samples until threshold reached
        for cls in self.downsample_classes:
            target = class_target[cls]
            heap = class_heaps[cls]

            pbar = tqdm(desc=f"Downsampling '{cls}'", total=len(heap))
            while class_counts[cls] > target and heap:
                _, idx = heapq.heappop(heap)
                if idx in indices_to_remove:
                    continue
                # For each class in this sample, if it's a downsample class, decrement the count
                for sample_cls, count in label_counts_list[idx].items():
                    if sample_cls in self.downsample_classes:
                        class_counts[sample_cls] -= count
                indices_to_remove.add(idx)
                pbar.update(1)
                pbar.set_postfix(class_label=cls, remaining=class_counts[cls], target=target)
            pbar.close()

        # 4. Apply filter
        keep_mask = [i for i in range(len(X)) if i not in indices_to_remove]
        X_new = [X[i] for i in keep_mask]
        y_new = [y[i] for i in keep_mask]

        return X_new, y_new

downsampler = Downsampler()

In [16]:
from collections import defaultdict
from typing import Counter
from sklearn.preprocessing import MultiLabelBinarizer

from tqdm import tqdm
class RandomSplitGraphDataset:

    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.total_count = len(self.y)

        # Compute class weights
        labels = []

        for graph in self.X:
            labels.append(graph.edge_label.tolist())

        labels = np.concatenate(labels)

        self.class_counts = Counter(labels)

        # Compute the class weights
        self.class_weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(labels),
            y=labels
        )

    def graph_train_test_split(self, test_ratio: float = 0.15, random_state: int = 42):
        train_idx, test_idx = train_test_split(
            np.arange(len(self.X)), 
            test_size=test_ratio, 
            random_state=random_state,
        )

        X_train = [self.X[i] for i in train_idx]
        X_test = [self.X[i] for i in test_idx]

        y_train = [self.y[i] for i in train_idx]
        y_test = [self.y[i] for i in test_idx]

        return RandomSplitGraphDataset(X_train, y_train), RandomSplitGraphDataset(X_test, y_test)


class StratifiedGraphDataset:

    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.total_count = len(self.y)

        # Compute class weights
        labels = []

        for graph in self.X:
            labels.append(graph.edge_label.tolist())

        labels = np.concatenate(labels)

        self.class_counts = Counter(labels)

        # Compute the class weights
        self.class_weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(labels),
            y=labels
        )

    def k_fold_split(self, k: int = 5, test_ratio: float = 0.15, random_state: int = 42):
        cv = MultilabelStratifiedShuffleSplit(test_size=test_ratio, random_state=random_state, n_splits=k)

        mlb = MultiLabelBinarizer()

        y_binary = mlb.fit_transform(self.y)

        return cv.split(np.zeros(len(self.X)), y_binary)

    def graph_train_test_split(self, test_ratio: float = 0.15, random_state: int = 42):
        train_idx, test_idx = next(self.k_fold_split(k = 1, test_ratio = test_ratio, random_state = random_state))
        
        X_train = [self.X[i] for i in train_idx]
        X_test = [self.X[i] for i in test_idx]

        y_train = [self.y[i] for i in train_idx]
        y_test = [self.y[i] for i in test_idx]

        return StratifiedGraphDataset(X_train, y_train), StratifiedGraphDataset(X_test, y_test)
    
    def print_class_distribution_and_weights(self):
        # Use the label encoder to inverse transform the class labels
        class_counts_named = {cls: count for cls, count in self.class_counts.items()}
        class_weights_named = {cls: weight for cls, weight in enumerate(self.class_weights)}
        print("Class Counts and Weights:")
        for cls_label in class_counts_named.keys():
            count = class_counts_named[cls_label]
            weight = class_weights_named[cls_label]
            print(f"{cls_label:<2}  {le.inverse_transform([cls_label])[0]:<15}: Count = {count:<10}, Weight = {weight:<10.4f}")
    
    def __len__(self):
        return self.total_count

    def __iter__(self):
        for g in self.X:
            yield g
            
    def __getitem__(self, idx):
        if isinstance(idx, int):
            return self.X[idx], self.y[idx]
        elif isinstance(idx, slice):
            return [self.X[i] for i in range(len(self.X))][idx], [self.y[i] for i in range(len(self.y))][idx]
        else:
            raise TypeError("Index must be an integer or a slice.")

def generate_graph_datasets(
    df: pd.DataFrame, 
    window_size: int = 200, 
    # overlap_ratio: float = 0, 
    feature_cols=feature_cols,
    ordering_cols=[SOURCE_FILE_ID_COL_NAME] + TIME_COLS + [ATTACK_CLASS_COL_NAME], 
    label_col=label_col,
    build_graph_func=create_graph,
    downsampler=downsampler
    ):

    print("All Columns: ", df.columns)
    print("Ordering Columns: ", ordering_cols)
    assert all(col in df.columns for col in ordering_cols), "All timestamp columns are required"
    assert label_col in df.columns, "Edge label column 'label' is required"
    
    df = df.sort_values(ordering_cols).reset_index(drop=True)
    window_size = int(window_size)
    
    df.drop(columns=set(df.columns) - set(feature_cols) - set(label_col))

    print("Final Columns: ", df.columns)
    
    label_counts_list = []
    X = []
    y = []

    progress_bar = tqdm(range(0, len(df), window_size), desc=f"Generating graphs")
    for start in progress_bar:
        window_df = df[start: min(start + window_size, len(df))]
        contains_label = window_df[label_col].unique()

        G_pyg = build_graph_func(window_df)

        label_counts = window_df[label_col].value_counts()

        label_counts_list.append(label_counts)
        X.append(G_pyg)
        y.append(contains_label.tolist())

    X, y = downsampler.downsample(label_counts_list, X, y)
        
    return X, y

In [17]:
X, y = generate_graph_datasets(data)

All Columns:  Index(['Flow ID', 'Source IP', 'Destination IP', 'Timestamp', 'Flow Duration',
       'Total Length of Fwd Packets', 'Fwd Packet Length Mean',
       'Fwd Packet Length Std', 'Bwd Packet Length Min',
       'Bwd Packet Length Std', 'Flow Packets/s', 'Flow IAT Mean',
       'Flow IAT Std', 'Flow IAT Min', 'Fwd IAT Min', 'Bwd IAT Mean',
       'Fwd PSH Flags', 'SYN Flag Count', 'PSH Flag Count', 'ACK Flag Count',
       'Average Packet Size', 'Fwd Header Length.1', 'Subflow Fwd Packets',
       'Subflow Fwd Bytes', 'Subflow Bwd Bytes', 'Init_Win_bytes_forward',
       'Active Mean', 'Active Min', 'Label', 'source_file_id', 'Protocol_0',
       'Protocol_6', 'Protocol_17', 'h'],
      dtype='object')
Ordering Columns:  ['source_file_id', 'Timestamp', 'Label']
Final Columns:  Index(['Flow ID', 'Source IP', 'Destination IP', 'Timestamp', 'Flow Duration',
       'Total Length of Fwd Packets', 'Fwd Packet Length Mean',
       'Fwd Packet Length Std', 'Bwd Packet Length Min',
   

Generating graphs: 100%|██████████| 14154/14154 [01:47<00:00, 131.19it/s]
Downsampling '0':  72%|███████▏  | 10229/14154 [00:05<00:02, 1819.18it/s, class_label=0, remaining=227297, target=2.27e+5]


In [18]:

stratified_graph_daataset = StratifiedGraphDataset(X, y)
random_split_graph_dataset = RandomSplitGraphDataset(X, y)


In [22]:
num_tests = 100
for i in range(num_tests):
    
    stratified_train_graph_dataset, stratified_test_graph_dataset = stratified_graph_daataset.graph_train_test_split(test_ratio=0.15, random_state=i)
    random_train_graph_dataset, random_test_graph_dataset = random_split_graph_dataset.graph_train_test_split(test_ratio=0.15, random_state=i)

    # Compute the Jensen-Shannon Divergence of train and test datasets
    def compute_js_divergence(counter_p, counter_q):
        """
        Compute the Jensen-Shannon Divergence between two class count dictionaries (Counter).
        """
        # Get the union of all class labels
        all_keys = sorted(set(counter_p.keys()).union(set(counter_q.keys())))
        # Convert to arrays in the same order
        p = np.array([counter_p.get(k, 0) for k in all_keys], dtype=np.float64)
        q = np.array([counter_q.get(k, 0) for k in all_keys], dtype=np.float64)

        # Normalize the distributions
        p = p / np.sum(p) if np.sum(p) > 0 else p
        q = q / np.sum(q) if np.sum(q) > 0 else q

        # Compute the average distribution
        m = 0.5 * (p + q)

        # Compute the Kullback-Leibler divergence
        kl_p_m = np.sum(np.where(p != 0, p * np.log(p / m), 0))
        kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))

        # Return the Jensen-Shannon Divergence
        jsd = 0.5 * (kl_p_m + kl_q_m)
        return jsd


    def KL_divergence(counter_p, counter_q):
        """
        Compute the Kullback-Leibler Divergence between two distributions.
        """
        # Get the union of all class labels
        all_keys = sorted(set(counter_p.keys()).union(set(counter_q.keys())))

        p = np.array([counter_p.get(k, 0) for k in all_keys], dtype=np.float64)
        q = np.array([counter_q.get(k, 0) for k in all_keys], dtype=np.float64)

        # Normalize the distributions
        p = p / np.sum(p) if np.sum(p) > 0 else p
        q = q / np.sum(q) if np.sum(q) > 0 else q

        print("KL-Divergence: ", np.sum(np.where(p != 0, p * np.log(p / q), 0))) 

    total_stratified_js = 0
    total_random_js = 0


    total_stratified_js += compute_js_divergence(
        stratified_train_graph_dataset.class_counts,
        stratified_test_graph_dataset.class_counts
    )

    total_random_js += compute_js_divergence(
        random_train_graph_dataset.class_counts,
        random_test_graph_dataset.class_counts
    )
    

print(f"Stratified Split Jensen-Shannon Divergence: {total_stratified_js / num_tests}")
print(f"Random Split Jensen-Shannon Divergence: {total_random_js / num_tests}")

  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.

Stratified Split Jensen-Shannon Divergence: 1.5798754262960515e-06
Random Split Jensen-Shannon Divergence: 1.1077078595328155e-05


  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
  kl_q_m = np.sum(np.where(q != 0, q * np.log(q / m), 0))
