In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import joblib
import tensorflow as tf
from gmvae import GMVAE

# ==========================
# Load and preprocess the NSL-KDD dataset
# ==========================
data = pd.read_csv('./dataset/kdd_full_clean_5classes.csv')
print('Original shape:', data.shape)

X = data.drop(columns=['target'])
y = data['target']

# Apply fitted transformer
tdt = joblib.load("./typed_nslkdd_all_features.pkl")
X_encoded = tdt.transform(data)

# Train-test split
x_train, x_test, y_train, y_test = train_test_split(
    X_encoded, y, test_size=0.3, random_state=42, stratify=y
)

# Encode labels
encoder_label = LabelEncoder()
y_train_encoded = encoder_label.fit_transform(y_train)

# ==========================
# Training parameters
# ==========================
latent_dim = 128
batch_size = 128
epochs = 30
lambda_kl = 0.05
input_dim = x_train.shape[1]
n_classes = len(np.unique(y_train_encoded))

# ==========================
# Prepare GMVAE and dataset
# ==========================
model = GMVAE(
    input_dim=input_dim,
    latent_dim=latent_dim,
    n_classes=n_classes,
    lambda_kl=lambda_kl
)

# TensorFlow dataset
X_tf = tf.convert_to_tensor(x_train, dtype=tf.float32)
Y_tf = tf.convert_to_tensor(y_train_encoded, dtype=tf.int32)
train_dataset = tf.data.Dataset.from_tensor_slices((X_tf, Y_tf)).shuffle(1000).batch(batch_size)

# ==========================
# Train the model
# ==========================
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
model.train(train_dataset, optimizer=optimizer, epochs=epochs)

# Trigger a dummy forward pass to ensure the model is 'built'
_ = model(tf.convert_to_tensor(x_train[:1], dtype=tf.float32))

# ==========================
# Save model and learned parameters
# ==========================
model.save_weights("./models_training/gmvae_weights.weights.h5")


2025-04-16 12:12:33.198717: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-16 12:12:33.213337: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744798353.229667   84882 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744798353.234586   84882 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744798353.247742   84882 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Original shape: (148517, 42)


I0000 00:00:1744798358.059028   84882 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1768 MB memory:  -> device: 0, name: NVIDIA RTX A1000 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6



Epoch 1/30


Batch:  99%|█████████▉| 808/813 [00:18<00:00, 55.85it/s, Loss=3.99, Recon=3.58, Class=0.05, KL=7.05]      2025-04-16 12:12:56.810432: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
                                                                                                    

→ Epoch 01 | Total Loss: 36818.3 | Recon: 16425.5 | Class: 5848.1 | KL: 290896.7

Epoch 2/30


Batch: 100%|█████████▉| 810/813 [00:14<00:00, 60.45it/s, Loss=3.89, Recon=3.74, Class=0.04, KL=2.25]     2025-04-16 12:13:11.552756: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
                                                                                                    

→ Epoch 02 | Total Loss: 22537.1 | Recon: 15953.1 | Class: 2500.0 | KL: 81679.3

Epoch 3/30


                                                                                                         

→ Epoch 03 | Total Loss: 21111.0 | Recon: 15945.1 | Class: 2018.2 | KL: 62954.6

Epoch 4/30


Batch:  99%|█████████▉| 807/813 [00:14<00:00, 52.81it/s, Loss=4.03, Recon=3.90, Class=0.04, KL=2.00]     2025-04-16 12:13:39.985001: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
                                                                                                    

→ Epoch 04 | Total Loss: 20389.9 | Recon: 15936.4 | Class: 1771.8 | KL: 53633.5

Epoch 5/30


                                                                                                         

→ Epoch 05 | Total Loss: 19679.2 | Recon: 15931.1 | Class: 1466.3 | KL: 45635.2

Epoch 6/30


                                                                                                         

→ Epoch 06 | Total Loss: 19296.3 | Recon: 15928.2 | Class: 1309.4 | KL: 41175.3

Epoch 7/30


                                                                                                         

→ Epoch 07 | Total Loss: 19325.3 | Recon: 15922.5 | Class: 1322.0 | KL: 41616.7

Epoch 8/30


Batch: 100%|██████████| 813/813 [00:14<00:00, 57.94it/s, Loss=9.05, Recon=3.81, Class=2.11, KL=62.50]    2025-04-16 12:14:38.028116: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
                                                                                                     

