In [2]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import joblib
import tensorflow as tf
from lcvae import LCVAE

# ==========================
# Load and preprocess the full NSL-KDD dataset
# ==========================
data = pd.read_csv('./dataset/cicids2017_clean_all_labels.csv')
print('Original shape:', data.shape)

X = data.drop(columns=['target'])
y = data['target']

# Apply fitted transformer
tdt = joblib.load("./typed_cicids2017_all_features.pkl")
X_encoded = tdt.transform(data)

# Encode labels
encoder_label = LabelEncoder()
y_encoded = encoder_label.fit_transform(y)

# ==========================
# Training parameters
# ==========================
latent_dim = 128
batch_size = 128
epochs = 20
k = 40
lambda_kl = 0.05
input_dim = X_encoded.shape[1]
n_classes = len(np.unique(y_encoded))

# ==========================
# Prepare LCVAE and dataset
# ==========================
model = LCVAE(
    input_dim=input_dim,
    latent_dim=latent_dim,
    n_classes=n_classes,
    lambda_kl=lambda_kl,
    k=k,
    batch_size=batch_size
)
model.build_clustering(X_encoded, y_encoded)

# TensorFlow dataset
X_tf = tf.convert_to_tensor(X_encoded, dtype=tf.float32)
Y_tf = tf.convert_to_tensor(y_encoded, dtype=tf.int32)
train_dataset = tf.data.Dataset.from_tensor_slices((X_tf, Y_tf)).shuffle(1000).batch(batch_size)

# ==========================
# Train the model
# ==========================
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
model.train(train_dataset, optimizer=optimizer, epochs=epochs)

# Trigger a dummy forward pass to ensure model is "built"
_ = model(tf.convert_to_tensor(X_encoded[:1], dtype=tf.float32))

# ==========================
# Save weights and learned latent parameters
# ==========================
model.save_weights("./models_training/lcvae_weights_cicids.weights.h5")
np.save("./models_training/cluster_centers_cicids.npy", model.cluster_centers.numpy())
np.save("./models_training/sigma_per_class_cicids.npy", model.sigma2_per_class.numpy())


Original shape: (2827876, 79)


: 