## Imports

In [None]:
import data_handling
import os
import torch
import json
from util_scripts import gnn_architectures
from util_scripts.link_prediction_pretraining import train_model, load_link_prediction_model, process_data_and_build_graphs, get_metadata_and_sample_graph, process_data_and_build_out_of_context_graphs

## Set hyperparameters

In [5]:
# Set paths
current_dir = os.getcwd()
DATA_PATH = f'{current_dir}/data/ingested'
UTILS_PATH = f'{current_dir}/data/utils'
SAVED_MODELS_PATH = f'{current_dir}/saved_models'
CONFIG_PATH = f'{current_dir}/configs'
SAVED_GRAPHS_PATH = f'{current_dir}/data/saved_graphs'

# General hyperparameters
dataset_name = 'NF_ToN_IoT' # Choose from ['NF_ToN_IoT', 'NF_UNSW_NB15', 'NF_BoT_IoT', 'all']
truncate = True
graph_type = 'temporal' # Choose from ['static', 'temporal']
graph_building = 'normal'
pre_training_strategy = 'out_context' # Choose from ['in_context', 'out_context']

# Model hyperparameters
gnn_layers = 2 # 2,3
gnn_hidden_channels = 128
window_size = 5 #20, 10, 5, 1, 0.5 # 2, 5, 10, 20 # 10, 30
graph_building = 'normal' # 'connect_flows', 'normal'
include_port = False # True, False
self_loops = False
window_memory = 5 # [3, 5]
classifier_layers = 2
classifier_hidden_channels = 128

# Specific hyperparameters for temporal graph + model type
flow_memory = 20

# Training hyperparameters
num_epochs = 3
batch_size = 4 # Keep low enough to not run out of memory (especially with BoT_IoT involved)
save_epoch_every = 1
learning_rate = 0.001
continue_training_from_checkpoint = False

# Specific for out_context pre-training
use_last_saved_graphs = True # Use the last saved graphs for out_context pre-training if they exist (check hyperparameters in that file to see if they fit, otherwise set to False)

# For if continue_training_from_checkpoint is True
checkpoint_dir = '/content/drive/MyDrive/BNN-UPC/pre_training/saved_models/all/pretraining_experiments/TemporalPlus_v2/experiment_5' # Set directory of checkpoint if continue_training_from_checkpoint is True
checkpoint_epoch = 20 # Set epoch of checkpoint if continue_training_from_checkpoint is True

## Pretraining

In [None]:
data_preprocessor = data_handling.DataPreprocessor(DATA_PATH, UTILS_PATH)
graph_builder = data_handling.GraphBuilder()

# If training should continue from a checkpoint, override some parts of the config with the checkpoint config
if continue_training_from_checkpoint:

    # set experiemnt dir to the checkpoint dir
    experiment_dir = checkpoint_dir

# Else create a new experiment dir
else:
    os.makedirs(os.path.join(SAVED_MODELS_PATH, dataset_name, 'pretraining_experiments', graph_type), exist_ok=True)
    model_dir = os.path.join(SAVED_MODELS_PATH, dataset_name, 'pretraining_experiments', graph_type)
    experiment_idx = len(os.listdir(model_dir))
    experiment_dir = os.path.join(model_dir, f'experiment_{experiment_idx}')
    os.makedirs(experiment_dir)

# Build initial graph list for graph sample and metadata
if pre_training_strategy != 'in_context':
    # Just get dummy attack mapping and sample graph. Doesnt matter. is for initalization and is standardized over all datasets.
    attack_mapping = data_preprocessor.load_attack_mapping('NF_ToN_IoT')
    all_train_files_and_indices = data_preprocessor.get_all_train_files_and_indices('NF_ToN_IoT')
    graph_metadata, graph_sample, features = get_metadata_and_sample_graph(True, 'NF_ToN_IoT', data_preprocessor, graph_builder, all_train_files_and_indices[0][1], graph_type, window_size, window_memory, include_port, attack_mapping)
else:
    attack_mapping = data_preprocessor.load_attack_mapping(dataset_name)
    all_train_files_and_indices = data_preprocessor.get_all_train_files_and_indices(dataset_name)
    graph_metadata, graph_sample, features = get_metadata_and_sample_graph(False, dataset_name, data_preprocessor, graph_builder, all_train_files_and_indices[0][1], graph_type, window_size, window_memory, include_port, attack_mapping)

# Initialize model
starting_epoch = 0
if graph_type == 'temporal':
    gnn_base = gnn_architectures.TemporalPlusConv_v2(graph_metadata, gnn_hidden_channels, gnn_layers)
elif graph_type == 'static':
    gnn_base = gnn_architectures.SAGE(graph_metadata, gnn_hidden_channels, gnn_layers)
else:
    raise ValueError('Unknown GNN type')

model = gnn_architectures.LinkPredictionModel(gnn_base)

if continue_training_from_checkpoint: # Load model from checkpoint if needed
    model = load_link_prediction_model(model, checkpoint_dir, checkpoint_epoch, graph_metadata, graph_type, graph_sample)
    starting_epoch = checkpoint_epoch + 1

