# Imports

In [None]:
import data_handling
from torch_geometric.transforms import AddSelfLoops
import os
import torch
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import json
import numpy as np
from util_scripts import gnn_architectures
from util_scripts.gnn_training import train_batch, validate_batch, calculate_multiclass_metrics, calculate_multiclass_test_metrics, export_pretty_confusion_matrix, deduplicate_multiclass_sliding_window_results, balanced_temporal_undersampler
import pandas as pd
import time
import pickle

# Run Experiments

## Experiment 1: General GNNs from scratch Parameter Optimization

### Experiment Parameters

In [9]:
# Paths
RAW_DATA_PATH = 'data/raw'
LANDED_DATA_PATH = 'data/landed'
INGESTED_DATA_PATH = 'data/ingested'
UTILS_PATH = 'data/utils'
SAVED_MODELS_PATH = 'saved_models'
CONFIG_PATH = 'configs'

# General parameters
dataset_name = 'NF_ToN_IoT' # Pick from 'NF_ToN_IoT', 'NF_BoT_IoT' and 'NF_UNSW_NB15'
truncate = True
gnn_type = 'temporal' # Pick from 'temporal' and 'static'
temporal = True if gnn_type == 'temporal' else False

# Training parameters
batch_size = 10
num_epochs = 2
weighted_loss = True

# Model Parameters. Make list of all the options you want to try
gnn_layer_options = [2] # [2,3]
window_size_options = [10, 1] # [10, 30]
self_loops_options = [False]
save_epoch_every = 1

# Specific for if temporal
window_memory_options = [3, 5] # [3, 5]
flow_memory = 20

### Experiment Runs

In [None]:
data_processor = data_handling.DataPreprocessor(INGESTED_DATA_PATH, UTILS_PATH)
graph_builder = data_handling.GraphBuilder()

attack_mapping = data_processor.load_attack_mapping(dataset_name)
train_raw, val_raw, = data_processor.load_mixed_train(dataset_name), data_processor.load_mixed_val(dataset_name)
(train_attrs, train_labels), (val_attrs, val_labels) = data_processor.preprocess_NF(dataset_name, train_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=truncate), \
                    data_processor.preprocess_NF(dataset_name, val_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=truncate)

# If this dataset, we undersampxle benign flows a bit for performance. Only on training set though, not on eval!
if dataset_name == 'NF_UNSW_NB15':
    train_attrs, train_labels = data_processor.randomly_drop_benign_flows(train_attrs, train_labels, 0.7)

for gnn_layers in gnn_layer_options:
    gnn_hidden_channels = 128
    classifier_layers = 2
    classifier_hidden_channels = 128

    for window_size in window_size_options:
        for window_memory in window_memory_options:
            window_stride = window_size
            # batch_size = max(1,int((1/(window_size*window_memory))*200)+10) # A metric for having small enough batch size when windows are getting big
            for self_loops in self_loops_options:

                features = train_attrs.columns
                features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]
                train_windows, val_windows = graph_builder.time_window_with_flow_duration(train_attrs, window_size, window_stride), \
                                                            graph_builder.time_window_with_flow_duration(val_attrs, window_size, window_stride)
                if gnn_type == 'temporal':
                    train_graphs, _ = graph_builder.build_spatio_temporal_pyg_graphs(train_windows, train_attrs, train_labels, window_memory, flow_memory, False, features, attack_mapping, True)
                    val_graphs, val_window_indices_for_classification = graph_builder.build_spatio_temporal_pyg_graphs(val_windows, val_attrs, val_labels, window_memory, flow_memory, False, features, attack_mapping, True)
                elif gnn_type == 'static':
                    train_graphs = graph_builder.build_static_pyg_graphs(train_windows, train_attrs, train_labels, False, features, attack_mapping, True)
                    val_graphs = graph_builder.build_static_pyg_graphs(val_windows, val_attrs, val_labels, False, features, attack_mapping, True)

                metadata = train_graphs[0].metadata()
                sample_graph = train_graphs[0]

                if gnn_type == 'temporal':
                    gnn_base = gnn_architectures.TemporalSAGE(metadata, gnn_hidden_channels, gnn_layers)
                elif gnn_type == 'static':
                    gnn_base = gnn_architectures.SAGE(metadata, gnn_hidden_channels, gnn_layers)
                else:
                    raise ValueError('Unknown GNN type')

                model = gnn_architectures.multiclass_NIDS_model(gnn_base, len(attack_mapping), classifier_hidden_channels, classifier_layers, temporal)

                os.makedirs(os.path.join(SAVED_MODELS_PATH, dataset_name, 'experiments', gnn_type), exist_ok=True)
                model_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'experiments', gnn_type)
                experiment_idx = len(os.listdir(model_dir))
                experiment_dir = os.path.join(model_dir, f'experiment_{experiment_idx}')
                os.makedirs(experiment_dir)
                experiment_dict = {'dataset': dataset_name, 'gnn_type': gnn_type, 'gnn_layers': gnn_layers, 'gnn_hidden_channels': gnn_hidden_channels, 'classifier_layers': classifier_layers, 'classifier_hidden_channels': classifier_hidden_channels, 'self_loops': self_loops, 'window_size': window_size, 'window_stride': window_stride, 'include_port': False, 'window_memory': window_memory, 'batch_size': batch_size, 'num_epochs': num_epochs, 'weighted_loss': weighted_loss, 'truncate': truncate, 'flow_memory': flow_memory}
                train_losses = []
                val_losses = []
                train_weighted_f1 = []
                val_weighted_f1 = []
                train_macro_f1 = []
                val_macro_f1 = []

                # Setup optimizer
                optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
                loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
                loader_val = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)

                # Set model to device
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                model.to(device)

                # Setup Loss Criterion
                if weighted_loss:
                    num_classes = len(attack_mapping)
                    class_counts = np.zeros(num_classes)
                    for batch in loader:
                        target = batch['con'].y
                        class_counts += target.sum(dim=0).cpu().numpy()

                    # print(f'class counts in training data: {attack_mapping.keys()}:{class_counts}')
                    total_samples = class_counts.sum()
                    class_weights = total_samples / (num_classes * class_counts)
                    weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
                    criterion = torch.nn.CrossEntropyLoss(weight=weights)

                else:
                    criterion = torch.nn.CrossEntropyLoss()

                # Train the model
                best_model_weights_macro = None
                best_model_weights_weighted = None
                best_model_val_f1_macro = 0
                best_model_val_f1_weighted = 0
                best_model_weights_macro_epoch = 0
                best_model_weights_weighted_epoch = 0

                for epoch in tqdm(range(num_epochs)):
                    total_train_loss = 0
                    epoch_preds = np.array([])
                    epoch_targets = np.array([])
                    for batch in loader:
                        train_data = batch.to(device)
                        batch_loss, batch_preds, batch_targets = train_batch(model, train_data, optimizer, criterion)
                        total_train_loss += batch_loss
                        epoch_preds = np.concatenate((epoch_preds, batch_preds))
                        epoch_targets = np.concatenate((epoch_targets, batch_targets))

                    # Calculate metrics
                    epoch_train_accuracy, epoch_train_f1_weighted, epoch_train_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
                    train_losses.append(total_train_loss)
                    train_weighted_f1.append(epoch_train_f1_weighted)
                    train_macro_f1.append(epoch_train_f1_macro)
                    print('Epoch:', epoch, 'Train Loss:', total_train_loss, 'Train Accuracy:', epoch_train_accuracy.item(), 'Train Multiclass Weighted F1:', epoch_train_f1_weighted.item())

                    total_val_loss = 0
                    epoch_preds = np.array([])
                    epoch_targets = np.array([])
                    for batch in loader_val:
                        val_data = batch.to(device)
                        batch_loss, batch_preds, batch_targets = validate_batch(model, val_data, criterion)
                        total_val_loss += batch_loss
                        epoch_preds = np.concatenate((epoch_preds, batch_preds))
                        epoch_targets = np.concatenate((epoch_targets, batch_targets))

                    val_accuracy, val_f1_weighted, val_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
                    val_losses.append(total_val_loss)
                    val_weighted_f1.append(val_f1_weighted)
                    val_macro_f1.append(val_f1_macro)
                    if val_f1_macro > best_model_val_f1_macro:
                        best_model_val_f1_macro = val_f1_macro
                        best_model_weights_macro = model.state_dict()
                        best_model_weights_macro_epoch = epoch
                    if val_f1_weighted > best_model_val_f1_weighted:
                        best_model_val_f1_weighted = val_f1_weighted
                        best_model_weights_weighted = model.state_dict()
                        best_model_weights_weighted_epoch = epoch
                    if (epoch % save_epoch_every == 0) and (epoch != 0):
                        torch.save(model.state_dict, os.path.join(experiment_dir, f'model_weights_checkpoint_epoch_{epoch}.pth'))

                    print('Validation Loss:', total_val_loss, 'Validation Accuracy:', val_accuracy.item(), 'Validation Multiclass Weighted F1:', val_f1_weighted.item(), 'Validation Multiclass Macro F1:', val_f1_macro.item())

                # Save the best models
                torch.save(best_model_weights_macro, os.path.join(experiment_dir, f'best_model_weights_macro_f1_epoch_{best_model_weights_macro_epoch}.pth'))
                torch.save(best_model_weights_weighted, os.path.join(experiment_dir, f'best_model_weights_weighted_f1_epoch_{best_model_weights_weighted_epoch}.pth'))

                # Save the experiment metadata
                experiments_results = {'train_losses': train_losses, 'val_losses': val_losses, 'train_weighted_f1': train_weighted_f1, 'val_weighted_f1': val_weighted_f1, 'train_macro_f1': train_macro_f1, 'val_macro_f1': val_macro_f1}
                with open(os.path.join(experiment_dir, 'results.json'), 'w') as f:
                    json.dump(experiments_results, f)
                with open(os.path.join(experiment_dir, 'experiment_metadata.json'), 'w') as f:
                    json.dump(experiment_dict, f)

