# Imports

In [None]:
import data_handling
from torch_geometric.transforms import AddSelfLoops
import os
import torch
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import json
import numpy as np
from util_scripts import gnn_architectures
from util_scripts.gnn_training import train_batch, validate_batch, calculate_multiclass_metrics, calculate_multiclass_test_metrics, export_pretty_confusion_matrix, deduplicate_multiclass_sliding_window_results, balanced_temporal_undersampler
import pandas as pd
import time
import pickle

# Run Experiments

## Experiment 1: General GNNs from scratch Parameter Optimization

### Experiment Parameters

In [9]:
# Paths
RAW_DATA_PATH = 'data/raw'
LANDED_DATA_PATH = 'data/landed'
INGESTED_DATA_PATH = 'data/ingested'
UTILS_PATH = 'data/utils'
SAVED_MODELS_PATH = 'saved_models'
CONFIG_PATH = 'configs'

# General parameters
dataset_name = 'NF_ToN_IoT' # Pick from 'NF_ToN_IoT', 'NF_BoT_IoT' and 'NF_UNSW_NB15'
truncate = True
gnn_type = 'temporal' # Pick from 'temporal' and 'static'
temporal = True if gnn_type == 'temporal' else False

# Training parameters
batch_size = 10
num_epochs = 2
weighted_loss = True

# Model Parameters. Make list of all the options you want to try
gnn_layer_options = [2] # [2,3]
window_size_options = [10, 1] # [10, 30]
self_loops_options = [False]
save_epoch_every = 1

# Specific for if temporal
window_memory_options = [3, 5] # [3, 5]
flow_memory = 20

### Experiment Runs

In [10]:
data_processor = data_handling.DataPreprocessor(INGESTED_DATA_PATH, UTILS_PATH)
graph_builder = data_handling.GraphBuilder()

attack_mapping = data_processor.load_attack_mapping(dataset_name)
train_raw, val_raw, = data_processor.load_mixed_train(dataset_name), data_processor.load_mixed_val(dataset_name)
(train_attrs, train_labels), (val_attrs, val_labels) = data_processor.preprocess_NF(dataset_name, train_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=truncate), \
                    data_processor.preprocess_NF(dataset_name, val_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=truncate)

# If this dataset, we undersampxle benign flows a bit for performance. Only on training set though, not on eval!
if dataset_name == 'NF_UNSW_NB15':
    train_attrs, train_labels = data_processor.randomly_drop_benign_flows(train_attrs, train_labels, 0.7)

for gnn_layers in gnn_layer_options:
    gnn_hidden_channels = 128
    classifier_layers = 2
    classifier_hidden_channels = 128

    for window_size in window_size_options:
        for window_memory in window_memory_options:
            window_stride = window_size
            # batch_size = max(1,int((1/(window_size*window_memory))*200)+10) # A metric for having small enough batch size when windows are getting big
            for self_loops in self_loops_options:

                features = train_attrs.columns
                features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]
                train_windows, val_windows = graph_builder.time_window_with_flow_duration(train_attrs, window_size, window_stride), \
                                                            graph_builder.time_window_with_flow_duration(val_attrs, window_size, window_stride)
                if gnn_type == 'temporal':
                    train_graphs, _ = graph_builder.build_spatio_temporal_pyg_graphs(train_windows, train_attrs, train_labels, window_memory, flow_memory, False, features, attack_mapping, True)
                    val_graphs, val_window_indices_for_classification = graph_builder.build_spatio_temporal_pyg_graphs(val_windows, val_attrs, val_labels, window_memory, flow_memory, False, features, attack_mapping, True)
                elif gnn_type == 'static':
                    train_graphs = graph_builder.build_static_pyg_graphs(train_windows, train_attrs, train_labels, False, features, attack_mapping, True)
                    val_graphs = graph_builder.build_static_pyg_graphs(val_windows, val_attrs, val_labels, False, features, attack_mapping, True)

                metadata = train_graphs[0].metadata()
                sample_graph = train_graphs[0]

                if gnn_type == 'temporal':
                    gnn_base = gnn_architectures.TemporalPlusConv_v2(metadata, gnn_hidden_channels, gnn_layers)
                elif gnn_type == 'static':
                    gnn_base = gnn_architectures.SAGE(metadata, gnn_hidden_channels, gnn_layers)
                else:
                    raise ValueError('Unknown GNN type')

                model = gnn_architectures.multiclass_NIDS_model(gnn_base, len(attack_mapping), classifier_hidden_channels, classifier_layers, temporal)

                os.makedirs(os.path.join(SAVED_MODELS_PATH, dataset_name, 'experiments', gnn_type), exist_ok=True)
                model_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'experiments', gnn_type)
                experiment_idx = len(os.listdir(model_dir))
                experiment_dir = os.path.join(model_dir, f'experiment_{experiment_idx}')
                os.makedirs(experiment_dir)
                experiment_dict = {'dataset': dataset_name, 'gnn_type': gnn_type, 'gnn_layers': gnn_layers, 'gnn_hidden_channels': gnn_hidden_channels, 'classifier_layers': classifier_layers, 'classifier_hidden_channels': classifier_hidden_channels, 'self_loops': self_loops, 'window_size': window_size, 'window_stride': window_stride, 'include_port': False, 'window_memory': window_memory, 'batch_size': batch_size, 'num_epochs': num_epochs, 'weighted_loss': weighted_loss, 'truncate': truncate, 'flow_memory': flow_memory}
                train_losses = []
                val_losses = []
                train_weighted_f1 = []
                val_weighted_f1 = []
                train_macro_f1 = []
                val_macro_f1 = []

                # Setup optimizer
                optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
                loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
                loader_val = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)

                # Set model to device
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                model.to(device)

                # Setup Loss Criterion
                if weighted_loss:
                    num_classes = len(attack_mapping)
                    class_counts = np.zeros(num_classes)
                    for batch in loader:
                        target = batch['con'].y
                        class_counts += target.sum(dim=0).cpu().numpy()

                    # print(f'class counts in training data: {attack_mapping.keys()}:{class_counts}')
                    total_samples = class_counts.sum()
                    class_weights = total_samples / (num_classes * class_counts)
                    weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
                    criterion = torch.nn.CrossEntropyLoss(weight=weights)

                else:
                    criterion = torch.nn.CrossEntropyLoss()

                # Train the model
                best_model_weights_macro = None
                best_model_weights_weighted = None
                best_model_val_f1_macro = 0
                best_model_val_f1_weighted = 0
                best_model_weights_macro_epoch = 0
                best_model_weights_weighted_epoch = 0

                for epoch in tqdm(range(num_epochs)):
                    total_train_loss = 0
                    epoch_preds = np.array([])
                    epoch_targets = np.array([])
                    for batch in loader:
                        train_data = batch.to(device)
                        batch_loss, batch_preds, batch_targets = train_batch(model, train_data, optimizer, criterion)
                        total_train_loss += batch_loss
                        epoch_preds = np.concatenate((epoch_preds, batch_preds))
                        epoch_targets = np.concatenate((epoch_targets, batch_targets))

                    # Calculate metrics
                    epoch_train_accuracy, epoch_train_f1_weighted, epoch_train_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
                    train_losses.append(total_train_loss)
                    train_weighted_f1.append(epoch_train_f1_weighted)
                    train_macro_f1.append(epoch_train_f1_macro)
                    print('Epoch:', epoch, 'Train Loss:', total_train_loss, 'Train Accuracy:', epoch_train_accuracy.item(), 'Train Multiclass Weighted F1:', epoch_train_f1_weighted.item())

                    total_val_loss = 0
                    epoch_preds = np.array([])
                    epoch_targets = np.array([])
                    for batch in loader_val:
                        val_data = batch.to(device)
                        batch_loss, batch_preds, batch_targets = validate_batch(model, val_data, criterion)
                        total_val_loss += batch_loss
                        epoch_preds = np.concatenate((epoch_preds, batch_preds))
                        epoch_targets = np.concatenate((epoch_targets, batch_targets))

                    val_accuracy, val_f1_weighted, val_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
                    val_losses.append(total_val_loss)
                    val_weighted_f1.append(val_f1_weighted)
                    val_macro_f1.append(val_f1_macro)
                    if val_f1_macro > best_model_val_f1_macro:
                        best_model_val_f1_macro = val_f1_macro
                        best_model_weights_macro = model.state_dict()
                        best_model_weights_macro_epoch = epoch
                    if val_f1_weighted > best_model_val_f1_weighted:
                        best_model_val_f1_weighted = val_f1_weighted
                        best_model_weights_weighted = model.state_dict()
                        best_model_weights_weighted_epoch = epoch
                    if (epoch % save_epoch_every == 0) and (epoch != 0):
                        torch.save(model.state_dict, os.path.join(experiment_dir, f'model_weights_checkpoint_epoch_{epoch}.pth'))

                    print('Validation Loss:', total_val_loss, 'Validation Accuracy:', val_accuracy.item(), 'Validation Multiclass Weighted F1:', val_f1_weighted.item(), 'Validation Multiclass Macro F1:', val_f1_macro.item())

                # Save the best models
                torch.save(best_model_weights_macro, os.path.join(experiment_dir, f'best_model_weights_macro_f1_epoch_{best_model_weights_macro_epoch}.pth'))
                torch.save(best_model_weights_weighted, os.path.join(experiment_dir, f'best_model_weights_weighted_f1_epoch_{best_model_weights_weighted_epoch}.pth'))

                # Save the experiment metadata
                experiments_results = {'train_losses': train_losses, 'val_losses': val_losses, 'train_weighted_f1': train_weighted_f1, 'val_weighted_f1': val_weighted_f1, 'train_macro_f1': train_macro_f1, 'val_macro_f1': val_macro_f1}
                with open(os.path.join(experiment_dir, 'results.json'), 'w') as f:
                    json.dump(experiments_results, f)
                with open(os.path.join(experiment_dir, 'experiment_metadata.json'), 'w') as f:
                    json.dump(experiment_dict, f)

