In [1]:
# =========================================================
# E-CLARAE-IDS : Energy-Centered Latent Residual Attention VAE
# =========================================================

# ---------------------------
# 0. IMPORTS
# ---------------------------
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score
)

from torch.utils.data import Dataset, DataLoader

torch.manual_seed(42)
np.random.seed(42)

# ---------------------------
# 1. DATA LOADING & SPLITTING
# ---------------------------
data = pd.read_csv("../data/train.csv")

X = data.drop("attack", axis=1).values
y = data["attack"].values  # 0 = normal, 1 = attack

# Test split (contains attacks)
X_rest, X_test, y_rest, y_test = train_test_split(
    X, y, test_size=0.3, stratify=y, random_state=42
)

# Train only on normal samples
X_normal = X_rest[y_rest == 0]

X_train, X_val = train_test_split(
    X_normal, test_size=0.2, random_state=42
)

X_train = shuffle(X_train, random_state=42)
X_val = shuffle(X_val, random_state=42)
X_test, y_test = shuffle(X_test, y_test, random_state=42)

# Scaling
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

# ---------------------------
# 2. DATASET WRAPPER
# ---------------------------
class TabularDataset(Dataset):
    def __init__(self, X):
        self.X = torch.tensor(X, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx]

# ---------------------------
# 3. MODEL COMPONENTS
# ---------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.bn1 = nn.BatchNorm1d(out_dim)
        self.fc2 = nn.Linear(out_dim, out_dim)
        self.bn2 = nn.BatchNorm1d(out_dim)
        self.skip = nn.Linear(in_dim, out_dim) if in_dim != out_dim else nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = F.leaky_relu(self.bn1(self.fc1(x)), 0.1)
        out = self.bn2(self.fc2(out))
        return F.leaky_relu(out + identity, 0.1)


class MultiHeadFeatureAttention(nn.Module):
    def __init__(self, dim, heads=4):
        super().__init__()
        self.attn = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, dim // 2),
                nn.ReLU(),
                nn.Linear(dim // 2, dim),
                nn.Sigmoid()
            ) for _ in range(heads)
        ])

    def forward(self, x):
        return torch.mean(
            torch.stack([a(x) * x for a in self.attn]),
            dim=0
        )