## Experiment 2: GNNs from pre-trained

#### Experiment Parameters

In [13]:
# Paths
RAW_DATA_PATH = 'data/raw'
LANDED_DATA_PATH = 'data/landed'
INGESTED_DATA_PATH = 'data/ingested'
UTILS_PATH = 'data/utils'
SAVED_MODELS_PATH = 'saved_models'
CONFIG_PATH = 'configs'

# General parameters
dataset_name = 'NF_ToN_IoT' # 'NF_UNSW_NB15', 'NF_ToN_IoT', 'NF_BoT_IoT'
pretraining_strategy = 'no_pretraining' # 'in_context','out_context','no_pretraining'
undersample_fracs = [0.05] # [0.1, 0.2, 0.5, 0.8, 1.0] # K-shot learning

# Pretrained model parameters (which one to select)
experiment_idx = 0
checkpoint_idx = 2

# Training parameters
batch_size = 10
num_epochs = 5
learning_rate = 0.001
weighted_loss = True
classifier_layers = 2
classifier_hidden_channels = 128
truncate = True # Truncate the extreme numerical values in standardization

# For if pretraining_strategy=='no_pretraining'. Else, ignore (we'll load the best model from the pretraining)
flow_memory = 20
window_size = 5
window_memory = 5
window_stride = 5
use_ports=False
self_loops=False
gnn_hidden_channels = 128
gnn_type = 'temporal' # 'temporal', 'static'

#### Experiment Runs

In [None]:
data_processor = data_handling.DataPreprocessor(INGESTED_DATA_PATH, UTILS_PATH)
graph_builder = data_handling.GraphBuilder()

if pretraining_strategy != 'no_pretraining':
  pre_trained_gnn_dir = f'{SAVED_MODELS_PATH}/{dataset_name}/pretraining_experiments/{gnn_type}/experiment_{experiment_idx}'
  weights_path = f'{pre_trained_gnn_dir}/checkpoint_{checkpoint_idx}_gnnbase.pt'
  metadata_path = f'{pre_trained_gnn_dir}/experiment_metadata.json'
  with open(metadata_path, 'r') as f:
      metadata = json.load(f)

  gnn_type, flow_memory, gnn_layers, window_size, window_memory, include_port, self_loops, gnn_hidden_channels, classifier_layers, classifier_hidden_channels = metadata["graph_type"],metadata["flow_memory"],metadata["gnn_layer"],metadata["window_size"],metadata["window_memory"],metadata["include_port"],metadata["self_loops"],metadata["gnn_hidden_channels"], metadata["classifier_layers"], metadata["classifier_hidden_channels"]
  window_stride = window_size
  temporal = True if gnn_type == 'temporal' else False

attack_mapping = data_processor.load_attack_mapping(dataset_name)
truncate = True
val_raw, test_raw = data_processor.load_mixed_val(dataset_name), data_processor.load_mixed_test(dataset_name)