- Loading in mixed training set...
- Loading in mixed validation set...
-- Min-Max scaling numerical columns...
-- Min-Max scaling numerical columns...


100%|██████████| 2390/2390 [01:48<00:00, 22.08it/s]
100%|██████████| 586/586 [00:33<00:00, 17.30it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[168716   7729      0  13249     88   1627   5093   4889      0      0]
 [  1011  24780      0      1     20      1    286      3      0      0]
 [ 19561   3355      0   2564    184      0    120   2716      0      0]
 [  8510    529      0  52605      0    316   1528     93      0      0]
 [ 12635    273      0  19625      0   1267  44407  15050      0      0]
 [   556    121      0    378      0    512    136    412      0      0]
 [ 38571    985      0  19517    303    196  31220   9626      0      0]
 [  2379    296      0    100      0    101    124   1244      0      0]
 [ 14411      0      0      0      0      0      0      0      0      0]
 [  6605      0      0    818      0      0   3369    872      0      0]]
Epoch: 0 Train Loss: 463.7584934234619 Train Accuracy: 0.5114269640065753 Train Multiclass Weighted F1: 0.43325596069296
dict_keys(['BENIGN', 'backdoor', 

 50%|█████     | 1/2 [00:26<00:26, 26.62s/it]

[[41844     7     0  2775     0  2858  8479  2294     0     0]
 [   43  7716     0     0     0     4     0     1     0     0]
 [    6     0     0    14     0     4  8599    26     0     0]
 [    2     0     0 25337     0     5     0     0     0     0]
 [    1     0     0   428     0  1601 33722  2621     0     0]
 [   13     0     0     3     0   423     0    99     0     0]
 [    0     0     0     8     0    64 38843   380     0     0]
 [   40     0     0     0     0    13     5   782     0     0]
 [    4     0     0    35     0     0  7771     0     0     0]
 [   63     0     0   536     0    27  2164  1753     0     0]]
Validation Loss: 94.26711678504944 Validation Accuracy: 0.6005078025003526 Validation Multiclass Weighted F1: 0.5336303263300253 Validation Multiclass Macro F1: 0.3651171142762194
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[177189     83      0   2532      1   5109  13537   2940      0      0]
 

100%|██████████| 2/2 [00:51<00:00, 25.75s/it]

[[42519    52     0  1578  2293  4249  6166  1400     0     0]
 [   30  7714     0     1     0    17     0     2     0     0]
 [ 1499     0     0     2  5174     2  1972     0     0     0]
 [    0     0     0 25343     0     1     0     0     0     0]
 [    0     0     0    10 29240   948  7756   419     0     0]
 [    1     0     0     0     0   525     0    12     0     0]
 [    0     0     0     0  6384    24 32865    22     0     0]
 [   10     0     0     0     0    24     4   802     0     0]
 [   78     0     0     9   217     7  7499     0     0     0]
 [  559     0     0   303  1528  1415   691    47     0     0]]
Validation Loss: 92.14862620830536 Validation Accuracy: 0.7262202671709863 Validation Multiclass Weighted F1: 0.7027551308313615 Validation Multiclass Macro F1: 0.4757027702919466



100%|██████████| 1434/1434 [01:39<00:00, 14.40it/s]
100%|██████████| 352/352 [00:29<00:00, 11.78it/s]
  0%|          | 0/2 [00:13<?, ?it/s]


KeyboardInterrupt: 

## Experiment 2: GNNs from pre-trained

#### Experiment Parameters

In [13]:
# Paths
RAW_DATA_PATH = 'data/raw'
LANDED_DATA_PATH = 'data/landed'
INGESTED_DATA_PATH = 'data/ingested'
UTILS_PATH = 'data/utils'
SAVED_MODELS_PATH = 'saved_models'
CONFIG_PATH = 'configs'

# General parameters
dataset_name = 'NF_ToN_IoT' # 'NF_UNSW_NB15', 'NF_ToN_IoT', 'NF_BoT_IoT'
pretraining_strategy = 'no_pretraining' # 'in_context','out_context','no_pretraining'
undersample_fracs = [0.05] # [0.1, 0.2, 0.5, 0.8, 1.0] # K-shot learning

# Pretrained model parameters (which one to select)
experiment_idx = 0
checkpoint_idx = 2

# Training parameters
batch_size = 10
num_epochs = 5
learning_rate = 0.001
weighted_loss = True
classifier_layers = 2
classifier_hidden_channels = 128
truncate = True # Truncate the extreme numerical values in standardization

# For if pretraining_strategy=='no_pretraining'. Else, ignore (we'll load the best model from the pretraining)
flow_memory = 20
window_size = 5
window_memory = 5
window_stride = 5
use_ports=False
self_loops=False
gnn_hidden_channels = 128
gnn_type = 'temporal' # 'temporal', 'static'

#### Experiment Runs

In [14]:
data_processor = data_handling.DataPreprocessor(INGESTED_DATA_PATH, UTILS_PATH)
graph_builder = data_handling.GraphBuilder()

if pretraining_strategy != 'no_pretraining':
  pre_trained_gnn_dir = f'{SAVED_MODELS_PATH}/{dataset_name}/pretraining_experiments/{gnn_type}/experiment_{experiment_idx}'
  weights_path = f'{pre_trained_gnn_dir}/checkpoint_{checkpoint_idx}_gnnbase.pt'
  metadata_path = f'{pre_trained_gnn_dir}/experiment_metadata.json'
  with open(metadata_path, 'r') as f:
      metadata = json.load(f)

  gnn_type, flow_memory, gnn_layers, window_size, window_memory, include_port, self_loops, gnn_hidden_channels, classifier_layers, classifier_hidden_channels = metadata["graph_type"],metadata["flow_memory"],metadata["gnn_layer"],metadata["window_size"],metadata["window_memory"],metadata["include_port"],metadata["self_loops"],metadata["gnn_hidden_channels"], metadata["classifier_layers"], metadata["classifier_hidden_channels"]
  window_stride = window_size
  temporal = True if gnn_type == 'temporal' else False

attack_mapping = data_processor.load_attack_mapping(dataset_name)
truncate = True
val_raw, test_raw = data_processor.load_mixed_val(dataset_name), data_processor.load_mixed_test(dataset_name)

if pretraining_strategy != 'in_context':
  (val_attrs, val_labels), (test_attrs, test_labels) = data_processor.preprocess_NF('all', val_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True), \
                            data_processor.preprocess_NF('all', test_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

else:
  (val_attrs, val_labels), (test_attrs, test_labels) = data_processor.preprocess_NF(dataset_name, val_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True), \
                            data_processor.preprocess_NF(dataset_name, test_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

features = val_attrs.columns
features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]
val_windows = graph_builder.time_window_with_flow_duration(val_attrs, window_size, window_stride)
if gnn_type == 'temporal':
    val_graphs, val_window_indices_for_classification = graph_builder.build_spatio_temporal_pyg_graphs(val_windows, val_attrs, val_labels, window_memory, flow_memory, include_port, features, attack_mapping, True)
elif gnn_type == 'static':
    val_graphs = graph_builder.build_static_pyg_graphs(val_windows, val_attrs, val_labels, include_port, features, attack_mapping, True)
else:
    raise ValueError('GNN type not recognized!')
if self_loops:
    val_graph = [AddSelfLoops()(graph) for graph in val_graphs]

# Undersample the training data for K-shot learning
os.makedirs(os.path.join(SAVED_MODELS_PATH, dataset_name, 'fine_tuned_experiments', gnn_type), exist_ok=True)
experiment_idx = len(os.listdir(os.path.join(SAVED_MODELS_PATH, dataset_name, 'fine_tuned_experiments', gnn_type)))

for undersample_frac in undersample_fracs:
    train_raw = data_processor.load_mixed_train(dataset_name)

    if pretraining_strategy != 'in_context':
        train_attrs, train_labels = data_processor.preprocess_NF('all', train_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

    else:
       train_attrs, train_labels= data_processor.preprocess_NF(dataset_name, train_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

    train_attrs, train_labels, practical_fraction = balanced_temporal_undersampler(train_attrs, train_labels, undersample_frac)

    features = train_attrs.columns
    features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]

    train_windows = graph_builder.time_window_with_flow_duration(train_attrs, window_size, window_stride)

    if gnn_type == 'temporal':
        train_graphs, _ = graph_builder.build_spatio_temporal_pyg_graphs(train_windows, train_attrs, train_labels, window_memory, flow_memory, include_port, features, attack_mapping, True)
    elif gnn_type == 'static':
        train_graphs = graph_builder.build_static_pyg_graphs(train_windows, train_attrs, train_labels, include_port, features, attack_mapping, True)
    else:
        raise ValueError('GNN type not recognized!')

    if self_loops:
        train_graph = [AddSelfLoops()(graph) for graph in train_graphs]

    metadata = train_graphs[0].metadata()
    sample_graph = train_graphs[0]

    if gnn_type == 'temporal':
        gnn_base = gnn_architectures.TemporalPlusConv_v2(metadata, gnn_hidden_channels, gnn_layers)
    elif gnn_type == 'static':
        gnn_base = gnn_architectures.SAGE(metadata, gnn_hidden_channels, gnn_layers)
    else:
        raise ValueError('Unknown GNN type')

    if pretraining_strategy != 'no_pretraining':
      # Load the pre-trained model
      with torch.no_grad():  # Initialize lazy modules.
          out = gnn_base(sample_graph.x_dict, sample_graph.edge_index_dict)
      gnn_base.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
      print(f'Pretrained GNN loaded successfully!')

    model = gnn_architectures.multiclass_NIDS_model(gnn_base, len(attack_mapping), classifier_hidden_channels, classifier_layers)

    train_losses = []
    val_losses = []
    train_weighted_f1 = []
    val_weighted_f1 = []
    train_macro_f1 = []
    val_macro_f1 = []

    # Setup optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Setup dataloader
    loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    loader_val = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Setup Loss Criterion
    if weighted_loss:
        num_classes = len(attack_mapping)
        class_counts = np.zeros(num_classes)
        for batch in loader:
            target = batch['con'].y
            class_counts += target.sum(dim=0).cpu().numpy()

        # print(f'class counts in training data: {attack_mapping.keys()}:{class_counts}')
        total_samples = class_counts.sum()
        class_weights = total_samples / (num_classes * class_counts)
        # print(f'Class Weights: {attack_mapping.keys()}')
        # print(class_weights)
        weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
        criterion = torch.nn.CrossEntropyLoss(weight=weights)

    else:
        criterion = torch.nn.CrossEntropyLoss()

    # Set model to device
    model.to(device)

    # Train the model
    best_model_weights = None
    best_model_val_f1 = 0
    best_model_weights_epoch = 0

    # get the start time
    st = time.time()

    for epoch in tqdm(range(num_epochs)):
        total_train_loss = 0
        epoch_preds = np.array([])
        epoch_targets = np.array([])
        for batch in loader:
            train_data = batch.to(device)
            batch_loss, batch_preds, batch_targets = train_batch(model, train_data, optimizer, criterion)
            total_train_loss += batch_loss
            epoch_preds = np.concatenate((epoch_preds, batch_preds))
            epoch_targets = np.concatenate((epoch_targets, batch_targets))

        # Calculate metrics
        epoch_train_accuracy, epoch_train_f1_weighted, epoch_train_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
        train_losses.append(total_train_loss)
        train_weighted_f1.append(epoch_train_f1_weighted)
        train_macro_f1.append(epoch_train_f1_macro)
        print('Epoch:', epoch, 'Train Loss:', total_train_loss, 'Train Accuracy:', epoch_train_accuracy.item(), 'Train Multiclass Weighted F1:', epoch_train_f1_weighted.item(), 'Train Multiclass Macro F1:', epoch_train_f1_macro.item())

        total_val_loss = 0
        epoch_preds = np.array([])
        epoch_targets = np.array([])
        for batch in loader_val:
            val_data = batch.to(device)
            batch_loss, batch_preds, batch_targets = validate_batch(model, val_data, criterion)
            total_val_loss += batch_loss
            epoch_preds = np.concatenate((epoch_preds, batch_preds))
            epoch_targets = np.concatenate((epoch_targets, batch_targets))

        val_accuracy, val_f1_weighted, val_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
        val_losses.append(total_val_loss)
        val_weighted_f1.append(val_f1_weighted)
        val_macro_f1.append(val_f1_macro)
        if val_f1_macro > best_model_val_f1:
            best_model_val_f1 = val_f1_macro
            best_model_weights = model.state_dict()
            best_model_weights_epoch = epoch
        print('Validation Loss:', total_val_loss, 'Validation Accuracy:', val_accuracy.item(), 'Validation Multiclass Weighted F1:', val_f1_weighted.item(), 'Validation Multiclass Macro F1:', val_f1_macro.item())

    # get the end time
    et = time.time()
    # get the execution time
    elapsed_time = et - st
    print('Execution time:', elapsed_time, 'seconds')

    experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'fine_tuned_experiments', gnn_type, f'experiment_{experiment_idx}', str(undersample_frac))
    os.makedirs(experiment_dir, exist_ok=True)
    experiment_dict = {'dataset': dataset_name, 'K-shot-dataset_frac': practical_fraction, 'pretrain_strategy': pretraining_strategy, 'gnn_type': gnn_type, 'gnn_layers': gnn_layers, 'gnn_hidden_channels': gnn_hidden_channels, 'classifier_layers': classifier_layers, 'classifier_hidden_channels': classifier_hidden_channels, 'self_loops': self_loops, 'window_size': window_size, 'window_stride': window_stride, 'window_memory': window_memory, 'batch_size': batch_size, 'num_epochs': num_epochs, 'execution_time': elapsed_time}

    # Save the best model
    torch.save(best_model_weights, os.path.join(experiment_dir, f'best_model_weights_epoch_{best_model_weights_epoch}.pth'))

    # Save the experiment metadata
    experiments_results = {'train_losses': train_losses, 'val_losses': val_losses, 'train_weighted_f1': train_weighted_f1, 'val_weighted_f1': val_weighted_f1, 'train_macro_f1': train_macro_f1, 'val_macro_f1': val_macro_f1}
    with open(os.path.join(experiment_dir, 'results.json'), 'w') as f:
        json.dump(experiments_results, f)
    with open(os.path.join(experiment_dir, 'experiment_metadata.json'), 'w') as f:
        json.dump(experiment_dict, f)

- Loading in mixed validation set...
- Loading in mixed test set...


  ohe[f'{attribute}_{value}'] = 0


-- Min-Max scaling numerical columns...
-- Min-Max scaling numerical columns...


100%|██████████| 694/694 [00:22<00:00, 31.19it/s]


- Loading in mixed training set...
-- Min-Max scaling numerical columns...
Original Dataset Subsampled in balanced temporal way to 0.04997435531577669 % of the original dataset


100%|██████████| 134/134 [00:03<00:00, 35.69it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[   0    0  965    0    0 1557 6272    0    0    0]
 [   0    0   66    0    0  255 1069    0    0    0]
 [   0    0    0    0    0   41 2447    0    0    0]
 [   0    0  302    0    0  951  580    0    0    0]
 [   0    0    0    0    0    0 2763    0    0    0]
 [   0    0  112    0    0   74  769    0    0    0]
 [   0    0 1477    0    0    1 1121    0    0    0]
 [   0    0  112    0    0  254 1192    0    0    0]
 [   0    0    0    0    0    0 1298    0    0    0]
 [   0    0    0    0    0   62 2646    0    0    0]]
Epoch: 0 Train Loss: 32.24578642845154 Train Accuracy: 0.04528916849844614 Train Multiclass Weighted F1: 0.010995236880890913 Train Multiclass Macro F1: 0.01341861169409577


 20%|██        | 1/5 [00:02<00:11,  2.99s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[    0     0     0     0     0 61700     0     0     0     0]
 [    0     0     0     0     0 10364     0     0     0     0]
 [    0     0     0     0     0  8721     0     0     0     0]
 [    0     0     0     0     0 25716     0     0     0     0]
 [    0     0     0     0     0 39319     0     0     0     0]
 [    0     0     0     0     0   769     0     0     0     0]
 [    0     0     0     0     0 40186     0     0     0     0]
 [    0     0     0     0     0   879     0     0     0     0]
 [    0     0     0     0     0 10175     0     0     0     0]
 [    0     0     0     0     0  4789     0     0     0     0]]
Validation Loss: 160.66599893569946 Validation Accuracy: 0.0037953192707459358 Validation Multiclass Weighted F1: 2.8699971180101235e-05 Validation Multiclass Macro F1: 0.0007561938570311771
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'm

 40%|████      | 2/5 [00:05<00:08,  2.87s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[ 2053 59212     0     0     0   435     0     0     0     0]
 [    1 10363     0     0     0     0     0     0     0     0]
 [    2  8719     0     0     0     0     0     0     0     0]
 [  978 24191     0     0     0   547     0     0     0     0]
 [  532 25004     0     0     0 13783     0     0     0     0]
 [    8   744     0     0     0    17     0     0     0     0]
 [  267 25965     0     0     0 13954     0     0     0     0]
 [   73   727     0     0     0    79     0     0     0     0]
 [   13 10162     0     0     0     0     0     0     0     0]
 [  598  2347     0     0     0  1844     0     0     0     0]]
Validation Loss: 159.4394612312317 Validation Accuracy: 0.06136177437345152 Validation Multiclass Weighted F1: 0.024846830372379048 Validation Multiclass Macro F1: 0.017965308296023517
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 

 60%|██████    | 3/5 [00:09<00:06,  3.09s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[30693 30923     0     0     0    14     0    70     0     0]
 [    3 10361     0     0     0     0     0     0     0     0]
 [ 8607   114     0     0     0     0     0     0     0     0]
 [22047  3669     0     0     0     0     0     0     0     0]
 [31285   424     0     0     0    30     0  7580     0     0]
 [  289   480     0     0     0     0     0     0     0     0]
 [25161   879     0     0     0    54     0 14092     0     0]
 [  665   141     0     0     0     0     0    73     0     0]
 [10081    94     0     0     0     0     0     0     0     0]
 [ 3291    41     0     0     0     0     0  1457     0     0]]
Validation Loss: 155.204097032547 Validation Accuracy: 0.20297801774768284 Validation Multiclass Weighted F1: 0.11490675950399622 Validation Multiclass Macro F1: 0.06832038661682807
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'pa

 80%|████████  | 4/5 [00:12<00:03,  3.01s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[38935 21798     0     0     0     0     0   967     0     0]
 [    6 10358     0     0     0     0     0     0     0     0]
 [ 8301   188     0     0     0     0     0   232     0     0]
 [21111  1807     0     0     0     0     0  2798     0     0]
 [ 7050   218     0     0     0     0     0 32051     0     0]
 [  456   306     0     0     0     0     0     7     0     0]
 [ 1469   807     0     0     0     0     0 37910     0     0]
 [  113    99     0     0     0     0     0   667     0     0]
 [ 9933   111     0     0     0     0     0   131     0     0]
 [  897    18     0     0     0     0     0  3874     0     0]]
Validation Loss: 144.52745473384857 Validation Accuracy: 0.24657236770671906 Validation Multiclass Weighted F1: 0.18118516355517894 Validation Multiclass Macro F1: 0.0985634733411892
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'p

100%|██████████| 5/5 [00:14<00:00,  3.00s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[50130  2810     0     4     0   210     0  8546     0     0]
 [   49 10315     0     0     0     0     0     0     0     0]
 [  864    20     0     0     0     0     0  7837     0     0]
 [22952   157     0  1905     0   190     0   512     0     0]
 [ 4763    15     0     2     0    66     0 34473     0     0]
 [  574    46     0     0     0    82     0    67     0     0]
 [  610   198     0     0     0   219     0 39159     0     0]
 [  104    18     0     0     0     0     0   757     0     0]
 [ 3461    28     0     0     0     0     0  6686     0     0]
 [  402     0     0     0     0     0     0  4387     0     0]]
Validation Loss: 130.96599411964417 Validation Accuracy: 0.31186271703402463 Validation Multiclass Weighted F1: 0.2716682208559149 Validation Multiclass Macro F1: 0.18085149084970759
Execution time: 14.97993016242981 seconds





## Experiment 3: Reference MLP Baselines

#### Experiment Parameters

In [9]:
# Paths
RAW_DATA_PATH = 'data/raw'
LANDED_DATA_PATH = 'data/landed'
INGESTED_DATA_PATH = 'data/ingested'
UTILS_PATH = 'data/utils'
SAVED_MODELS_PATH = 'saved_models'
CONFIG_PATH = 'configs'

# General Experiment parameters
dataset_name = 'NF_ToN_IoT' # 'NF_ToN_IoT', 'NF_BoT_IoT' or 'NF_UNSW_NB15'
batch_size = 128
num_epochs = 10
weighted_loss = True
classifier_layers = 2
classifier_hidden_channels = 128

#### Experiment Run

In [10]:
data_processor = data_handling.DataPreprocessor(INGESTED_DATA_PATH, UTILS_PATH)
graph_builder = data_handling.GraphBuilder()

attack_mapping = data_processor.load_attack_mapping(dataset_name)

train_raw, val_raw, test_raw = data_processor.load_mixed_train(dataset_name), data_processor.load_mixed_val(dataset_name), data_processor.load_mixed_test(dataset_name)


(train_attrs, train_labels), (val_attrs, val_labels), (test_attrs, test_labels) = data_processor.preprocess_NF(dataset_name, train_raw, keep_IPs_and_timestamp=False, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True), \
            data_processor.preprocess_NF(dataset_name, val_raw, keep_IPs_and_timestamp=False, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True), \
            data_processor.preprocess_NF(dataset_name, test_raw, keep_IPs_and_timestamp=False, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

# Setup optimizer
model = gnn_architectures.MLP(train_attrs.shape[1], classifier_hidden_channels, len(attack_mapping), classifier_layers)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Make sure attrs in alphabetical order and then in tensor form
train_attrs = train_attrs[train_attrs.columns.sort_values()]
val_attrs = val_attrs[val_attrs.columns.sort_values()]
test_attrs = test_attrs[test_attrs.columns.sort_values()]

train_attrs = torch.tensor(train_attrs.values, dtype=torch.float32)
train_labels = train_labels.to_numpy()
labels_torch = torch.Tensor([attack_mapping[attack] for attack in train_labels])

val_attrs = torch.tensor(val_attrs.values, dtype=torch.float32)
val_labels = val_labels.to_numpy()
labels_torch_val = torch.Tensor([attack_mapping[attack] for attack in val_labels])

test_attrs = torch.tensor(test_attrs.values, dtype=torch.float32)
test_labels = test_labels.to_numpy()
labels_torch_test = torch.Tensor([attack_mapping[attack] for attack in test_labels])

tensor_dataset = torch.utils.data.TensorDataset(train_attrs.float(), labels_torch.float())
tensor_dataset_val = torch.utils.data.TensorDataset(val_attrs.float(), labels_torch_val.float())

train_loader = torch.utils.data.DataLoader(tensor_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(tensor_dataset_val, batch_size=batch_size, shuffle=False)

# Setup Loss Criterion
if weighted_loss:
    num_classes = len(attack_mapping)
    class_counts = np.zeros(num_classes)
    for batch in train_loader:
        target = batch[1]
        class_counts += target.sum(dim=0).numpy()

    # print(f'class counts in training data: {attack_mapping.keys()}:{class_counts}')

    total_samples = class_counts.sum()
    class_weights = total_samples / (num_classes * class_counts)

    weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
    criterion = torch.nn.CrossEntropyLoss(weight=weights)

else:
    criterion = torch.nn.CrossEntropyLoss()

# Set model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Train the model
best_model_weights = None
best_model_val_f1 = 0
best_model_weights_epoch = 0

train_losses = []
val_losses = []
train_weighted_f1 = []
val_weighted_f1 = []
train_macro_f1 = []
val_macro_f1 = []

for epoch in tqdm(range(num_epochs)):
    total_train_loss = 0
    epoch_preds = np.array([])
    epoch_targets = np.array([])
    for batch in train_loader:
        train_data, train_targets = batch
        train_data, train_targets = train_data.to(device), train_targets.to(device)
        optimizer.zero_grad()
        batch_preds = model(train_data)
        batch_loss = criterion(batch_preds, train_targets)
        batch_loss.backward()
        optimizer.step()
        total_train_loss += batch_loss.item()
        epoch_preds = np.concatenate((epoch_preds, batch_preds.argmax(dim=1).cpu().detach().numpy()))
        epoch_targets = np.concatenate((epoch_targets, train_targets.argmax(dim=1).cpu().detach().numpy()))

    # Calculate metrics
    epoch_train_accuracy, epoch_train_f1_weighted, epoch_train_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
    train_losses.append(total_train_loss)
    train_weighted_f1.append(epoch_train_f1_weighted)
    train_macro_f1.append(epoch_train_f1_macro)
    print('Epoch:', epoch, 'Train Loss:', total_train_loss, 'Train Accuracy:', epoch_train_accuracy.item(), 'Train Multiclass Weighted F1:', epoch_train_f1_weighted.item(), 'Train Multiclass Macro F1:', epoch_train_f1_macro.item())

    total_val_loss = 0
    epoch_preds = np.array([])
    epoch_targets = np.array([])
    for batch in val_loader:
        val_data, val_targets = batch
        val_data, val_targets = val_data.to(device), val_targets.to(device)
        batch_preds = model(val_data)
        batch_loss = criterion(batch_preds, val_targets)
        total_val_loss += batch_loss.item()
        epoch_preds = np.concatenate((epoch_preds, batch_preds.argmax(dim=1).cpu().detach().numpy()))
        epoch_targets = np.concatenate((epoch_targets, val_targets.argmax(dim=1).cpu().detach().numpy()))

    val_accuracy, val_f1_weighted, val_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
    val_losses.append(total_val_loss)
    val_weighted_f1.append(val_f1_weighted)
    val_macro_f1.append(val_f1_macro)
    if val_f1_macro > best_model_val_f1:
        best_model_val_f1 = val_f1_macro
        best_model_weights = model.state_dict()
        best_model_weights_epoch = epoch
    print('Validation Loss:', total_val_loss, 'Validation Accuracy:', val_accuracy.item(), 'Validation Multiclass Weighted F1:', val_f1_weighted.item(), 'Validation Multiclass Macro F1:', val_f1_macro.item())

os.makedirs(os.path.join(SAVED_MODELS_PATH, dataset_name, 'baselines'), exist_ok=True)
experiment_idx = len(os.listdir(os.path.join(SAVED_MODELS_PATH, dataset_name, 'baselines')))
experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'baselines', f'experiment_{experiment_idx}')
os.makedirs(experiment_dir, exist_ok=True)

experiment_dict = {'dataset': dataset_name, 'gnn_type': 'MLP', 'gnn_layers': 0, 'gnn_hidden_channels': 0, 'classifier_layers': classifier_layers, 'classifier_hidden_channels': classifier_hidden_channels, 'batch_size': batch_size, 'num_epochs': num_epochs, 'weighted_loss': weighted_loss, 'truncate': True}

# Save the best model
torch.save(best_model_weights, os.path.join(experiment_dir, f'best_model_weights_epoch_{best_model_weights_epoch}.pth'))

# Save the experiment metadata
experiments_results = {'train_losses': train_losses, 'val_losses': val_losses, 'train_weighted_f1': train_weighted_f1, 'val_weighted_f1': val_weighted_f1, 'train_macro_f1': train_macro_f1, 'val_macro_f1': val_macro_f1}
with open(os.path.join(experiment_dir, 'results.json'), 'w') as f:
    json.dump(experiments_results, f)
with open(os.path.join(experiment_dir, 'experiment_metadata.json'), 'w') as f:
    json.dump(experiment_dict, f)

- Loading in mixed training set...
- Loading in mixed validation set...
- Loading in mixed test set...
-- Min-Max scaling numerical columns...
-- Min-Max scaling numerical columns...
-- Min-Max scaling numerical columns...


  0%|          | 0/10 [00:00<?, ?it/s]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[63076    13  8611  2459  3921  7472  5113  2735   248  1493]
 [    7  8581    32     2     2    23     2    22     2    12]
 [   90     5 11376   269   135   208   185   132    29   190]
 [  261    11   177 29009     3  1181     4    27    10   643]
 [  299     2  1801   163 37266   467  2217   577    56  1943]
 [   54     2    16    26    18   481    16    16     0    32]
 [  549    55   839    56  3196   963 34911   825  2034   861]
 [   29     1    29     0    12    43    37  1690     4    30]
 [    4     0    14     0     3    10   288    37 14047     7]
 [   25     0    22   172   100   216    31   137     2  4811]]
Epoch: 0 Train Loss: 936.7169901132584 Train Accuracy: 0.7915067890927181 Train Multiclass Weighted F1: 0.8126846699606932 Train Multiclass Macro F1: 0.6999639039766464


 10%|█         | 1/10 [00:13<02:05, 13.95s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[21628     8   312   162   735  1852  2155   208    87    33]
 [    1  2560    11     1     1     6     2     2     0     1]
 [  285     0  3693     0   251     3    55     0     0     1]
 [   29     1     3 12303     0   145     1     0     4     2]
 [   14     0   185   168 15589  1209  1151     0     2   354]
 [    5     0     4     1     3   111    14     2     0     1]
 [    2     0   150     0  5635    10 13378     0     0     1]
 [    2     0    18     0     0    12    16   354     0     0]
 [    4     0     0     0     0     3    63    11  5512     0]
 [    2     0     6    10   215     4    29    31     0  1939]]
Validation Loss: 359.02946863044053 Validation Accuracy: 0.8308125181919126 Validation Multiclass Weighted F1: 0.8458514326090888 Validation Multiclass Macro F1: 0.7803253074687739
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'pas

 20%|██        | 2/10 [00:27<01:51, 13.99s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[20308    10  3163    85   980  1302   597   599   102    34]
 [    2  2560    12     1     0     3     2     4     0     1]
 [   15     0  3922     0   317     2    30     1     0     1]
 [   21     3     0 12291     0   156     0     2    13     2]
 [  747     0   113   168 15003  1002   448     8     2  1181]
 [   13     0     0     1     7   103     1    15     0     1]
 [ 2795     0     6     0  2749  2514  8470     1     0  2641]
 [    1     0     4     0     0     0    20   377     0     0]
 [    4     0     0     0     0     0    59    14  5516     0]
 [   34     0     6    10   204     4     5    29     0  1944]]
Validation Loss: 355.70869114995 Validation Accuracy: 0.7599529974881685 Validation Multiclass Weighted F1: 0.7781100743975269 Validation Multiclass Macro F1: 0.6853978938858795
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'passwo

 30%|███       | 3/10 [00:40<01:34, 13.50s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[25045     7   287   153   128   548   609   164   202    37]
 [    6  2560     6     1     0     6     0     2     2     2]
 [  285     0  3715     0   259     2    26     0     0     1]
 [   97     0     3 12305     0    77     0     0     1     5]
 [  683     0   134   168 13325  1395  2319     1     2   645]
 [   11     0     4     2     3   104    14     1     0     2]
 [ 1097     0   150     0  5580    88 12166     0    48    47]
 [   33     0    16     0     0     1    14   338     0     0]
 [   12     0     0     0     0     3    67     3  5508     0]
 [   63     0     6    10    69     3    12     2     0  2071]]
Validation Loss: 384.6580617837608 Validation Accuracy: 0.8315671456754455 Validation Multiclass Weighted F1: 0.8375500827415227 Validation Multiclass Macro F1: 0.7782221926299646
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'pass

 40%|████      | 4/10 [00:55<01:23, 13.89s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[24608    10   161   125   746   590   606   149   142    43]
 [   12  2560     1     1     0     6     0     2     2     1]
 [  239     0  3700     0   318     2    27     0     1     1]
 [   97     0     0 12307     0    79     0     0     1     4]
 [  526     0    99   168 13910  1088  1976     3     2   900]
 [    9     0     1     1     7   108    13     1     0     1]
 [   26     0     2     0  5416    10 13685     0     0    37]
 [    6     0    14     0     0    11     6   365     0     0]
 [    8     0     0     0     0     9    59     1  5516     0]
 [   32     0    13     2   107     2    18    19     0  2043]]
Validation Loss: 339.562127432524 Validation Accuracy: 0.8495164993909078 Validation Multiclass Weighted F1: 0.8578154353079428 Validation Multiclass Macro F1: 0.7917252954356615
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'passw

 50%|█████     | 5/10 [01:09<01:09, 13.97s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[21567     9  2633    23    72   781  1731   232    94    38]
 [    1  2560     0     1     0    18     2     2     0     1]
 [   86     0  3923     0   245     2    29     2     0     1]
 [  104     0     5 11820     0   556     2     0     1     0]
 [  269     0   115     4 13482  3185   859    45     1   712]
 [    7     0     1     0     7   111     8     7     0     0]
 [    3     0   150     0  5302    15 13705     0     0     1]
 [   14     0    14     0     0     4     2   368     0     0]
 [   17     0     0     0     0     0    67     1  5508     0]
 [  473     0    15     1   149     5    27    39     0  1527]]
Validation Loss: 418.94554751052056 Validation Accuracy: 0.8039046582076519 Validation Multiclass Weighted F1: 0.8260191368513293 Validation Multiclass Macro F1: 0.7393502480542192
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'pas

 60%|██████    | 6/10 [01:22<00:54, 13.54s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[25612    11   175    97    86   438   406   219    94    42]
 [   13  2560     1     1     0     4     3     2     0     1]
 [  282     0  3698     0   277     3    27     0     0     1]
 [   43     0     4 12361     1    78     0     0     0     1]
 [  297     0    95   168 12854  1878  2736     3     4   637]
 [   18     0     0     1     2   104    12     2     0     2]
 [  119     0     4     0  5205   256 13578     0     0    14]
 [   10     0    16     0     0     0     2   367     0     7]
 [   15     0     0     0     0     0    11     3  5564     0]
 [  303     0     8     2    68     4     2     3     0  1846]]
Validation Loss: 389.62129147478845 Validation Accuracy: 0.8467351580944578 Validation Multiclass Weighted F1: 0.8568010633468014 Validation Multiclass Macro F1: 0.7831872620092286
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'pas

 70%|███████   | 7/10 [01:34<00:39, 13.24s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[24693    35   128    38    55  1559   397   150    80    45]
 [   11  2560     0     1     0     8     2     2     0     1]
 [  284     0  3679     0   294     5    25     0     0     1]
 [   99     4     4 12302     0    79     0     0     0     0]
 [   24     0    98   168 12527  5092   359     0     2   402]
 [    7     0     0     1     1   118    13     1     0     0]
 [ 2591     0     4     0  2687   176 13683     0     0    35]
 [   39     8     0     0     0     1     2   352     0     0]
 [   18     0     0     0     0     0    62     1  5512     0]
 [  455     0     6    11    57     3     0     3     0  1701]]
Validation Loss: 500.9046512860805 Validation Accuracy: 0.8314593417492265 Validation Multiclass Weighted F1: 0.8587695510377109 Validation Multiclass Macro F1: 0.7860670112784321
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'pass

 80%|████████  | 8/10 [01:47<00:26, 13.04s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[23196     7   159    57   696  1396  1236   302    94    37]
 [    0  2560     1     1     0    18     2     2     0     1]
 [  253     0  3702     0   304     2    26     0     0     1]
 [   21     0     4 12357     0   104     0     0     0     2]
 [   97     0    95   168 13848  1856  1984     1     1   622]
 [    3     0     0     1     4   115    13     2     0     3]
 [    8     0     3     0  5366    12 13785     0     0     2]
 [    0     0     0     0    14     4     2   382     0     0]
 [   17     0     0     0     0     0    72     1  5503     0]
 [   16     0     1    11   381     2     2    19     0  1804]]
Validation Loss: 418.2168078743125 Validation Accuracy: 0.8328068908269639 Validation Multiclass Weighted F1: 0.8492184289715499 Validation Multiclass Macro F1: 0.7754650499667142
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'pass

 90%|█████████ | 9/10 [02:00<00:12, 12.99s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[25349    93   136    55    73   468   544   334    88    40]
 [    5  2570     1     1     0     3     0     4     0     1]
 [  286     0  3690     0   251    19    41     0     0     1]
 [  103     0     2 12302     0    73     6     0     0     2]
 [  314     0   811   168 12490  2699  1490    32     1   667]
 [   15     0     1     1     6   101     8     7     0     2]
 [   57     0    16     0  4943   112 14047     0     0     1]
 [   14     4     0     0     1     9     2   372     0     0]
 [   14     0     0     0     0     0    65    13  5501     0]
 [   25     0     7     2    63     3     7     8     0  2121]]
Validation Loss: 426.2092094120453 Validation Accuracy: 0.846724377701836 Validation Multiclass Weighted F1: 0.8603761827958402 Validation Multiclass Macro F1: 0.772579820937524
dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'passwo

100%|██████████| 10/10 [02:12<00:00, 13.29s/it]

dict_keys(['BENIGN', 'backdoor', 'ddos', 'dos', 'injection', 'mitm', 'password', 'ransomware', 'scanning', 'xss'])
[[23905    16   272    29   669  1644   266   213   134    32]
 [    2  2560    10     1     0    10     0     0     2     0]
 [  239     0  3614     0   316     5   114     0     0     0]
 [   15     0     1 12117     1   349     0     0     0     5]
 [  347     0    94    85 11820  4574  1203    58     1   490]
 [    0     0     0     0     6   119     9     7     0     0]
 [  972     0     1     0   874    27 17231     0    64     7]
 [    1     0     4     0     0    14     2   381     0     0]
 [   17     0     0     0     0     0    63     1  5512     0]
 [  264     0     0     2    40   553     1     8     0  1368]]
Validation Loss: 709.8308229913237 Validation Accuracy: 0.8476299306820755 Validation Multiclass Weighted F1: 0.8764538220000032 Validation Multiclass Macro F1: 0.7763328629386927





# Evaluation

### Get All Model Predictions On Test Set

In [19]:
dataset_to_evaluate = 'NF_ToN_IoT' # Choose from 'NF_ToN_IoT', 'NF_BoT_IoT', 'NF_UNSW_NB15
model_type_to_evaluate = 'temporal' # Choose from 'static', 'temporal'
weights_to_select = 'best_macro' # Choose from 'best_macro', 'best_weighted', or give a checkpoint number

# Set the experiment directories
gnn_from_scratch_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_to_evaluate, 'experiments', model_type_to_evaluate)
gnn_finetune_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_to_evaluate, 'fine_tuned_experiments', model_type_to_evaluate)
baseline_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_to_evaluate, 'baselines')