if pre_training_strategy == 'in_context':
    print('Training on curated training set')

    # Process data and build graphs
    train_graphs = process_data_and_build_graphs(False, dataset_name, data_preprocessor, graph_builder, graph_type, window_size, window_memory, include_port, attack_mapping, flow_memory, on_curated_train=True)

    # Save the experiment metadata
    experiments_metadata = {'pre_training_strategy': pre_training_strategy ,'graph_type': graph_type, 'flow_memory': flow_memory, 'gnn_layer': gnn_layers, 'window_size': window_size, 'window_memory': window_memory, 'include_port': include_port, 'self_loops': self_loops, 'gnn_hidden_channels': gnn_hidden_channels, 'classifier_layers': classifier_layers, 'classifier_hidden_channels': classifier_hidden_channels}
    with open(os.path.join(experiment_dir, 'experiment_metadata.json'), 'w') as f:
        json.dump(experiments_metadata, f)

    # Train the model
    train_model(model, train_graphs, graph_metadata, graph_type, learning_rate, batch_size, experiment_dir, save_epoch_every, starting_epoch=starting_epoch, epoch_end=num_epochs)

elif pre_training_strategy == 'out_context':
    print("Training on unlabeled training sets that are not the dataset itself (the target dataset on which we'll later fine-tune)")

    current_epoch = starting_epoch

    datasets = ['NF_BoT_IoT', 'NF_ToN_IoT', 'NF_UNSW_NB15']
    datasets.remove(dataset_name)
    if use_last_saved_graphs and os.path.exists(os.path.join(SAVED_GRAPHS_PATH, f'{datasets[0]}_{datasets[1]}_link_prediction_training_graph_list.pt')):
        mixed_graphs = torch.load(os.path.join(SAVED_GRAPHS_PATH, f'{datasets[0]}_{datasets[1]}_link_prediction_training_graph_list.pt'))
    else:
        print(f'Saved graph list of mixed datasets {datasets[0]} and {datasets[1]} not found! Building new graph list and saving them...')
        mixed_graphs = process_data_and_build_out_of_context_graphs(datasets, data_preprocessor, graph_builder, graph_type, window_size, window_memory, include_port, flow_memory, idx=0)
        torch.save(mixed_graphs, os.path.join(SAVED_GRAPHS_PATH, f'{datasets[0]}_{datasets[1]}_link_prediction_training_graph_list.pt'))     
        print(f'Graph list of mixed datasets {datasets[0]} and {datasets[1]} built and saved at {os.path.join(SAVED_GRAPHS_PATH, f"{datasets[0]}_{datasets[1]}_link_prediction_training_graph_list.pt")}')  
        print(f"Saving hyperparameters and metadata of the saved pretrain graphs at {os.path.join(SAVED_GRAPHS_PATH, f'{datasets[0]}_{datasets[1]}_link_prediction_training_graph_list_metadata.json')}")
        with open(os.path.join(SAVED_GRAPHS_PATH, f'{datasets[0]}_{datasets[1]}_link_prediction_training_graph_list_metadata.json'), 'w') as f:
            json.dump({'graph_type': graph_type, 'window_size': window_size, 'window_memory': window_memory, 'include_port': include_port, 'flow_memory': flow_memory}, f)

    # Save the experiment metadata
    experiments_metadata = {'pre_training_strategy': pre_training_strategy ,'graph_type': graph_type, 'flow_memory': flow_memory, 'gnn_layer': gnn_layers, 'window_size': window_size, 'window_memory': window_memory, 'include_port': include_port, 'self_loops': self_loops, 'gnn_hidden_channels': gnn_hidden_channels, 'classifier_layers': classifier_layers, 'classifier_hidden_channels': classifier_hidden_channels}
    with open(os.path.join(experiment_dir, 'experiment_metadata.json'), 'w') as f:
        json.dump(experiments_metadata, f)

    # Train the model
    train_model(model, mixed_graphs, graph_metadata, graph_type, learning_rate, batch_size, experiment_dir, save_epoch_every, starting_epoch=starting_epoch, epoch_end=num_epochs)

else:
    print('No valid pretraining strategy specified. Choose from "in_context", "out_context" or "mixed_context"')

- Loading in all training set file nr 0...


  ohe[f'{attribute}_{value}'] = 0


-- Min-Max scaling numerical columns...


100%|██████████| 10/10 [00:00<00:00, 10.70it/s]


Training on unlabeled training sets that are not the dataset itself (the target dataset on which we'll later fine-tune)


 33%|███▎      | 1/3 [02:31<05:02, 151.20s/it]

Epoch 0 -- Training Loss: 288.2467613220215, Validation Loss: 288.10518860816956, validation Acc: 0.5150745370262246


 67%|██████▋   | 2/3 [03:56<01:52, 112.62s/it]

Epoch 1 -- Training Loss: 286.9778376221657, Validation Loss: 286.7381873726845, validation Acc: 0.5304013743675012


100%|██████████| 3/3 [05:29<00:00, 109.89s/it]

Epoch 2 -- Training Loss: 281.8711902499199, Validation Loss: 281.4733741879463, validation Acc: 0.5768741382111748