if pretraining_strategy != 'in_context':
  (val_attrs, val_labels), (test_attrs, test_labels) = data_processor.preprocess_NF('all', val_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True), \
                            data_processor.preprocess_NF('all', test_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

else:
  (val_attrs, val_labels), (test_attrs, test_labels) = data_processor.preprocess_NF(dataset_name, val_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True), \
                            data_processor.preprocess_NF(dataset_name, test_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

features = val_attrs.columns
features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]
val_windows = graph_builder.time_window_with_flow_duration(val_attrs, window_size, window_stride)
if gnn_type == 'temporal':
    val_graphs, val_window_indices_for_classification = graph_builder.build_spatio_temporal_pyg_graphs(val_windows, val_attrs, val_labels, window_memory, flow_memory, include_port, features, attack_mapping, True)
elif gnn_type == 'static':
    val_graphs = graph_builder.build_static_pyg_graphs(val_windows, val_attrs, val_labels, include_port, features, attack_mapping, True)
else:
    raise ValueError('GNN type not recognized!')
if self_loops:
    val_graph = [AddSelfLoops()(graph) for graph in val_graphs]

# Undersample the training data for K-shot learning
os.makedirs(os.path.join(SAVED_MODELS_PATH, dataset_name, 'fine_tuned_experiments', gnn_type), exist_ok=True)
experiment_idx = len(os.listdir(os.path.join(SAVED_MODELS_PATH, dataset_name, 'fine_tuned_experiments', gnn_type)))

for undersample_frac in undersample_fracs:
    train_raw = data_processor.load_mixed_train(dataset_name)

    if pretraining_strategy != 'in_context':
        train_attrs, train_labels = data_processor.preprocess_NF('all', train_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

    else:
       train_attrs, train_labels= data_processor.preprocess_NF(dataset_name, train_raw, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

    train_attrs, train_labels, practical_fraction = balanced_temporal_undersampler(train_attrs, train_labels, undersample_frac)

    features = train_attrs.columns
    features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]

    train_windows = graph_builder.time_window_with_flow_duration(train_attrs, window_size, window_stride)

    if gnn_type == 'temporal':
        train_graphs, _ = graph_builder.build_spatio_temporal_pyg_graphs(train_windows, train_attrs, train_labels, window_memory, flow_memory, include_port, features, attack_mapping, True)
    elif gnn_type == 'static':
        train_graphs = graph_builder.build_static_pyg_graphs(train_windows, train_attrs, train_labels, include_port, features, attack_mapping, True)
    else:
        raise ValueError('GNN type not recognized!')

    if self_loops:
        train_graph = [AddSelfLoops()(graph) for graph in train_graphs]

    metadata = train_graphs[0].metadata()
    sample_graph = train_graphs[0]

    if gnn_type == 'temporal':
        gnn_base = gnn_architectures.TemporalSAGE(metadata, gnn_hidden_channels, gnn_layers)
    elif gnn_type == 'static':
        gnn_base = gnn_architectures.SAGE(metadata, gnn_hidden_channels, gnn_layers)
    else:
        raise ValueError('Unknown GNN type')

    if pretraining_strategy != 'no_pretraining':
      # Load the pre-trained model
      with torch.no_grad():  # Initialize lazy modules.
          out = gnn_base(sample_graph.x_dict, sample_graph.edge_index_dict)
      gnn_base.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
      print(f'Pretrained GNN loaded successfully!')

    model = gnn_architectures.multiclass_NIDS_model(gnn_base, len(attack_mapping), classifier_hidden_channels, classifier_layers)

    train_losses = []
    val_losses = []
    train_weighted_f1 = []
    val_weighted_f1 = []
    train_macro_f1 = []
    val_macro_f1 = []

    # Setup optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Setup dataloader
    loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    loader_val = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Setup Loss Criterion
    if weighted_loss:
        num_classes = len(attack_mapping)
        class_counts = np.zeros(num_classes)
        for batch in loader:
            target = batch['con'].y
            class_counts += target.sum(dim=0).cpu().numpy()

        # print(f'class counts in training data: {attack_mapping.keys()}:{class_counts}')
        total_samples = class_counts.sum()
        class_weights = total_samples / (num_classes * class_counts)
        # print(f'Class Weights: {attack_mapping.keys()}')
        # print(class_weights)
        weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
        criterion = torch.nn.CrossEntropyLoss(weight=weights)

    else:
        criterion = torch.nn.CrossEntropyLoss()

    # Set model to device
    model.to(device)

    # Train the model
    best_model_weights = None
    best_model_val_f1 = 0
    best_model_weights_epoch = 0

    # get the start time
    st = time.time()

    for epoch in tqdm(range(num_epochs)):
        total_train_loss = 0
        epoch_preds = np.array([])
        epoch_targets = np.array([])
        for batch in loader:
            train_data = batch.to(device)
            batch_loss, batch_preds, batch_targets = train_batch(model, train_data, optimizer, criterion)
            total_train_loss += batch_loss
            epoch_preds = np.concatenate((epoch_preds, batch_preds))
            epoch_targets = np.concatenate((epoch_targets, batch_targets))

        # Calculate metrics
        epoch_train_accuracy, epoch_train_f1_weighted, epoch_train_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
        train_losses.append(total_train_loss)
        train_weighted_f1.append(epoch_train_f1_weighted)
        train_macro_f1.append(epoch_train_f1_macro)
        print('Epoch:', epoch, 'Train Loss:', total_train_loss, 'Train Accuracy:', epoch_train_accuracy.item(), 'Train Multiclass Weighted F1:', epoch_train_f1_weighted.item(), 'Train Multiclass Macro F1:', epoch_train_f1_macro.item())

        total_val_loss = 0
        epoch_preds = np.array([])
        epoch_targets = np.array([])
        for batch in loader_val:
            val_data = batch.to(device)
            batch_loss, batch_preds, batch_targets = validate_batch(model, val_data, criterion)
            total_val_loss += batch_loss
            epoch_preds = np.concatenate((epoch_preds, batch_preds))
            epoch_targets = np.concatenate((epoch_targets, batch_targets))

        val_accuracy, val_f1_weighted, val_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
        val_losses.append(total_val_loss)
        val_weighted_f1.append(val_f1_weighted)
        val_macro_f1.append(val_f1_macro)
        if val_f1_macro > best_model_val_f1:
            best_model_val_f1 = val_f1_macro
            best_model_weights = model.state_dict()
            best_model_weights_epoch = epoch
        print('Validation Loss:', total_val_loss, 'Validation Accuracy:', val_accuracy.item(), 'Validation Multiclass Weighted F1:', val_f1_weighted.item(), 'Validation Multiclass Macro F1:', val_f1_macro.item())

    # get the end time
    et = time.time()
    # get the execution time
    elapsed_time = et - st
    print('Execution time:', elapsed_time, 'seconds')

    experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'fine_tuned_experiments', gnn_type, f'experiment_{experiment_idx}', str(undersample_frac))
    os.makedirs(experiment_dir, exist_ok=True)
    experiment_dict = {'dataset': dataset_name, 'K-shot-dataset_frac': practical_fraction, 'pretrain_strategy': pretraining_strategy, 'gnn_type': gnn_type, 'gnn_layers': gnn_layers, 'gnn_hidden_channels': gnn_hidden_channels, 'classifier_layers': classifier_layers, 'classifier_hidden_channels': classifier_hidden_channels, 'self_loops': self_loops, 'window_size': window_size, 'window_stride': window_stride, 'window_memory': window_memory, 'batch_size': batch_size, 'num_epochs': num_epochs, 'execution_time': elapsed_time}

    # Save the best model
    torch.save(best_model_weights, os.path.join(experiment_dir, f'best_model_weights_epoch_{best_model_weights_epoch}.pth'))

    # Save the experiment metadata
    experiments_results = {'train_losses': train_losses, 'val_losses': val_losses, 'train_weighted_f1': train_weighted_f1, 'val_weighted_f1': val_weighted_f1, 'train_macro_f1': train_macro_f1, 'val_macro_f1': val_macro_f1}
    with open(os.path.join(experiment_dir, 'results.json'), 'w') as f:
        json.dump(experiments_results, f)
    with open(os.path.join(experiment_dir, 'experiment_metadata.json'), 'w') as f:
        json.dump(experiment_dict, f)

## Experiment 3: Reference MLP Baselines

#### Experiment Parameters

In [9]:
# Paths
RAW_DATA_PATH = 'data/raw'
LANDED_DATA_PATH = 'data/landed'
INGESTED_DATA_PATH = 'data/ingested'
UTILS_PATH = 'data/utils'
SAVED_MODELS_PATH = 'saved_models'
CONFIG_PATH = 'configs'

# General Experiment parameters
dataset_name = 'NF_ToN_IoT' # 'NF_ToN_IoT', 'NF_BoT_IoT' or 'NF_UNSW_NB15'
batch_size = 128
num_epochs = 10
weighted_loss = True
classifier_layers = 2
classifier_hidden_channels = 128

#### Experiment Run

In [None]:
data_processor = data_handling.DataPreprocessor(INGESTED_DATA_PATH, UTILS_PATH)
graph_builder = data_handling.GraphBuilder()

attack_mapping = data_processor.load_attack_mapping(dataset_name)

train_raw, val_raw, test_raw = data_processor.load_mixed_train(dataset_name), data_processor.load_mixed_val(dataset_name), data_processor.load_mixed_test(dataset_name)


(train_attrs, train_labels), (val_attrs, val_labels), (test_attrs, test_labels) = data_processor.preprocess_NF(dataset_name, train_raw, keep_IPs_and_timestamp=False, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True), \
            data_processor.preprocess_NF(dataset_name, val_raw, keep_IPs_and_timestamp=False, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True), \
            data_processor.preprocess_NF(dataset_name, test_raw, keep_IPs_and_timestamp=False, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

# Setup optimizer
model = gnn_architectures.MLP(train_attrs.shape[1], classifier_hidden_channels, len(attack_mapping), classifier_layers)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Make sure attrs in alphabetical order and then in tensor form
train_attrs = train_attrs[train_attrs.columns.sort_values()]
val_attrs = val_attrs[val_attrs.columns.sort_values()]
test_attrs = test_attrs[test_attrs.columns.sort_values()]

train_attrs = torch.tensor(train_attrs.values, dtype=torch.float32)
train_labels = train_labels.to_numpy()
labels_torch = torch.Tensor([attack_mapping[attack] for attack in train_labels])

val_attrs = torch.tensor(val_attrs.values, dtype=torch.float32)
val_labels = val_labels.to_numpy()
labels_torch_val = torch.Tensor([attack_mapping[attack] for attack in val_labels])

test_attrs = torch.tensor(test_attrs.values, dtype=torch.float32)
test_labels = test_labels.to_numpy()
labels_torch_test = torch.Tensor([attack_mapping[attack] for attack in test_labels])

tensor_dataset = torch.utils.data.TensorDataset(train_attrs.float(), labels_torch.float())
tensor_dataset_val = torch.utils.data.TensorDataset(val_attrs.float(), labels_torch_val.float())

train_loader = torch.utils.data.DataLoader(tensor_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(tensor_dataset_val, batch_size=batch_size, shuffle=False)

# Setup Loss Criterion
if weighted_loss:
    num_classes = len(attack_mapping)
    class_counts = np.zeros(num_classes)
    for batch in train_loader:
        target = batch[1]
        class_counts += target.sum(dim=0).numpy()

    # print(f'class counts in training data: {attack_mapping.keys()}:{class_counts}')

    total_samples = class_counts.sum()
    class_weights = total_samples / (num_classes * class_counts)

    weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
    criterion = torch.nn.CrossEntropyLoss(weight=weights)

else:
    criterion = torch.nn.CrossEntropyLoss()

# Set model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Train the model
best_model_weights = None
best_model_val_f1 = 0
best_model_weights_epoch = 0

train_losses = []
val_losses = []
train_weighted_f1 = []
val_weighted_f1 = []
train_macro_f1 = []
val_macro_f1 = []

for epoch in tqdm(range(num_epochs)):
    total_train_loss = 0
    epoch_preds = np.array([])
    epoch_targets = np.array([])
    for batch in train_loader:
        train_data, train_targets = batch
        train_data, train_targets = train_data.to(device), train_targets.to(device)
        optimizer.zero_grad()
        batch_preds = model(train_data)
        batch_loss = criterion(batch_preds, train_targets)
        batch_loss.backward()
        optimizer.step()
        total_train_loss += batch_loss.item()
        epoch_preds = np.concatenate((epoch_preds, batch_preds.argmax(dim=1).cpu().detach().numpy()))
        epoch_targets = np.concatenate((epoch_targets, train_targets.argmax(dim=1).cpu().detach().numpy()))

    # Calculate metrics
    epoch_train_accuracy, epoch_train_f1_weighted, epoch_train_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
    train_losses.append(total_train_loss)
    train_weighted_f1.append(epoch_train_f1_weighted)
    train_macro_f1.append(epoch_train_f1_macro)
    print('Epoch:', epoch, 'Train Loss:', total_train_loss, 'Train Accuracy:', epoch_train_accuracy.item(), 'Train Multiclass Weighted F1:', epoch_train_f1_weighted.item(), 'Train Multiclass Macro F1:', epoch_train_f1_macro.item())

    total_val_loss = 0
    epoch_preds = np.array([])
    epoch_targets = np.array([])
    for batch in val_loader:
        val_data, val_targets = batch
        val_data, val_targets = val_data.to(device), val_targets.to(device)
        batch_preds = model(val_data)
        batch_loss = criterion(batch_preds, val_targets)
        total_val_loss += batch_loss.item()
        epoch_preds = np.concatenate((epoch_preds, batch_preds.argmax(dim=1).cpu().detach().numpy()))
        epoch_targets = np.concatenate((epoch_targets, val_targets.argmax(dim=1).cpu().detach().numpy()))

    val_accuracy, val_f1_weighted, val_f1_macro = calculate_multiclass_metrics(epoch_preds, epoch_targets, attack_mapping)
    val_losses.append(total_val_loss)
    val_weighted_f1.append(val_f1_weighted)
    val_macro_f1.append(val_f1_macro)
    if val_f1_macro > best_model_val_f1:
        best_model_val_f1 = val_f1_macro
        best_model_weights = model.state_dict()
        best_model_weights_epoch = epoch
    print('Validation Loss:', total_val_loss, 'Validation Accuracy:', val_accuracy.item(), 'Validation Multiclass Weighted F1:', val_f1_weighted.item(), 'Validation Multiclass Macro F1:', val_f1_macro.item())

os.makedirs(os.path.join(SAVED_MODELS_PATH, dataset_name, 'baselines'), exist_ok=True)
experiment_idx = len(os.listdir(os.path.join(SAVED_MODELS_PATH, dataset_name, 'baselines')))
experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'baselines', f'experiment_{experiment_idx}')
os.makedirs(experiment_dir, exist_ok=True)

experiment_dict = {'dataset': dataset_name, 'gnn_type': 'MLP', 'gnn_layers': 0, 'gnn_hidden_channels': 0, 'classifier_layers': classifier_layers, 'classifier_hidden_channels': classifier_hidden_channels, 'batch_size': batch_size, 'num_epochs': num_epochs, 'weighted_loss': weighted_loss, 'truncate': True}

# Save the best model
torch.save(best_model_weights, os.path.join(experiment_dir, f'best_model_weights_epoch_{best_model_weights_epoch}.pth'))

# Save the experiment metadata
experiments_results = {'train_losses': train_losses, 'val_losses': val_losses, 'train_weighted_f1': train_weighted_f1, 'val_weighted_f1': val_weighted_f1, 'train_macro_f1': train_macro_f1, 'val_macro_f1': val_macro_f1}
with open(os.path.join(experiment_dir, 'results.json'), 'w') as f:
    json.dump(experiments_results, f)
with open(os.path.join(experiment_dir, 'experiment_metadata.json'), 'w') as f:
    json.dump(experiment_dict, f)

# Evaluation

### Get All Model Predictions On Test Set

In [None]:
dataset_to_evaluate = 'NF_ToN_IoT' # Choose from 'NF_ToN_IoT', 'NF_BoT_IoT', 'NF_UNSW_NB15
model_type_to_evaluate = 'temporal' # Choose from 'static', 'temporal'
weights_to_select = 'best_macro' # Choose from 'best_macro', 'best_weighted', or give a checkpoint number

# Set the experiment directories
gnn_from_scratch_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_to_evaluate, 'experiments', model_type_to_evaluate)
gnn_finetune_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_to_evaluate, 'fine_tuned_experiments', model_type_to_evaluate)
baseline_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_to_evaluate, 'baselines')