# Check if the experiment directories exist
os.makedirs(gnn_from_scratch_experiment_dir, exist_ok=True)
os.makedirs(gnn_finetune_experiment_dir, exist_ok=True)
os.makedirs(baseline_experiment_dir, exist_ok=True)

## 1) Testing routine for all GNNs from scratch ----------

for experiment_path in os.listdir(gnn_from_scratch_experiment_dir):
    # Check which experiments have already been tested. Skip if test results already exist
    experiment_path = os.path.join(gnn_from_scratch_experiment_dir, experiment_path)
    files_in_experiment = os.listdir(experiment_path)
    if 'test_set_results.pkl' in files_in_experiment:
        print(f'Skipping {experiment_path} as test results already exist')
        continue
    if len(files_in_experiment) == 0:
        print(f'Skipping {experiment_path} as no files in experiment')
        continue

    # Get Experiment Metadata
    with open(os.path.join(experiment_path, 'experiment_metadata.json'), 'r') as f:
        experiment_dict = json.load(f)

    # Get data according to the experiment
    test_data = data_processor.load_mixed_test(experiment_dict['dataset'])
    attack_mapping = data_processor.load_attack_mapping(experiment_dict['dataset'])
    if 'NF' in experiment_dict['dataset']:
        test_attrs, test_labels = data_processor.preprocess_NF(experiment_dict['dataset'], test_data, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=experiment_dict['truncate'])
    else:
        raise ValueError('Unknown dataset name')

    # Get the graph list
    test_windows = graph_builder.time_window_with_flow_duration(test_attrs, experiment_dict['window_size'], experiment_dict['window_stride'])
    temporal = False
    features = test_attrs.columns
    features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]
    if model_type_to_evaluate == 'temporal':
        temporal = True
        test_graphs, test_window_indices_for_classification = graph_builder.build_spatio_temporal_pyg_graphs(test_windows, test_attrs, test_labels, experiment_dict['window_memory'], experiment_dict['flow_memory'], experiment_dict['include_port'], features, attack_mapping)
    elif model_type_to_evaluate == 'static':
        test_graphs = graph_builder.build_static_pyg_graphs(test_windows, test_attrs, test_labels, experiment_dict['include_port'], features, attack_mapping)
    if experiment_dict['self_loops']:
        test_graphs = [AddSelfLoops()(graph) for graph in test_graphs]
    sample_graph = test_graphs[0]

    # Load the model
    if model_type_to_evaluate == 'temporal':
        gnn_base = gnn_architectures.TemporalPlusConv_v2(sample_graph.metadata(), experiment_dict['gnn_hidden_channels'], experiment_dict['gnn_layers'])
    elif model_type_to_evaluate == 'static':
        gnn_base = gnn_architectures.SAGE(sample_graph.metadata(), experiment_dict['gnn_hidden_channels'], experiment_dict['gnn_layers'])
    else:
        raise ValueError('Unknown GNN type')

    model = gnn_architectures.multiclass_NIDS_model(gnn_base, len(attack_mapping), experiment_dict['classifier_hidden_channels'], experiment_dict['classifier_layers'], temporal)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # lazy init
    with torch.no_grad():
        _ = model(test_graphs[0].x_dict, test_graphs[0].edge_index_dict)

    # Load the best model weights
    if weights_to_select == 'best_macro':
        best_model_weights = [f for f in os.listdir(experiment_path) if 'macro' in f][0]
    elif weights_to_select == 'best_weighted':
        best_model_weights = [f for f in os.listdir(experiment_path) if 'weighted' in f][0]
    else:
        best_model_weights = f'model_weights_checkpoint_epoch_{weights_to_select}.pth'
    best_model_weights = torch.load(os.path.join(experiment_path, best_model_weights), map_location=torch.device('cpu'))
    model.load_state_dict(best_model_weights)
    model.to(device)

    # Evaluate the model
    test_loader = DataLoader(test_graphs, batch_size=experiment_dict['batch_size'], shuffle=False)
    model.eval()
    test_preds = np.array([])
    test_targets = np.array([])
    for idx, batch in enumerate(test_loader):
        test_data = batch.to(device)
        with torch.no_grad():
            out = model(test_data.x_dict, test_data.edge_index_dict)
            preds = torch.argmax(out, dim=1)

            if idx == 0:
                test_probs = out.cpu().numpy()
            else:
                test_probs = np.concatenate((test_probs, out.cpu().numpy()), axis=0)
            test_preds = np.concatenate((test_preds, preds.cpu().numpy()), axis=0)
            test_targets = np.concatenate((test_targets, torch.argmax(test_data['con'].y, dim = 1).cpu().numpy()), axis=0)

    # Deduplicate the results if temporal model (cause in temporal model, we have reoccurig flows connected to different windows)
    if experiment_dict['gnn_type'] == 'temporal':
        test_preds, test_targets, test_probs = deduplicate_multiclass_sliding_window_results(test_preds, test_targets, test_probs, test_window_indices_for_classification)

    # Save test_preds, test_targets and test_probs in pickle file
    with open(os.path.join(experiment_path, 'test_set_results.pkl'), 'wb') as f:
        pickle.dump({'test_preds': test_preds, 'test_targets': test_targets, 'test_probs': test_probs}, f)

