In [None]:
import torch
import torch.optim as optim
from models import create_dynamic_neural_network, create_dynamic_dm_neural_network
from train import train_adversarial
from plotting import plot_training_metrics
from dataloader_ids import load_and_prepare_data

import os
import time
import json

In [None]:
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

# CONFIG

In [None]:
ADVERSARIAL_TRAINING_MODES = ['BASELINE', 'PGD', 'MART', 'TRADES'] 

DATASETS = ['mirai', 'unsw-nb15']
MULTICLASS = [False, True]
ENCODINGS = ['DM', 'Stats', 'Raw']

NUM_EPOCHS = 100

# TRAINING

In [None]:
MODE_PARAMS = {
    "mirai": {
        "dm": {
            "pgd": {"eps": 0.3, "alpha": 0.01, "iters": 40},
            "mart": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
            "trades": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
        },
        "stats": {
            "pgd": {"eps": 0.3, "alpha": 0.01, "iters": 40},
            "mart": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
            "trades": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
        },
        "raw": {
            "pgd": {"eps": 0.3, "alpha": 0.01, "iters": 40},
            "mart": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
            "trades": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
        },
    },
    "unsw-nb15": {
        "dm": {
            "pgd": {"eps": 0.3, "alpha": 0.01, "iters": 40},
            "mart": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
            "trades": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
        },
        "stats": {
            "pgd": {"eps": 0.3, "alpha": 0.01, "iters": 40},
            "mart": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
            "trades": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
        },
        "raw": {
            "pgd": {"eps": 0.3, "alpha": 0.01, "iters": 40},
            "mart": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
            "trades": {"step_size": 0.003, "epsilon": 0.03, "perturb_steps": 10, "beta": 6.0},
        },
    },
}


def get_method_params(method, dataset, encoding):
    method = method.lower()
    dataset = dataset.lower()
    encoding = encoding.lower()
    
    if dataset not in MODE_PARAMS:
        raise ValueError(f"Dataset '{dataset}' is not configured in MODE_PARAMS.")
    if encoding not in MODE_PARAMS[dataset]:
        raise ValueError(f"Encoding '{encoding}' is not configured for dataset '{dataset}'.")
    if method not in MODE_PARAMS[dataset][encoding]:
        raise ValueError(f"Mode '{method}' is not configured for dataset '{dataset}' and encoding '{encoding}'.")
    
    return MODE_PARAMS[dataset][encoding][method]

In [None]:
def save_model(dataset, encoding, method, constrain, model_timestamp, model):
    if encoding in ['DM', 'Stats']:
        model_save_path = f'trained_models/{dataset}/{encoding}/{method}_{constrain}/'
    else:
        model_save_path = f'trained_models/{dataset}/{encoding}/{method}/'
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    torch.save(model.state_dict(), model_save_path + f'{model_timestamp}.pth')

def save_training_results(dataset, encoding, method, constrain, model_timestamp, training_results):
    if encoding in ['DM', 'Stats']:
        results_save_path = f'results/training/{dataset}/{encoding}/{method}_{constrain}/'
    else:
        results_save_path = f'results/training/{dataset}/{encoding}/{method}/'
    os.makedirs(os.path.dirname(results_save_path), exist_ok=True)
    with open(results_save_path + f'{model_timestamp}.json', 'w') as f:
        json.dump(training_results, f)

In [None]:
def extract_hidden_dims(params):
    hidden_dims = []
    for key, value in params.items():
        if key.startswith("hidden_dim_layer_"):
            hidden_dims.append(value)

    # Sort by layer index in case keys are unordered
    hidden_dims = [v for k, v in sorted((key, value) for key, value in params.items() if key.startswith("hidden_dim_layer_"))]
    return hidden_dims

In [None]:
total_iterations = (
    len(DATASETS) * len(ADVERSARIAL_TRAINING_MODES) *  # Base factors
    (
        len([enc for enc in ENCODINGS if enc in ['DM', 'Stats']]) * 2 +  # Encodings with constraints
        len([enc for enc in ENCODINGS if enc not in ['DM', 'Stats']]) * 1  # Encodings without constraints
    )
)
iteration_counter = 0 


for dataset, multiclass in zip(DATASETS, MULTICLASS):
    for encoding in ENCODINGS:
        # Determine whether to include the 'constrain' loop
        constrain_values = [True, False] if encoding in ['DM', 'Stats'] else [False]

        print(f'DATASET: {dataset}, encoding: {encoding}')
        dataset = dataset.lower()
        
        best_params_fp = f"results/model_discovery/{dataset}/{encoding}/best_params.json"
        with open(best_params_fp, "r") as f:
            best_params = json.load(f)
            hidden_dims = extract_hidden_dims(best_params)
            learning_rate = best_params["lr"]
            dropout_rate = best_params["dropout"]
            batch_size = best_params["batch_size"]

        train_loader, val_loader, _, _, input_dim, output_dim, y_mapping, _ = load_and_prepare_data(
            dataset_key=dataset, 
            encoding_key=encoding,
            multiclass=multiclass,
            batch_size=batch_size,
        )
        print("Data loaded and DataLoader created successfully!")
        print(f'Input dimension: {input_dim}, Output dimension: {output_dim}')
        print(f'Shapes: Train: {len(train_loader.dataset)}, Test: {len(val_loader.dataset)}')
        
        for constrain in constrain_values:
            for method in ADVERSARIAL_TRAINING_MODES:
                iteration_counter += 1
            
                method = method.lower()
                method_params = {}
                if method != 'baseline':
                    method_params = get_method_params(method, dataset, encoding)

                print(f' ------ [{iteration_counter}/{total_iterations}]: ADVERSARIAL TRAINING MODE: {method}, Constrain: {constrain}')
                
                if encoding == 'DM':
                    model, criterion, optimizer = create_dynamic_dm_neural_network(
                        input_dim=input_dim,
                        output_dim=output_dim,
                        multiclass=multiclass,
                        hidden_dims=hidden_dims,
                        optimizer="adam",
                        lr=learning_rate,
                        dropout_rate=dropout_rate
                    )
                else:
                    model, criterion, optimizer = create_dynamic_neural_network(
                        input_dim=input_dim,
                        output_dim=output_dim,
                        multiclass=multiclass,
                        hidden_dims=hidden_dims,
                        optimizer="adam",
                        lr=learning_rate,
                        dropout_rate=dropout_rate
                    )

                model = model.to(device)
                print(model)
                
                training_results, model = train_adversarial(
                    model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    optimizer=optimizer,
                    criterion=criterion,
                    device=device,
                    encoding=encoding,
                    constrain=constrain,
                    num_epochs=NUM_EPOCHS,
                    patience=25,
                    method=method,
                    verbose=False,
                    **method_params
                )
                
                model_timestamp = time.strftime("%Y%m%d-%H%M%S")
                save_model(dataset, encoding, method, constrain, model_timestamp, model)
                save_training_results(dataset, encoding, method, constrain, model_timestamp, training_results)
                plot_training_metrics(dataset, encoding, method, constrain, model_timestamp, training_results)