# Check if the experiment directories exist
os.makedirs(gnn_from_scratch_experiment_dir, exist_ok=True)
os.makedirs(gnn_finetune_experiment_dir, exist_ok=True)
os.makedirs(baseline_experiment_dir, exist_ok=True)

## 1) Testing routine for all GNNs from scratch ----------

for experiment_path in os.listdir(gnn_from_scratch_experiment_dir):
    # Check which experiments have already been tested. Skip if test results already exist
    experiment_path = os.path.join(gnn_from_scratch_experiment_dir, experiment_path)
    files_in_experiment = os.listdir(experiment_path)
    if 'test_set_results.pkl' in files_in_experiment:
        print(f'Skipping {experiment_path} as test results already exist')
        continue
    if len(files_in_experiment) == 0:
        print(f'Skipping {experiment_path} as no files in experiment')
        continue

    # Get Experiment Metadata
    with open(os.path.join(experiment_path, 'experiment_metadata.json'), 'r') as f:
        experiment_dict = json.load(f)

    # Get data according to the experiment
    test_data = data_processor.load_mixed_test(experiment_dict['dataset'])
    attack_mapping = data_processor.load_attack_mapping(experiment_dict['dataset'])
    if 'NF' in experiment_dict['dataset']:
        test_attrs, test_labels = data_processor.preprocess_NF(experiment_dict['dataset'], test_data, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=experiment_dict['truncate'])
    else:
        raise ValueError('Unknown dataset name')

    # Get the graph list
    test_windows = graph_builder.time_window_with_flow_duration(test_attrs, experiment_dict['window_size'], experiment_dict['window_stride'])
    temporal = False
    features = test_attrs.columns
    features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]
    if model_type_to_evaluate == 'temporal':
        temporal = True
        test_graphs, test_window_indices_for_classification = graph_builder.build_spatio_temporal_pyg_graphs(test_windows, test_attrs, test_labels, experiment_dict['window_memory'], experiment_dict['flow_memory'], experiment_dict['include_port'], features, attack_mapping)
    elif model_type_to_evaluate == 'static':
        test_graphs = graph_builder.build_static_pyg_graphs(test_windows, test_attrs, test_labels, experiment_dict['include_port'], features, attack_mapping)
    if experiment_dict['self_loops']:
        test_graphs = [AddSelfLoops()(graph) for graph in test_graphs]
    sample_graph = test_graphs[0]

    # Load the model
    if model_type_to_evaluate == 'temporal':
        gnn_base = gnn_architectures.TemporalSAGE(sample_graph.metadata(), experiment_dict['gnn_hidden_channels'], experiment_dict['gnn_layers'])
    elif model_type_to_evaluate == 'static':
        gnn_base = gnn_architectures.SAGE(sample_graph.metadata(), experiment_dict['gnn_hidden_channels'], experiment_dict['gnn_layers'])
    else:
        raise ValueError('Unknown GNN type')

    model = gnn_architectures.multiclass_NIDS_model(gnn_base, len(attack_mapping), experiment_dict['classifier_hidden_channels'], experiment_dict['classifier_layers'], temporal)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # lazy init
    with torch.no_grad():
        _ = model(test_graphs[0].x_dict, test_graphs[0].edge_index_dict)

    # Load the best model weights
    if weights_to_select == 'best_macro':
        best_model_weights = [f for f in os.listdir(experiment_path) if 'macro' in f][0]
    elif weights_to_select == 'best_weighted':
        best_model_weights = [f for f in os.listdir(experiment_path) if 'weighted' in f][0]
    else:
        best_model_weights = f'model_weights_checkpoint_epoch_{weights_to_select}.pth'
    best_model_weights = torch.load(os.path.join(experiment_path, best_model_weights), map_location=torch.device('cpu'))
    model.load_state_dict(best_model_weights)
    model.to(device)

    # Evaluate the model
    test_loader = DataLoader(test_graphs, batch_size=experiment_dict['batch_size'], shuffle=False)
    model.eval()
    test_preds = np.array([])
    test_targets = np.array([])
    for idx, batch in enumerate(test_loader):
        test_data = batch.to(device)
        with torch.no_grad():
            out = model(test_data.x_dict, test_data.edge_index_dict)
            preds = torch.argmax(out, dim=1)

            if idx == 0:
                test_probs = out.cpu().numpy()
            else:
                test_probs = np.concatenate((test_probs, out.cpu().numpy()), axis=0)
            test_preds = np.concatenate((test_preds, preds.cpu().numpy()), axis=0)
            test_targets = np.concatenate((test_targets, torch.argmax(test_data['con'].y, dim = 1).cpu().numpy()), axis=0)

    # Deduplicate the results if temporal model (cause in temporal model, we have reoccurig flows connected to different windows)
    if experiment_dict['gnn_type'] == 'temporal':
        test_preds, test_targets, test_probs = deduplicate_multiclass_sliding_window_results(test_preds, test_targets, test_probs, test_window_indices_for_classification)

    # Save test_preds, test_targets and test_probs in pickle file
    with open(os.path.join(experiment_path, 'test_set_results.pkl'), 'wb') as f:
        pickle.dump({'test_preds': test_preds, 'test_targets': test_targets, 'test_probs': test_probs}, f)