## 2) Testing routine for all baselines --------

# Change gnn_fine_tune_experiment_dir list to include all experiments including the k-shot ones that have subdirs
experiments_dirs_in_fine_tune = []
files_in_fine_tune = os.listdir(gnn_finetune_experiment_dir)
for f in files_in_fine_tune:
    if 'experiment_metadata.json' in os.listdir(os.path.join(gnn_finetune_experiment_dir, f)):
        experiments_dirs_in_fine_tune.append(os.path.join(gnn_finetune_experiment_dir, f))
        continue
    subdirs = os.listdir(os.path.join(gnn_finetune_experiment_dir, f))
    for subdir in subdirs:
        if 'experiment_metadata.json' in os.listdir(os.path.join(gnn_finetune_experiment_dir, f, subdir)):
            experiments_dirs_in_fine_tune.append(os.path.join(gnn_finetune_experiment_dir, f, subdir))

# NOw evaluate the fine-tuned models
for experiment_path in experiments_dirs_in_fine_tune:
    flow_memory = 20
    files_in_experiment = os.listdir(experiment_path)
    # Check if already a fine-tuned model (having metadata) or a k-shot learning directory having more submodules
    if 'experiment_metadata.json' in files_in_experiment:
        if 'test_set_results.pkl' in files_in_experiment:
            print(f'Skipping {experiment_path} as test results already exist')
            continue
        if len(files_in_experiment) == 0:
            print(f'Skipping {experiment_path} as no files in experiment')
            continue

    # Get Experiment Metadata
    with open(os.path.join(experiment_path, 'experiment_metadata.json'), 'r') as f:
        experiment_dict = json.load(f)

    # Get data according to the experiment
    test_data = data_processor.load_mixed_test(experiment_dict['dataset'])
    attack_mapping = data_processor.load_attack_mapping(experiment_dict['dataset'])

    cross_data_preprocessing = False
    if 'pretrain_strategy' in experiment_dict.keys():
        if experiment_dict["pretrain_strategy"] != 'in_context':
            cross_data_preprocessing = True

    if cross_data_preprocessing:
        test_attrs, test_labels = data_processor.preprocess_NF('all', test_data, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)
    else:
        test_attrs, test_labels = data_processor.preprocess_NF(dataset_to_evaluate, test_data, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

    # Get the graph list
    test_windows = graph_builder.time_window_with_flow_duration(test_attrs, experiment_dict['window_size'], experiment_dict['window_stride'])
    features = test_attrs.columns
    features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]
    if model_type_to_evaluate == 'temporal':
        temporal = True
        test_graphs, test_window_indices_for_classification = graph_builder.build_spatio_temporal_pyg_graphs(test_windows, test_attrs, test_labels, experiment_dict['window_memory'], flow_memory, False, features, attack_mapping, True)
    elif model_type_to_evaluate == 'static':
        temporal = False
        test_graphs = graph_builder.build_static_pyg_graphs(test_windows, test_attrs, test_labels, experiment_dict['include_port'], features, attack_mapping)
    if experiment_dict['self_loops']:
        test_graphs = [AddSelfLoops()(graph) for graph in test_graphs]
    sample_graph = test_graphs[0]

    # Load the model
    if model_type_to_evaluate == 'temporal':
        gnn_base = gnn_architectures.TemporalPlusConv_v2(sample_graph.metadata(), experiment_dict['gnn_hidden_channels'], experiment_dict['gnn_layers'])
    elif model_type_to_evaluate == 'static':
        gnn_base = gnn_architectures.SAGE(sample_graph.metadata(), experiment_dict['gnn_hidden_channels'], experiment_dict['gnn_layers'])
    else:
        raise ValueError('Unknown GNN type')

    model = gnn_architectures.multiclass_NIDS_model(gnn_base, len(attack_mapping), experiment_dict['classifier_hidden_channels'], experiment_dict['classifier_layers'], temporal)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # lazy init
    with torch.no_grad():
        _ = model(test_graphs[0].x_dict, test_graphs[0].edge_index_dict)

    # Load the best model weights
    best_model_weights = [f for f in os.listdir(experiment_path) if 'best_model_weights' in f][0]
    best_model_weights = torch.load(os.path.join(experiment_path, best_model_weights), map_location=torch.device('cpu'))
    model.load_state_dict(best_model_weights)
    model.to(device)

    # Evaluate the model
    test_loader = DataLoader(test_graphs, batch_size=experiment_dict['batch_size'])
    model.eval()
    test_preds = np.array([])
    test_targets = np.array([])
    for idx,batch in enumerate(test_loader):
        test_data = batch.to(device)
        with torch.no_grad():
            out = model(test_data.x_dict, test_data.edge_index_dict)
            preds = torch.argmax(out, dim=1)

            if idx == 0:
                test_probs = out.cpu().numpy()
            else:
                test_probs = np.concatenate((test_probs, out.cpu().numpy()), axis=0)

            test_preds = np.concatenate((test_preds, preds.cpu().numpy()), axis=0)
            test_targets = np.concatenate((test_targets, torch.argmax(test_data['con'].y, dim = 1).cpu().numpy()), axis=0)

    if experiment_dict['gnn_type'] == 'static':
        test_preds, test_targets, test_probs = deduplicate_multiclass_sliding_window_results(test_preds, test_targets, test_probs, test_window_indices_for_classification)

    # Save test_preds, test_targets and test_probs in pickle file
    with open(os.path.join(experiment_path, 'test_set_results.pkl'), 'wb') as f:
        pickle.dump({'test_preds': test_preds, 'test_targets': test_targets, 'test_probs': test_probs}, f)