→ Epoch 08 | Total Loss: 18849.7 | Recon: 15925.6 | Class: 1128.9 | KL: 35903.9

Epoch 9/30


                                                                                                         

→ Epoch 09 | Total Loss: 18665.9 | Recon: 15926.0 | Class: 1063.5 | KL: 33530.1

Epoch 10/30


                                                                                                         

→ Epoch 10 | Total Loss: 18479.6 | Recon: 15927.2 | Class: 967.1 | KL: 31705.0

Epoch 11/30


                                                                                                         

→ Epoch 11 | Total Loss: 18513.7 | Recon: 15927.1 | Class: 1003.9 | KL: 31654.0

Epoch 12/30


                                                                                                         

→ Epoch 12 | Total Loss: 18059.8 | Recon: 15925.0 | Class: 803.6 | KL: 26623.7

Epoch 13/30


                                                                                                         

→ Epoch 13 | Total Loss: 18078.6 | Recon: 15923.8 | Class: 826.6 | KL: 26565.0

Epoch 14/30


                                                                                                         

→ Epoch 14 | Total Loss: 18024.5 | Recon: 15925.6 | Class: 797.3 | KL: 26031.7

Epoch 15/30


                                                                                                         

→ Epoch 15 | Total Loss: 18089.8 | Recon: 15924.8 | Class: 816.5 | KL: 26970.4

Epoch 16/30


Batch: 100%|█████████▉| 812/813 [00:14<00:00, 56.46it/s, Loss=4.12, Recon=4.11, Class=0.00, KL=0.14]     2025-04-16 12:16:36.632902: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
                                                                                                    

→ Epoch 16 | Total Loss: 17797.3 | Recon: 15923.3 | Class: 697.5 | KL: 23528.7

Epoch 17/30


                                                                                                         

→ Epoch 17 | Total Loss: 17652.4 | Recon: 15921.4 | Class: 648.6 | KL: 21647.5

Epoch 18/30


                                                                                                        

→ Epoch 18 | Total Loss: 17644.3 | Recon: 15923.4 | Class: 644.4 | KL: 21529.5

Epoch 19/30


                                                                                                         

→ Epoch 19 | Total Loss: 17667.2 | Recon: 15921.7 | Class: 658.2 | KL: 21747.6

Epoch 20/30


                                                                                                         

→ Epoch 20 | Total Loss: 17540.4 | Recon: 15921.1 | Class: 587.9 | KL: 20628.4

Epoch 21/30


                                                                                                         

→ Epoch 21 | Total Loss: 17369.7 | Recon: 15920.5 | Class: 547.9 | KL: 18028.4

Epoch 22/30


                                                                                                         

→ Epoch 22 | Total Loss: 17464.5 | Recon: 15917.5 | Class: 580.2 | KL: 19335.8

Epoch 23/30


                                                                                                         

→ Epoch 23 | Total Loss: 17342.9 | Recon: 15920.4 | Class: 525.2 | KL: 17946.8

Epoch 24/30


                                                                                                         

→ Epoch 24 | Total Loss: 17636.2 | Recon: 15918.5 | Class: 647.0 | KL: 21413.7

Epoch 25/30


                                                                                                         

→ Epoch 25 | Total Loss: 17476.0 | Recon: 15917.8 | Class: 570.7 | KL: 19751.2

Epoch 26/30


                                                                                                         

→ Epoch 26 | Total Loss: 17427.9 | Recon: 15919.8 | Class: 548.6 | KL: 19192.1

Epoch 27/30


                                                                                                         

→ Epoch 27 | Total Loss: 17162.4 | Recon: 15918.4 | Class: 444.3 | KL: 15995.4

Epoch 28/30


                                                                                                         

→ Epoch 28 | Total Loss: 17116.7 | Recon: 15918.7 | Class: 435.0 | KL: 15260.7

Epoch 29/30


                                                                                                         

→ Epoch 29 | Total Loss: 17397.0 | Recon: 15916.7 | Class: 537.5 | KL: 18855.7

Epoch 30/30


                                                                                                        

→ Epoch 30 | Total Loss: 17203.8 | Recon: 15917.9 | Class: 461.8 | KL: 16483.3
