In [None]:
# test_vqvae.py
import torch
import numpy as np

from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_fscore_support as prf

from forward_step import ComputeLossVQVAE
from model import DAGMM_VQVAE


def eval(args, model, dataloaders, device, sub=20):
    """Evaluate the DAGMM‑VQ‑VAE model with GMM energy scoring."""
    train_loader, test_loader = dataloaders
    model.eval()
    print('Evaluating DAGMM‑VQ‑VAE...')

    # Use ComputeLossVQVAE only for its compute_params and compute_energy
    compute = ComputeLossVQVAE(
        lambda_energy=args.lambda_energy,
        lambda_cov=args.lambda_cov,
        device=device,
        n_gmm=args.n_gmm
    )

    # 1) Estimate GMM parameters on training (clean) data
    with torch.no_grad():
        N = 0
        gamma_sum = 0
        mu_sum    = 0
        cov_sum   = 0

        for x, _ in train_loader:
            x = x.float().to(device)
            out = model(x)
            z_q   = out['z_q']
            gamma = out['gamma']

            phi_batch, mu_batch, cov_batch = compute.compute_params(z_q, gamma)
            batch_gamma_sum = gamma.sum(dim=0)

            gamma_sum += batch_gamma_sum
            mu_sum    += mu_batch * batch_gamma_sum.unsqueeze(-1)
            cov_sum   += cov_batch * batch_gamma_sum.unsqueeze(-1).unsqueeze(-1)
            N        += x.size(0)

        phi = gamma_sum / N
        mu  = mu_sum    / gamma_sum.unsqueeze(-1)
        cov = cov_sum   / gamma_sum.unsqueeze(-1).unsqueeze(-1)

    # 2) Compute energy scores for train and test
    def get_scores(loader):
        scores, labels = [], []
        with torch.no_grad():
            for x, y in loader:
                x = x.float().to(device)
                out = model(x)
                z_q   = out['z_q']
                gamma = out['gamma']

                energy, _ = compute.compute_energy(
                    z_q, gamma,
                    phi=phi, mu=mu, cov=cov,
                    sample_mean=False
                )
                scores.append(energy.cpu())
                labels.append(y)
        return torch.cat(scores).numpy(), torch.cat(labels).numpy()

    energy_train, labels_train = get_scores(train_loader)
    energy_test,  labels_test  = get_scores(test_loader)

    # Combine for threshold and AUC
    all_scores = np.concatenate([energy_train, energy_test])
    all_labels = np.concatenate([labels_train,  labels_test])

    # Set threshold (e.g., top 20% anomalies)
    thresh = np.percentile(all_scores, 100-sub)
    print(f"Threshold (top {sub}% of TRAIN): {thresh:.4f}")

    # 4) Predict & evaluate on TEST
    preds = (energy_test > thresh).astype(int)
    precision, recall, f1, _ = prf(labels_test, preds, average='binary', zero_division=0)
    auc = roc_auc_score(labels_test, energy_test)

    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    print(f"ROC AUC: {auc*100:.2f}%")

    return labels_test, energy_test


In [None]:
import torch
from preprocess import get_KDDCup99
from train import TrainerDAGMMVQVAE

In [None]:

class Args:
    num_epochs    = 100
    patience      = 50
    lr            = 1e-4
    lr_milestones = [50]
    batch_size    = 1024
    latent_dim    = 1
    n_gmm         = 4
    num_embeddings= 16
    lambda_energy = 0.1
    lambda_cov    = 0.005

args   = Args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data   = get_KDDCup99(args)

trainer = TrainerDAGMMVQVAE(args, data, device)
trainer.train()


In [None]:

for sub in [10,20,30,40,50,60]:
    print(f"-- top {sub}% anomalies --")
    eval(args, trainer.model, data, device, sub=sub)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_inlier_outlier_kde(labels, scores, model_name='DAGMM', sub=None):
    """
    Plot KDEs of inlier (label=0) vs outlier (label=1) score distributions.

    Args:
        labels (array-like): 0 for inliers, 1 for outliers
        scores (array-like): anomaly scores
        model_name (str): Used in the title
        sub (int, optional): percentile threshold used, if any
    """
    # Split scores
    scores_in  = scores[np.where(labels == 0)[0]]
    scores_out = scores[np.where(labels == 1)[0]]

    # Make DataFrames
    df_in  = pd.DataFrame(scores_in,  columns=['Inlier'])
    df_out = pd.DataFrame(scores_out, columns=['Outlier'])

    # Plot
    fig, ax = plt.subplots(figsize=(8, 4))
    df_in .plot.kde(ax=ax, legend=True)
    df_out.plot.kde(ax=ax, legend=True)

    # Title & grids
    title = f'{model_name} Inlier vs Outlier KDE'
    if sub is not None:
        title += f' (top {sub}% threshold)'
    ax.set_title(title)
    ax.grid(axis='x', linestyle='--', alpha=0.5)
    ax.grid(axis='y', linestyle='--', alpha=0.5)
    ax.set_xlabel('Anomaly Score')
    plt.tight_layout()
    plt.show()


In [None]:
labels_vq, scores_vq = eval_vqvae(trainer.model, data, device, args.n_gmm, sub=30)
plot_inlier_outlier_kde(labels_vq, scores_vq, model_name='DAGMM‑VQ‑VAE', sub=30)