## 3) Testing routine for all baselines --------

for experiment_path in os.listdir(baseline_experiment_dir):
    experiment_path = os.path.join(baseline_experiment_dir, experiment_path)
    files_in_experiment = os.listdir(experiment_path)
    if 'test_set_results.pkl' in files_in_experiment:
        print(f'Skipping {experiment_path} as test results already exist')
        continue
    if len(files_in_experiment) == 0:
        print(f'Skipping {experiment_path} as no files in experiment')
        continue

    # Get Experiment Metadata
    with open(os.path.join(experiment_path, 'experiment_metadata.json'), 'r') as f:
        experiment_dict = json.load(f)

    if experiment_dict['gnn_type'] == 'MLP':

        # Get data according to the experiment
        test_data = data_processor.load_mixed_test(experiment_dict['dataset'])
        attack_mapping = data_processor.load_attack_mapping(experiment_dict['dataset'])
        test_attrs, test_labels = data_processor.preprocess_NF(experiment_dict['dataset'], test_data, keep_IPs_and_timestamp=False, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)
        
        # Setup model
        model = gnn_architectures.MLP(test_attrs.shape[1], experiment_dict['classifier_hidden_channels'], len(attack_mapping), experiment_dict['classifier_layers'])

        # Load the best model weights
        best_model_weights = [f for f in os.listdir(experiment_path) if 'best_model_weights' in f][0]
        best_model_weights = torch.load(os.path.join(experiment_path, best_model_weights), map_location=torch.device('cpu'))
        model.load_state_dict(best_model_weights)

        # Make sure attrs in alphabetical order and then in tensor form
        test_attrs = test_attrs[test_attrs.columns.sort_values()]
        test_attrs = torch.tensor(test_attrs.values, dtype=torch.float32)
        test_labels = test_labels.to_numpy()
        labels_torch_test = torch.Tensor([attack_mapping[attack] for attack in test_labels])
        tensor_dataset = torch.utils.data.TensorDataset(test_attrs.float(), labels_torch_test.float())
        test_loader = torch.utils.data.DataLoader(tensor_dataset, batch_size=experiment_dict['batch_size'], shuffle=False)

        # Evaluate the model
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        model.eval()
        test_preds_list = np.array([])
        test_targets_list = np.array([])
        for idx, batch in enumerate(test_loader):
            test_data, test_targets = batch
            test_data, test_targets = test_data.to(device), test_targets.to(device)
            with torch.no_grad():
                out = torch.nn.functional.softmax(model(test_data), dim=1)
                preds = torch.argmax(out, dim=1).cpu().numpy()
                targets = torch.argmax(test_targets, dim = 1).cpu().numpy()

                if idx == 0:
                    test_probs_list = out.cpu().numpy()
                else:
                    test_probs_list = np.concatenate((test_probs_list, out.cpu().numpy()), axis=0)

                test_preds_list = np.concatenate((test_preds_list, preds), axis=0)
                test_targets_list = np.concatenate((test_targets_list, targets), axis=0)


        # Save test_preds, test_targets and test_probs in pickle file
        with open(os.path.join(experiment_path, 'test_set_results.pkl'), 'wb') as f:
            pickle.dump({'test_preds': test_preds_list, 'test_targets': test_targets_list, 'test_probs': test_probs_list}, f)