## 2) Testing routine for all baselines --------

# Change gnn_fine_tune_experiment_dir list to include all experiments including the k-shot ones that have subdirs
experiments_dirs_in_fine_tune = []
files_in_fine_tune = os.listdir(gnn_finetune_experiment_dir)
for f in files_in_fine_tune:
    if 'experiment_metadata.json' in os.listdir(os.path.join(gnn_finetune_experiment_dir, f)):
        experiments_dirs_in_fine_tune.append(os.path.join(gnn_finetune_experiment_dir, f))
        continue
    subdirs = os.listdir(os.path.join(gnn_finetune_experiment_dir, f))
    for subdir in subdirs:
        if 'experiment_metadata.json' in os.listdir(os.path.join(gnn_finetune_experiment_dir, f, subdir)):
            experiments_dirs_in_fine_tune.append(os.path.join(gnn_finetune_experiment_dir, f, subdir))

# NOw evaluate the fine-tuned models
for experiment_path in experiments_dirs_in_fine_tune:
    flow_memory = 20
    files_in_experiment = os.listdir(experiment_path)
    # Check if already a fine-tuned model (having metadata) or a k-shot learning directory having more submodules
    if 'experiment_metadata.json' in files_in_experiment:
        if 'test_set_results.pkl' in files_in_experiment:
            print(f'Skipping {experiment_path} as test results already exist')
            continue
        if len(files_in_experiment) == 0:
            print(f'Skipping {experiment_path} as no files in experiment')
            continue

    # Get Experiment Metadata
    with open(os.path.join(experiment_path, 'experiment_metadata.json'), 'r') as f:
        experiment_dict = json.load(f)

    # Get data according to the experiment
    test_data = data_processor.load_mixed_test(experiment_dict['dataset'])
    attack_mapping = data_processor.load_attack_mapping(experiment_dict['dataset'])

    cross_data_preprocessing = False
    if 'pretrain_strategy' in experiment_dict.keys():
        if experiment_dict["pretrain_strategy"] != 'in_context':
            cross_data_preprocessing = True

    if cross_data_preprocessing:
        test_attrs, test_labels = data_processor.preprocess_NF('all', test_data, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)
    else:
        test_attrs, test_labels = data_processor.preprocess_NF(dataset_to_evaluate, test_data, keep_IPs_and_timestamp=True, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)

    # Get the graph list
    test_windows = graph_builder.time_window_with_flow_duration(test_attrs, experiment_dict['window_size'], experiment_dict['window_stride'])
    features = test_attrs.columns
    features = [feat for feat in features if feat not in ['Dst IP', 'Dst Port', 'Flow Duration Graph Building', 'Src IP', 'Src Port', 'Timestamp']]
    if model_type_to_evaluate == 'temporal':
        temporal = True
        test_graphs, test_window_indices_for_classification = graph_builder.build_spatio_temporal_pyg_graphs(test_windows, test_attrs, test_labels, experiment_dict['window_memory'], flow_memory, False, features, attack_mapping, True)
    elif model_type_to_evaluate == 'static':
        temporal = False
        test_graphs = graph_builder.build_static_pyg_graphs(test_windows, test_attrs, test_labels, experiment_dict['include_port'], features, attack_mapping)
    if experiment_dict['self_loops']:
        test_graphs = [AddSelfLoops()(graph) for graph in test_graphs]
    sample_graph = test_graphs[0]

    # Load the model
    if model_type_to_evaluate == 'temporal':
        gnn_base = gnn_architectures.TemporalSAGE(sample_graph.metadata(), experiment_dict['gnn_hidden_channels'], experiment_dict['gnn_layers'])
    elif model_type_to_evaluate == 'static':
        gnn_base = gnn_architectures.SAGE(sample_graph.metadata(), experiment_dict['gnn_hidden_channels'], experiment_dict['gnn_layers'])
    else:
        raise ValueError('Unknown GNN type')

    model = gnn_architectures.multiclass_NIDS_model(gnn_base, len(attack_mapping), experiment_dict['classifier_hidden_channels'], experiment_dict['classifier_layers'], temporal)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # lazy init
    with torch.no_grad():
        _ = model(test_graphs[0].x_dict, test_graphs[0].edge_index_dict)

    # Load the best model weights
    best_model_weights = [f for f in os.listdir(experiment_path) if 'best_model_weights' in f][0]
    best_model_weights = torch.load(os.path.join(experiment_path, best_model_weights), map_location=torch.device('cpu'))
    model.load_state_dict(best_model_weights)
    model.to(device)

    # Evaluate the model
    test_loader = DataLoader(test_graphs, batch_size=experiment_dict['batch_size'])
    model.eval()
    test_preds = np.array([])
    test_targets = np.array([])
    for idx,batch in enumerate(test_loader):
        test_data = batch.to(device)
        with torch.no_grad():
            out = model(test_data.x_dict, test_data.edge_index_dict)
            preds = torch.argmax(out, dim=1)

            if idx == 0:
                test_probs = out.cpu().numpy()
            else:
                test_probs = np.concatenate((test_probs, out.cpu().numpy()), axis=0)

            test_preds = np.concatenate((test_preds, preds.cpu().numpy()), axis=0)
            test_targets = np.concatenate((test_targets, torch.argmax(test_data['con'].y, dim = 1).cpu().numpy()), axis=0)

    if experiment_dict['gnn_type'] == 'static':
        test_preds, test_targets, test_probs = deduplicate_multiclass_sliding_window_results(test_preds, test_targets, test_probs, test_window_indices_for_classification)

    # Save test_preds, test_targets and test_probs in pickle file
    with open(os.path.join(experiment_path, 'test_set_results.pkl'), 'wb') as f:
        pickle.dump({'test_preds': test_preds, 'test_targets': test_targets, 'test_probs': test_probs}, f)