# ---------------------------
# 4. AGR-VAE MODEL
# ---------------------------
class AGRVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=32):
        super().__init__()

        # Encoder
        self.enc1 = ResidualBlock(input_dim, 512)
        self.enc2 = ResidualBlock(512, 256)
        self.enc3 = ResidualBlock(256, 128)

        self.attn = MultiHeadFeatureAttention(128, heads=4)

        self.mu = nn.Linear(128, latent_dim)
        self.logvar = nn.Linear(128, latent_dim)

        # Decoder
        self.dec1 = ResidualBlock(latent_dim, 128)
        self.dec2 = ResidualBlock(128, 256)
        self.dec3 = ResidualBlock(256, 512)
        self.out = nn.Linear(512, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.enc1(x)
        x = self.enc2(x)
        x = self.enc3(x)

        x = self.attn(x)

        mu = self.mu(x)
        logvar = torch.clamp(self.logvar(x), min=-10, max=10)
        z = self.reparameterize(mu, logvar)

        z = self.dec1(z)
        z = self.dec2(z)
        z = self.dec3(z)

        return self.out(z), mu, logvar, z

# ---------------------------
# 5. ENERGY-AWARE LOSS
# ---------------------------
def energy_vae_loss(x, x_hat, mu, logvar, center, beta=1.0, gamma=0.1):
    center = center.to(mu.device)

    recon = torch.mean((x - x_hat) ** 2, dim=1)
    kl = -0.5 * torch.sum(
    1 + logvar - mu.pow(2) - logvar.exp(),
    dim=1
    ) / mu.size(1)
    latent = torch.sum((mu - center) ** 2, dim=1)

    return torch.mean(recon + beta * kl + gamma * latent)

# ---------------------------
# 6. LATENT CENTER COMPUTATION
# ---------------------------
def compute_latent_center(model, X):
    model.eval()
    device = next(model.parameters()).device
    X = torch.tensor(X, dtype=torch.float32).to(device)

    with torch.no_grad():
        _, mu, _, _ = model(X)
    return mu.mean(dim=0)

# ---------------------------
# 7. TRAINING LOOP
# ---------------------------
def train(model, X_train, epochs=80, batch_size=256, beta=1.0, gamma=0.1):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    loader = DataLoader(
        TabularDataset(X_train),
        batch_size=batch_size,
        shuffle=True
    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=3e-4,
        weight_decay=1e-5
    )

    center = None

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for x in loader:
            x = x.to(device)
            x_noisy = x + 0.01 * torch.randn_like(x)

            optimizer.zero_grad()
            x_hat, mu, logvar, _ = model(x_noisy)

            if center is None:
                center = mu.mean(dim=0).detach()

            beta_w = min(beta, beta * epoch / 20)
            gamma_w = min(gamma, gamma * epoch / 20)

            loss = energy_vae_loss(
                x, x_hat, mu, logvar, center,
                beta=beta_w,
                gamma=gamma_w
            )
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            total_loss += loss.item()

        if epoch % 10 == 0:
            print(f"Epoch {epoch:03d} | Loss {total_loss / len(loader):.6f}")

    return model

# ---------------------------
# 8. ENERGY-BASED SCORING
# ---------------------------
def energy_score(model, X, center, lambda_z=0.1):
    model.eval()
    device = next(model.parameters()).device
    X = torch.tensor(X, dtype=torch.float32).to(device)
    center = center.to(device)

    with torch.no_grad():
        x_hat, mu, logvar, _ = model(X)
        recon = torch.mean((X - x_hat) ** 2, dim=1)
        latent = torch.sum((mu - center) ** 2, dim=1)
    
    kl = -0.5 * torch.sum(
        1 + logvar - mu.pow(2) - logvar.exp(),
        dim=1
    ) / mu.size(1)

    return (recon + lambda_z * latent + 0.05 * kl).cpu().numpy()

# ---------------------------
# 9. EVALUATION
# ---------------------------
def evaluate(model, X_test, y_test, threshold, center):
    scores = energy_score(model, X_test, center)
    y_pred = (scores > threshold).astype(int)

    acc = accuracy_score(y_test, y_pred)
    p, r, f1, _ = precision_recall_fscore_support(
        y_test, y_pred, average="binary"
    )
    auc = roc_auc_score(y_test, scores)
    
    

    return acc, p, r, f1, auc

# ---------------------------
# 10. RUN EXPERIMENT
# ---------------------------
input_dim = X_train.shape[1]
model = AGRVAE(input_dim, latent_dim=64)

model = train(
    model,
    X_train,
    epochs=120,
    beta=1.5,
    gamma=0.01
)

center = compute_latent_center(model, X_train)

val_scores = energy_score(model, X_val, center)
threshold = np.percentile(val_scores, 90)

acc, p, r, f1, auc = evaluate(
    model,
    X_test,
    y_test,
    threshold,
    center
)

print("\n===== FINAL RESULTS =====")
print(f"Accuracy  : {acc:.4f}")
print(f"Precision : {p:.4f}")
print(f"Recall    : {r:.4f}")
print(f"F1-score  : {f1:.4f}")
print(f"ROC-AUC   : {auc:.4f}")

from sklearn.metrics import classification_report
# Generate classification report
report = classification_report(y_test, (threshold < energy_score(model, X_test, center)).astype(int))
print("Classification Report:\n", report)



Epoch 000 | Loss 0.283544
Epoch 010 | Loss 0.197184
Epoch 020 | Loss 0.273610
Epoch 030 | Loss 0.204874
Epoch 040 | Loss 0.180617
Epoch 050 | Loss 0.177068
Epoch 060 | Loss 0.170770
Epoch 070 | Loss 0.168089
Epoch 080 | Loss 0.162973
Epoch 090 | Loss 0.159840
Epoch 100 | Loss 0.160271
Epoch 110 | Loss 0.166670

===== FINAL RESULTS =====
Accuracy  : 0.9371
Precision : 0.9234
Recall    : 0.9622
F1-score  : 0.9424
ROC-AUC   : 0.9828
Classification Report:
               precision    recall  f1-score   support

           0       0.96      0.91      0.93     17589
           1       0.93      0.96      0.94     20203

    accuracy                           0.94     37792
   macro avg       0.94      0.94      0.94     37792
weighted avg       0.94      0.94      0.94     37792