- Loading in mixed test set...
-- Min-Max scaling numerical columns...


100%|██████████| 541/541 [00:12<00:00, 41.70it/s]


- Loading in mixed test set...


  ohe[f'{attribute}_{value}'] = 0


-- Min-Max scaling numerical columns...


100%|██████████| 644/644 [00:19<00:00, 33.17it/s]


- Loading in mixed test set...
-- Min-Max scaling numerical columns...


100%|██████████| 644/644 [00:19<00:00, 33.32it/s]


- Loading in mixed test set...
-- Min-Max scaling numerical columns...


### Aggregate Test Set Predictions to Evaluation Metrics across All Models and Datasets

In [20]:
# Load pickle file 
datasets_to_evaluate = ['NF_ToN_IoT'] # Choose from 'NF_ToN_IoT', 'NF_BoT_IoT', 'NF_UNSW_NB15
model_types_to_evaluate = ['temporal'] # Choose from 'temporal' and 'static'
weights_to_select = 'best_macro' # Choose from 'best_macro', 'best_weighted', or give a checkpoint number

gnn_from_scratch_df = pd.DataFrame(columns=['dataset', 'model_type', 'window_size', 'window_memory', 'multiclass_acc', 'multiclass_f1_weighted', 'multiclass_f1_macro', 'multiclass_roc_auc_macro_ovr', 'multiclass_roc_auc_macro_ovo', 'multiclass_roc_auc_weighted_ovr', 'multiclass_roc_auc_weighted_ovo', 'binary_macro_f1', 'binary_weighted_f1'])
baselines_df = pd.DataFrame(columns=['dataset', 'model_type', 'multiclass_acc', 'multiclass_f1_weighted', 'multiclass_f1_macro', 'multiclass_roc_auc_macro_ovr', 'multiclass_roc_auc_macro_ovo', 'multiclass_roc_auc_weighted_ovr', 'multiclass_roc_auc_weighted_ovo', 'binary_macro_f1', 'binary_weighted_f1'])
fine_tuned_gnn_df = pd.DataFrame(columns=['dataset', 'model_type', 'window_size', 'window_memory', 'multiclass_acc', 'multiclass_f1_weighted', 'multiclass_f1_macro', 'multiclass_roc_auc_macro_ovr', 'multiclass_roc_auc_macro_ovo', 'multiclass_roc_auc_weighted_ovr', 'multiclass_roc_auc_weighted_ovo', 'binary_macro_f1', 'binary_weighted_f1'])
k_shot_learning_df = pd.DataFrame(columns=['dataset', 'model_type', 'k_shot_frac','pretrain_strategy', 'window_size', 'window_memory','multiclass_acc', 'multiclass_f1_weighted', 'multiclass_f1_macro', 'multiclass_roc_auc_macro_ovr', 'multiclass_roc_auc_macro_ovo', 'multiclass_roc_auc_weighted_ovr', 'multiclass_roc_auc_weighted_ovo', 'binary_macro_f1', 'binary_weighted_f1', 'best_train_macro_f1', 'best_train_weighted_f1', 'best_val_macro_f1', 'best_val_weighted_f1'])