## 3) Testing routine for all baselines --------

for experiment_path in os.listdir(baseline_experiment_dir):
    experiment_path = os.path.join(baseline_experiment_dir, experiment_path)
    files_in_experiment = os.listdir(experiment_path)
    if 'test_set_results.pkl' in files_in_experiment:
        print(f'Skipping {experiment_path} as test results already exist')
        continue
    if len(files_in_experiment) == 0:
        print(f'Skipping {experiment_path} as no files in experiment')
        continue

    # Get Experiment Metadata
    with open(os.path.join(experiment_path, 'experiment_metadata.json'), 'r') as f:
        experiment_dict = json.load(f)

    if experiment_dict['gnn_type'] == 'MLP':

        # Get data according to the experiment
        test_data = data_processor.load_mixed_test(experiment_dict['dataset'])
        attack_mapping = data_processor.load_attack_mapping(experiment_dict['dataset'])
        test_attrs, test_labels = data_processor.preprocess_NF(experiment_dict['dataset'], test_data, keep_IPs_and_timestamp=False, binary=False, remove_minority_labels=False, only_attacks=False, scale=True, truncate=True)
        
        # Setup model
        model = gnn_architectures.MLP(test_attrs.shape[1], experiment_dict['classifier_hidden_channels'], len(attack_mapping), experiment_dict['classifier_layers'])

        # Load the best model weights
        best_model_weights = [f for f in os.listdir(experiment_path) if 'best_model_weights' in f][0]
        best_model_weights = torch.load(os.path.join(experiment_path, best_model_weights), map_location=torch.device('cpu'))
        model.load_state_dict(best_model_weights)

        # Make sure attrs in alphabetical order and then in tensor form
        test_attrs = test_attrs[test_attrs.columns.sort_values()]
        test_attrs = torch.tensor(test_attrs.values, dtype=torch.float32)
        test_labels = test_labels.to_numpy()
        labels_torch_test = torch.Tensor([attack_mapping[attack] for attack in test_labels])
        tensor_dataset = torch.utils.data.TensorDataset(test_attrs.float(), labels_torch_test.float())
        test_loader = torch.utils.data.DataLoader(tensor_dataset, batch_size=experiment_dict['batch_size'], shuffle=False)

        # Evaluate the model
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        model.eval()
        test_preds_list = np.array([])
        test_targets_list = np.array([])
        for idx, batch in enumerate(test_loader):
            test_data, test_targets = batch
            test_data, test_targets = test_data.to(device), test_targets.to(device)
            with torch.no_grad():
                out = torch.nn.functional.softmax(model(test_data), dim=1)
                preds = torch.argmax(out, dim=1).cpu().numpy()
                targets = torch.argmax(test_targets, dim = 1).cpu().numpy()

                if idx == 0:
                    test_probs_list = out.cpu().numpy()
                else:
                    test_probs_list = np.concatenate((test_probs_list, out.cpu().numpy()), axis=0)

                test_preds_list = np.concatenate((test_preds_list, preds), axis=0)
                test_targets_list = np.concatenate((test_targets_list, targets), axis=0)


        # Save test_preds, test_targets and test_probs in pickle file
        with open(os.path.join(experiment_path, 'test_set_results.pkl'), 'wb') as f:
            pickle.dump({'test_preds': test_preds_list, 'test_targets': test_targets_list, 'test_probs': test_probs_list}, f)

