In [12]:
import numpy as np
import pandas as pd
import joblib
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tqdm
from gmvae import GMVAE

# Load data
raw_data = pd.read_csv("./dataset/kdd_full_clean_5classes.csv")
X = raw_data.drop(columns=["target"])
y = raw_data["target"]

# Encode features
tdt = joblib.load("./typed_nslkdd_all_features.pkl")
X_encoded = tdt.transform(raw_data)

# Split data
_, X_test, _, y_test = train_test_split(X_encoded, y, test_size=0.3, stratify=y, random_state=42)
encoder_label = LabelEncoder()
y_test_encoded = encoder_label.fit_transform(y_test)

# Load GMVAE
input_dim = X_test.shape[1]
latent_dim = 128
n_classes = len(np.unique(y_test_encoded))
model = GMVAE(input_dim=input_dim, latent_dim=latent_dim, n_classes=n_classes, lambda_kl=0.05)
_ = model(tf.convert_to_tensor(X_test[:1], dtype=tf.float32))
model.load_weights("./models_training/gmvae_weights.weights.h5")



In [13]:
# Extract latent representations
def extract_latents(model, data, batch_size):
    latents, labels = [], []
    for i in tqdm(range(0, len(data), batch_size)):
        x_batch = data[i:i+batch_size]
        mu, logvar, _ = model.encode(tf.convert_to_tensor(x_batch, dtype=tf.float32))
        z = model.reparameterize(mu, logvar).numpy()
        latents.append(z)
    return np.concatenate(latents, axis=0)

latent_vectors = extract_latents(model, X_test, batch_size=10000)

# Plot (2D only)
if latent_vectors.shape[1] == 2:
    plt.figure(figsize=(8, 8))
    scatter = plt.scatter(latent_vectors[:, 0], latent_vectors[:, 1], c=y_test_encoded, cmap='viridis', s=10)
    plt.colorbar(scatter)
    plt.title("GMVAE Latent Space")
    plt.grid(True)
    plt.show()




100%|██████████| 5/5 [00:00<00:00, 34.69it/s]


In [14]:
# ==========================
# Generate samples from learned priors
# ==========================
def generate_latent_points_from_clusters(class_distribution, center_cluster, sigma2_per_class, epsilon=1.0, seed=42):
    """
    Generate latent points around each cluster center with class-specific variance.
    """
    np.random.seed(seed)
    z_list = []
    y_list = []

    for label, count in class_distribution.items():
        mean = center_cluster[label]
        stddev = np.sqrt(sigma2_per_class[label]) * epsilon
        samples = np.random.randn(count, mean.shape[0]) * stddev + mean
        z_list.append(samples)
        y_list.append(np.full(count, label))

    z = np.vstack(z_list)
    y = np.concatenate(y_list)
    return z, y

# Step 1: Retrieve cluster centers and variances
cluster_centers = model.prior_mu.numpy()
sigma2_cluster = np.exp(model.prior_logvar.numpy())

# Step 2: Class distribution from original dataset
y_all = raw_data['target'].values
y_all_encoded = encoder_label.transform(y_all)
unique, counts = np.unique(y_all_encoded, return_counts=True)
class_distribution = dict(zip(unique, counts))

# Step 3: Generate latent samples
z, y_gen = generate_latent_points_from_clusters(
    class_distribution=class_distribution,
    center_cluster=cluster_centers,
    sigma2_per_class=sigma2_cluster,
    epsilon=2  # adjust alpha as needed
)

# Step 4: Decode latent samples in batches
def decode_in_batches(model, z_array, batch_size=128):
    decoded_batches = []
    for i in range(0, len(z_array), batch_size):
        z_batch = tf.convert_to_tensor(z_array[i:i+batch_size], dtype=tf.float32)
        x_decoded = model.decode(z_batch).numpy()
        decoded_batches.append(x_decoded)
    return np.vstack(decoded_batches)

generation_decod = decode_in_batches(model, z, batch_size=128)

# Step 5: Inverse preprocessing
Generation_real = tdt.inverse_transform(generation_decod)

# Step 6: Add label column
Generation_real = pd.DataFrame(Generation_real, columns=X.columns)
Generation_real['target'] = encoder_label.inverse_transform(y_gen)

# Step 7: Export to CSV
print("Final shape:", Generation_real.shape)
print(Generation_real.head())
Generation_real.to_csv('./generations/gmvae_df.csv', index=False)


Final shape: (148517, 42)
   duration  protocol_type  service  flag  src_bytes  dst_bytes  land  \
0         0              1       15     9         79          0     0   
1         0              1       20     9         69          0     0   
2         0              0       14     9         42          0     0   
3         0              1       49     9         76          0     0   
4         0              0       15     9         55          0     0   

   wrong_fragment  urgent  hot  ...  dst_host_srv_count  \
0               0       0    0  ...                   2   
1               0       0    0  ...                   2   
2               0       0    0  ...                   2   
3               0       0    0  ...                   2   
4               0       0    0  ...                   2   

   dst_host_same_srv_rate  dst_host_diff_srv_rate  \
0                1.000029                0.000000   
1                0.999984                0.000074   
2                0.01