for dataset_name in datasets_to_evaluate:
    
    for model_type_to_evaluate in model_types_to_evaluate:
        # Set the experiment directories
        gnn_from_scratch_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'experiments', model_type_to_evaluate)
        gnn_finetune_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'fine_tuned_experiments', model_type_to_evaluate)
        baseline_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'baselines')

        # Check if the experiment directories exist
        os.makedirs(gnn_from_scratch_experiment_dir, exist_ok=True)
        os.makedirs(gnn_finetune_experiment_dir, exist_ok=True)
        os.makedirs(baseline_experiment_dir, exist_ok=True)

        for experiment_path in os.listdir(gnn_from_scratch_experiment_dir):
            experiment_path = os.path.join(gnn_from_scratch_experiment_dir, experiment_path)
            with open(f'{experiment_path}/experiment_metadata.json', 'r') as f:
                experiment_dict = json.load(f)

            if 'test_set_results.pkl' not in os.listdir(experiment_path):
                print(f'Skipping {experiment_path} with window size {experiment_dict["window_size"]} and window memory {experiment_dict["window_memory"]} as test results do not exist')
                continue
            
            with open(f'{experiment_path}/test_set_results.pkl', 'rb') as f:
                results = pickle.load(f)


            gnn_type, window_size, window_memory = experiment_dict['gnn_type'], experiment_dict['window_size'], experiment_dict['window_memory']
            if window_memory == 5:
                multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1 = calculate_multiclass_test_metrics(results['test_preds'], results['test_targets'], results['test_probs'])
                gnn_from_scratch_df = gnn_from_scratch_df.append({'dataset':dataset_name ,'model_type': gnn_type, 'window_size': window_size, 'window_memory': window_memory, 'multiclass_acc': multiclass_acc, 'multiclass_f1_weighted': multiclass_f1_weighted, 'multiclass_f1_macro': multiclass_f1_macro, 'multiclass_roc_auc_macro_ovr': multiclass_roc_auc_macro_ovr, 'multiclass_roc_auc_macro_ovo': multiclass_roc_auc_macro_ovo, 'multiclass_roc_auc_weighted_ovr': multiclass_roc_auc_weighted_ovr, 'multiclass_roc_auc_weighted_ovo': multiclass_roc_auc_weighted_ovo, 'binary_macro_f1': binary_macro_f1, 'binary_weighted_f1': binary_weighted_f1}, ignore_index=True)

        for experiment_path in os.listdir(baseline_experiment_dir):

            experiment_path = os.path.join(baseline_experiment_dir, experiment_path)
            with open(f'{experiment_path}/experiment_metadata.json', 'r') as f:
                experiment_dict = json.load(f)

            if 'test_set_results.pkl' not in os.listdir(experiment_path):
                print(f'Skipping {experiment_path} as test results do not exist')
                continue
            
            with open(f'{experiment_path}/test_set_results.pkl', 'rb') as f:
                results = pickle.load(f)

            multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1 = calculate_multiclass_test_metrics(results['test_preds'], results['test_targets'], results['test_probs'])
            baselines_df = baselines_df.append({'dataset':dataset_name ,'model_type': experiment_dict['gnn_type'], 'multiclass_acc': multiclass_acc, 'multiclass_f1_weighted': multiclass_f1_weighted, 'multiclass_f1_macro': multiclass_f1_macro, 'multiclass_roc_auc_macro_ovr': multiclass_roc_auc_macro_ovr, 'multiclass_roc_auc_macro_ovo': multiclass_roc_auc_macro_ovo, 'multiclass_roc_auc_weighted_ovr': multiclass_roc_auc_weighted_ovr, 'multiclass_roc_auc_weighted_ovo': multiclass_roc_auc_weighted_ovo, 'binary_macro_f1': binary_macro_f1, 'binary_weighted_f1': binary_weighted_f1}, ignore_index=True)
        
        # Change gnn_fine_tune_experiment_dir list to include all experiments including the k-shot ones that have subdirs
        # Change gnn_fine_tune_experiment_dir list to include all experiments including the k-shot ones that have subdirs
        experiments_dirs_in_fine_tune = []
        files_in_fine_tune = os.listdir(gnn_finetune_experiment_dir)
        for f in files_in_fine_tune:
            if 'experiment_metadata.json' in os.listdir(os.path.join(gnn_finetune_experiment_dir, f)):
                experiments_dirs_in_fine_tune.append(os.path.join(gnn_finetune_experiment_dir, f))
                continue
            subdirs = os.listdir(os.path.join(gnn_finetune_experiment_dir, f))
            for subdir in subdirs:
                if 'experiment_metadata.json' in os.listdir(os.path.join(gnn_finetune_experiment_dir, f, subdir)):
                    experiments_dirs_in_fine_tune.append(os.path.join(gnn_finetune_experiment_dir, f, subdir))
        for experiment_path in experiments_dirs_in_fine_tune:
        
            with open(f'{experiment_path}/experiment_metadata.json', 'r') as f:
                experiment_dict = json.load(f)

            k_shot_model = False
            if 'pretrain_strategy' in experiment_dict.keys():
                k_shot_model = True
                with open(f'{experiment_path}/results.json', 'r') as f:
                    train_results = json.load(f)

            if 'test_set_results.pkl' not in os.listdir(experiment_path):
                print(f'Skipping {experiment_path} with window size {experiment_dict["window_size"]} and window memory {experiment_dict["window_memory"]} as test results do not exist')
                continue
            
            with open(f'{experiment_path}/test_set_results.pkl', 'rb') as f:
                results = pickle.load(f)

            gnn_type, window_size, window_memory = experiment_dict['gnn_type'], experiment_dict['window_size'], experiment_dict['window_memory']

            if k_shot_model:
                k_shot_frac = experiment_dict['K-shot-dataset_frac']
                pretrain_strategy = experiment_dict['pretrain_strategy']
                multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1 = calculate_multiclass_test_metrics(results['test_preds'], results['test_targets'], results['test_probs'])
                best_train_macro_f1, best_val_macro_f1 = max(train_results['train_macro_f1']), max(train_results['val_macro_f1'])
                best_train_weighted_f1, best_val_weighted_f1 = max(train_results['train_weighted_f1']), max(train_results['val_weighted_f1'])

                k_shot_learning_df = k_shot_learning_df.append({'dataset':dataset_name ,'model_type': gnn_type, 'k_shot_frac': k_shot_frac, 'pretrain_strategy': pretrain_strategy, 'window_size': window_size, 'window_memory': window_memory, 'multiclass_acc': multiclass_acc, 'multiclass_f1_weighted': multiclass_f1_weighted, 'multiclass_f1_macro': multiclass_f1_macro, 'multiclass_roc_auc_macro_ovr': multiclass_roc_auc_macro_ovr, 'multiclass_roc_auc_macro_ovo': multiclass_roc_auc_macro_ovo, 'multiclass_roc_auc_weighted_ovr': multiclass_roc_auc_weighted_ovr, 'multiclass_roc_auc_weighted_ovo': multiclass_roc_auc_weighted_ovo, 'binary_macro_f1': binary_macro_f1, 'binary_weighted_f1': binary_weighted_f1, 'best_train_macro_f1': best_train_macro_f1, 'best_val_macro_f1': best_val_macro_f1, 'best_train_weighted_f1': best_train_weighted_f1, 'best_val_weighted_f1': best_val_weighted_f1}, ignore_index=True)
            
            else:
                multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1 = calculate_multiclass_test_metrics(results['test_preds'], results['test_targets'], results['test_probs'])

                fine_tuned_gnn_df = fine_tuned_gnn_df.append({'dataset':dataset_name ,'model_type': gnn_type, 'window_size': window_size, 'window_memory': window_memory, 'multiclass_acc': multiclass_acc, 'multiclass_f1_weighted': multiclass_f1_weighted, 'multiclass_f1_macro': multiclass_f1_macro, 'multiclass_roc_auc_macro_ovr': multiclass_roc_auc_macro_ovr, 'multiclass_roc_auc_macro_ovo': multiclass_roc_auc_macro_ovo, 'multiclass_roc_auc_weighted_ovr': multiclass_roc_auc_weighted_ovr, 'multiclass_roc_auc_weighted_ovo': multiclass_roc_auc_weighted_ovo, 'binary_macro_f1': binary_macro_f1, 'binary_weighted_f1': binary_weighted_f1}, ignore_index=True)

print('GNN from scratch results')
print(gnn_from_scratch_df)  
print('Baselines results')
print(baselines_df)
print('Fine-tuned GNN results')
print(fine_tuned_gnn_df)
print('K-Shot Learning results')
print(k_shot_learning_df)

GNN from scratch results
Empty DataFrame
Columns: [dataset, model_type, window_size, window_memory, multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1]
Index: []
Baselines results
      dataset model_type  multiclass_acc  multiclass_f1_weighted  \
0  NF_ToN_IoT        MLP        0.863127                0.885071   

   multiclass_f1_macro  multiclass_roc_auc_macro_ovr  \
0             0.756166                       0.96405   

   multiclass_roc_auc_macro_ovo  multiclass_roc_auc_weighted_ovr  \
0                      0.970319                         0.967359   

   multiclass_roc_auc_weighted_ovo  binary_macro_f1  binary_weighted_f1  
0                         0.968857         0.906766            0.970484  
Fine-tuned GNN results
Empty DataFrame
Columns: [dataset, model_type, window_size, window_memory, multiclass_acc,