In [1]:
# Test dataset loading

In [1]:
from pathlib import Path

from torch.utils.data import DataLoader

from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

import lightning as L

import mlflow
import mlflow.pytorch
from mlflow import MlflowClient

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('./datasets')
sys.path.append('./models')

from UNSW_NB15 import UNSWNB15Dataset
from UNSW_NB15 import load_UNSWNB15Dataset, split_UNSWNB15Dataset
from SimpleAE import SimpleAE

In [4]:
data_dir = Path("../data/UNSW-NB15/preprocessed")
data = load_UNSWNB15Dataset(data_dir)

In [5]:
pipeline = Pipeline([
    ('imputer', SimpleImputer(strategy='median').set_output(transform='pandas')),
    ('scaler', StandardScaler().set_output(transform='pandas'))
])

train_dataset, val_dataset, test_dataset = split_UNSWNB15Dataset( 
                                                        data_dir = data_dir,
                                                        data = data,
                                                        records_num = {"train" : 1e5, "val" : 1e4, "test" : 1e5},
                                                        normal_records_num = {"train" : 1.0, "val" : 0.5, "test" : 0.5},
                                                        transformer = pipeline,
                                                        random_state = 42)

In [6]:
train_loader = DataLoader(train_dataset, batch_size=16384, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

In [8]:
for x, y, attack_cat in test_loader:
    print(x)
    print(y)
    print(attack_cat)
    break

tensor([[-1.9093e-01, -3.8237e-01, -2.4498e-01,  ..., -3.1623e-03,
         -1.5144e-01,  0.0000e+00],
        [-1.9093e-01, -3.8237e-01, -2.4498e-01,  ..., -3.1623e-03,
         -1.5144e-01,  0.0000e+00],
        [ 2.1202e-01,  1.4153e+00,  6.3807e+00,  ..., -3.1623e-03,
         -1.5144e-01,  0.0000e+00],
        ...,
        [-1.9062e-01, -3.8082e-01, -2.4399e-01,  ..., -3.1623e-03,
         -1.5144e-01,  0.0000e+00],
        [-1.9093e-01, -3.6785e-01, -2.4498e-01,  ..., -3.1623e-03,
         -1.5144e-01,  0.0000e+00],
        [-1.9093e-01, -3.8237e-01, -2.4498e-01,  ..., -3.1623e-03,
         -1.5144e-01,  0.0000e+00]])
tensor([1, 1, 0,  ..., 0, 0, 1])
tensor([ 9,  9, 13,  ..., 13, 13,  9])


In [9]:
mlflow.set_tracking_uri("http://127.0.0.1:8080")

mlflow.set_experiment("Config test")

mlflow.pytorch.autolog()




In [27]:


input_size = train_dataset[0][0].shape[0]
print("Input dim:", input_size)

model = SimpleAE(input_size, 8)

trainer = L.Trainer(accelerator='gpu', max_epochs=10)

with mlflow.start_run() as run:
    trainer.fit(model, train_loader, val_loader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | encoder | Sequential | 35.0 K | train
1 | decoder | Sequential | 35.2 K | train
-----------------------------------------------
70.2 K    Trainable params
0         Non-trainable params
70.2 K    Total params
0.281     Total estimated model params size (MB)
12        Modules in train mode
0         Modules in eval mode


Input dim: 204


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=150` reached.


In [26]:
print("threshold:", model.threshold)

model.threshold = 0.15

trainer.validate(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2025/04/02 22:44:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'e2d91111aca54cadafa816a490fe3bad', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
2025/04/02 22:44:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '033c9f01ca9046a5b4531dbe81e2e8ce', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
2025/04/02 22:44:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '484a22cfabab4408824dc6fbc89dfe51', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow


threshold: 0.13


2025/04/02 22:44:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '04cb0d2583aa4297a3ecbe718e4e1fd5', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow


Validation: |          | 0/? [00:00<?, ?it/s]

2025/04/02 22:44:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '83ad609c86484e219c474cb7a37e6ab3', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
2025/04/02 22:44:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '353acea74d2447b58032e051e959a7d7', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
2025/04/02 22:44:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '2df2eccfb9c74d7693823ce79c5431c5', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current pytorch workflow
2025/04/02 22:44:09 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'a832226ba15949ab8bab42b5b1d8c8d4', which will track hyperparameters, performance metrics, model artifacts, and lineage i

[{'val_loss': -2.0989444255828857,
  'val_accuracy': 0.9485999941825867,
  'val_precision': 0.9434559345245361,
  'val_recall': 0.9544000029563904,
  'positive_rate': 0.5058000087738037}]