### Aggregate Test Set Predictions to Evaluation Metrics across All Models and Datasets

In [None]:
# Load pickle file 
datasets_to_evaluate = ['NF_ToN_IoT'] # Choose from 'NF_ToN_IoT', 'NF_BoT_IoT', 'NF_UNSW_NB15
model_types_to_evaluate = ['temporal'] # Choose from 'temporal' and 'static'
weights_to_select = 'best_macro' # Choose from 'best_macro', 'best_weighted', or give a checkpoint number

gnn_from_scratch_df = pd.DataFrame(columns=['dataset', 'model_type', 'window_size', 'window_memory', 'multiclass_acc', 'multiclass_f1_weighted', 'multiclass_f1_macro', 'multiclass_roc_auc_macro_ovr', 'multiclass_roc_auc_macro_ovo', 'multiclass_roc_auc_weighted_ovr', 'multiclass_roc_auc_weighted_ovo', 'binary_macro_f1', 'binary_weighted_f1'])
baselines_df = pd.DataFrame(columns=['dataset', 'model_type', 'multiclass_acc', 'multiclass_f1_weighted', 'multiclass_f1_macro', 'multiclass_roc_auc_macro_ovr', 'multiclass_roc_auc_macro_ovo', 'multiclass_roc_auc_weighted_ovr', 'multiclass_roc_auc_weighted_ovo', 'binary_macro_f1', 'binary_weighted_f1'])
fine_tuned_gnn_df = pd.DataFrame(columns=['dataset', 'model_type', 'window_size', 'window_memory', 'multiclass_acc', 'multiclass_f1_weighted', 'multiclass_f1_macro', 'multiclass_roc_auc_macro_ovr', 'multiclass_roc_auc_macro_ovo', 'multiclass_roc_auc_weighted_ovr', 'multiclass_roc_auc_weighted_ovo', 'binary_macro_f1', 'binary_weighted_f1'])
k_shot_learning_df = pd.DataFrame(columns=['dataset', 'model_type', 'k_shot_frac','pretrain_strategy', 'window_size', 'window_memory','multiclass_acc', 'multiclass_f1_weighted', 'multiclass_f1_macro', 'multiclass_roc_auc_macro_ovr', 'multiclass_roc_auc_macro_ovo', 'multiclass_roc_auc_weighted_ovr', 'multiclass_roc_auc_weighted_ovo', 'binary_macro_f1', 'binary_weighted_f1', 'best_train_macro_f1', 'best_train_weighted_f1', 'best_val_macro_f1', 'best_val_weighted_f1'])

for dataset_name in datasets_to_evaluate:
    
    for model_type_to_evaluate in model_types_to_evaluate:
        # Set the experiment directories
        gnn_from_scratch_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'experiments', model_type_to_evaluate)
        gnn_finetune_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'fine_tuned_experiments', model_type_to_evaluate)
        baseline_experiment_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'baselines')

        # Check if the experiment directories exist
        os.makedirs(gnn_from_scratch_experiment_dir, exist_ok=True)
        os.makedirs(gnn_finetune_experiment_dir, exist_ok=True)
        os.makedirs(baseline_experiment_dir, exist_ok=True)

        for experiment_path in os.listdir(gnn_from_scratch_experiment_dir):
            experiment_path = os.path.join(gnn_from_scratch_experiment_dir, experiment_path)
            with open(f'{experiment_path}/experiment_metadata.json', 'r') as f:
                experiment_dict = json.load(f)

            if 'test_set_results.pkl' not in os.listdir(experiment_path):
                print(f'Skipping {experiment_path} with window size {experiment_dict["window_size"]} and window memory {experiment_dict["window_memory"]} as test results do not exist')
                continue
            
            with open(f'{experiment_path}/test_set_results.pkl', 'rb') as f:
                results = pickle.load(f)


            gnn_type, window_size, window_memory = experiment_dict['gnn_type'], experiment_dict['window_size'], experiment_dict['window_memory']
            if window_memory == 5:
                multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1 = calculate_multiclass_test_metrics(results['test_preds'], results['test_targets'], results['test_probs'])
                gnn_from_scratch_df = gnn_from_scratch_df.append({'dataset':dataset_name ,'model_type': gnn_type, 'window_size': window_size, 'window_memory': window_memory, 'multiclass_acc': multiclass_acc, 'multiclass_f1_weighted': multiclass_f1_weighted, 'multiclass_f1_macro': multiclass_f1_macro, 'multiclass_roc_auc_macro_ovr': multiclass_roc_auc_macro_ovr, 'multiclass_roc_auc_macro_ovo': multiclass_roc_auc_macro_ovo, 'multiclass_roc_auc_weighted_ovr': multiclass_roc_auc_weighted_ovr, 'multiclass_roc_auc_weighted_ovo': multiclass_roc_auc_weighted_ovo, 'binary_macro_f1': binary_macro_f1, 'binary_weighted_f1': binary_weighted_f1}, ignore_index=True)

        for experiment_path in os.listdir(baseline_experiment_dir):

            experiment_path = os.path.join(baseline_experiment_dir, experiment_path)
            with open(f'{experiment_path}/experiment_metadata.json', 'r') as f:
                experiment_dict = json.load(f)

            if 'test_set_results.pkl' not in os.listdir(experiment_path):
                print(f'Skipping {experiment_path} as test results do not exist')
                continue
            
            with open(f'{experiment_path}/test_set_results.pkl', 'rb') as f:
                results = pickle.load(f)

            multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1 = calculate_multiclass_test_metrics(results['test_preds'], results['test_targets'], results['test_probs'])
            baselines_df = baselines_df.append({'dataset':dataset_name ,'model_type': experiment_dict['gnn_type'], 'multiclass_acc': multiclass_acc, 'multiclass_f1_weighted': multiclass_f1_weighted, 'multiclass_f1_macro': multiclass_f1_macro, 'multiclass_roc_auc_macro_ovr': multiclass_roc_auc_macro_ovr, 'multiclass_roc_auc_macro_ovo': multiclass_roc_auc_macro_ovo, 'multiclass_roc_auc_weighted_ovr': multiclass_roc_auc_weighted_ovr, 'multiclass_roc_auc_weighted_ovo': multiclass_roc_auc_weighted_ovo, 'binary_macro_f1': binary_macro_f1, 'binary_weighted_f1': binary_weighted_f1}, ignore_index=True)
        
        # Change gnn_fine_tune_experiment_dir list to include all experiments including the k-shot ones that have subdirs
        # Change gnn_fine_tune_experiment_dir list to include all experiments including the k-shot ones that have subdirs
        experiments_dirs_in_fine_tune = []
        files_in_fine_tune = os.listdir(gnn_finetune_experiment_dir)
        for f in files_in_fine_tune:
            if 'experiment_metadata.json' in os.listdir(os.path.join(gnn_finetune_experiment_dir, f)):
                experiments_dirs_in_fine_tune.append(os.path.join(gnn_finetune_experiment_dir, f))
                continue
            subdirs = os.listdir(os.path.join(gnn_finetune_experiment_dir, f))
            for subdir in subdirs:
                if 'experiment_metadata.json' in os.listdir(os.path.join(gnn_finetune_experiment_dir, f, subdir)):
                    experiments_dirs_in_fine_tune.append(os.path.join(gnn_finetune_experiment_dir, f, subdir))
        for experiment_path in experiments_dirs_in_fine_tune:
        
            with open(f'{experiment_path}/experiment_metadata.json', 'r') as f:
                experiment_dict = json.load(f)

            k_shot_model = False
            if 'pretrain_strategy' in experiment_dict.keys():
                k_shot_model = True
                with open(f'{experiment_path}/results.json', 'r') as f:
                    train_results = json.load(f)

            if 'test_set_results.pkl' not in os.listdir(experiment_path):
                print(f'Skipping {experiment_path} with window size {experiment_dict["window_size"]} and window memory {experiment_dict["window_memory"]} as test results do not exist')
                continue
            
            with open(f'{experiment_path}/test_set_results.pkl', 'rb') as f:
                results = pickle.load(f)

            gnn_type, window_size, window_memory = experiment_dict['gnn_type'], experiment_dict['window_size'], experiment_dict['window_memory']

            if k_shot_model:
                k_shot_frac = experiment_dict['K-shot-dataset_frac']
                pretrain_strategy = experiment_dict['pretrain_strategy']
                multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1 = calculate_multiclass_test_metrics(results['test_preds'], results['test_targets'], results['test_probs'])
                best_train_macro_f1, best_val_macro_f1 = max(train_results['train_macro_f1']), max(train_results['val_macro_f1'])
                best_train_weighted_f1, best_val_weighted_f1 = max(train_results['train_weighted_f1']), max(train_results['val_weighted_f1'])

                k_shot_learning_df = k_shot_learning_df.append({'dataset':dataset_name ,'model_type': gnn_type, 'k_shot_frac': k_shot_frac, 'pretrain_strategy': pretrain_strategy, 'window_size': window_size, 'window_memory': window_memory, 'multiclass_acc': multiclass_acc, 'multiclass_f1_weighted': multiclass_f1_weighted, 'multiclass_f1_macro': multiclass_f1_macro, 'multiclass_roc_auc_macro_ovr': multiclass_roc_auc_macro_ovr, 'multiclass_roc_auc_macro_ovo': multiclass_roc_auc_macro_ovo, 'multiclass_roc_auc_weighted_ovr': multiclass_roc_auc_weighted_ovr, 'multiclass_roc_auc_weighted_ovo': multiclass_roc_auc_weighted_ovo, 'binary_macro_f1': binary_macro_f1, 'binary_weighted_f1': binary_weighted_f1, 'best_train_macro_f1': best_train_macro_f1, 'best_val_macro_f1': best_val_macro_f1, 'best_train_weighted_f1': best_train_weighted_f1, 'best_val_weighted_f1': best_val_weighted_f1}, ignore_index=True)
            
            else:
                multiclass_acc, multiclass_f1_weighted, multiclass_f1_macro, multiclass_roc_auc_macro_ovr, multiclass_roc_auc_macro_ovo, multiclass_roc_auc_weighted_ovr, multiclass_roc_auc_weighted_ovo, binary_macro_f1, binary_weighted_f1 = calculate_multiclass_test_metrics(results['test_preds'], results['test_targets'], results['test_probs'])

                fine_tuned_gnn_df = fine_tuned_gnn_df.append({'dataset':dataset_name ,'model_type': gnn_type, 'window_size': window_size, 'window_memory': window_memory, 'multiclass_acc': multiclass_acc, 'multiclass_f1_weighted': multiclass_f1_weighted, 'multiclass_f1_macro': multiclass_f1_macro, 'multiclass_roc_auc_macro_ovr': multiclass_roc_auc_macro_ovr, 'multiclass_roc_auc_macro_ovo': multiclass_roc_auc_macro_ovo, 'multiclass_roc_auc_weighted_ovr': multiclass_roc_auc_weighted_ovr, 'multiclass_roc_auc_weighted_ovo': multiclass_roc_auc_weighted_ovo, 'binary_macro_f1': binary_macro_f1, 'binary_weighted_f1': binary_weighted_f1}, ignore_index=True)

print('GNN from scratch results')
print(gnn_from_scratch_df)  
print('Baselines results')
print(baselines_df)
print('Fine-tuned GNN results')
print(fine_tuned_gnn_df)
print('K-Shot Learning results')
print(k_shot_learning_df)