# FAIIA-IDS Ablation Study (Refactored)

This notebook runs the ablation study for the FAIIA-IDS model by cloning the refactored codebase from GitHub.

**Note:** Please replace `https://github.com/USERNAME/REPO_NAME.git` with your actual repository URL.

In [None]:
# 1. Clone Repository
# TODO: Replace with your actual repository URL
GIT_REPO_URL = "https://github.com/Arif-Foysal/FAA-Net.git"
REPO_DIR = "FAA-Net" # This usually matches the name of the git repo

!git clone {GIT_REPO_URL}

import os
if os.path.exists(REPO_DIR):
    os.chdir(REPO_DIR)
    print(f"Changed directory to: {os.getcwd()}")
else:
    print(f"Warning: Could not find directory {REPO_DIR}. Check if git clone succeeded.")

# 2. Mount Google Drive (for saving models)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully.")
except ImportError:
    print("Not running in Google Colab, skipping Drive mount.")

# 3. Install Dependencies
!pip install -r requirements.txt

In [None]:
import sys
# Ensure project root is in path
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

import torch
import pandas as pd
from scripts.run_ablation import main as run_ablation_experiment
from scripts.train_main import main as train_main_model
from scripts.train_baselines import main as run_baseline_experiment

## 1. Train Main EDAN v3 Model

Trains the full EDAN v3 model with FAIIA and Minority Prototypes using the standard configuration.

In [None]:
try:
    # Run training logic
    train_main_model()
except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Please ensure dataset files (UNSW_NB15) are present in the project root or /content.")

## 2. Run Ablation Study

Runs 6 experiments:
1. Vanilla DNN + BCE
2. Vanilla DNN + Focal Loss
3. FAIIA + BCE
4. FAIIA + Focal Loss
5. **FAIIA + EWKM + BCE** *(NEW — Entropy-Weighted KMeans prototypes)*
6. **FAIIA + EWKM + Focal Loss** *(NEW — Entropy-Weighted KMeans prototypes)*

In [None]:
try:
    run_ablation_experiment()
except FileNotFoundError as e:
    print(f"Error: {e}")

## 2b. EWKM Ablation (Entropy-Weighted KMeans Prototypes)

Runs **only** the two new EWKM experiments:
1. **FAIIA + EWKM + BCE** — Feature-discriminative prototypes with weighted BCE
2. **FAIIA + EWKM + Focal Loss** — Feature-discriminative prototypes with Focal Loss

> Requires the data to be loaded (Cell 2) and imports (Cell 3) to have been executed.

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
from core.config import V3_CONFIG, RANDOM_STATE
from core.data_loader import load_and_preprocess_data, create_dataloaders
from core.ablation import EDANv3_Ablation
from core.loss import ImbalanceAwareFocalLoss_Logits
from core.trainer import train_model
from core.utils import set_all_seeds, evaluate_model, print_metrics, save_predictions
from sklearn.cluster import KMeans

# ============================================================
# EWKM: Only used for better prototype centroid generation
# ============================================================
class EntropyWeightedKMeans:
    """
    Entropy-Weighted K-Means for feature-discriminative prototype generation.
    Produces better centroids than standard KMeans by learning per-cluster
    feature importance weights during clustering.
    """
    def __init__(self, n_clusters=8, gamma=0.1, max_iter=100, tol=1e-4, random_state=42):
        self.n_clusters = n_clusters
        self.gamma = gamma
        self.max_iter = max_iter
        self.tol = tol
        self.random_state = random_state
        self.cluster_centers_ = None
        self.feature_weights_ = None
        self.n_iter_ = 0

    def fit(self, X):
        n_samples, n_features = X.shape
        kmeans_init = KMeans(n_clusters=self.n_clusters, random_state=self.random_state, n_init=10)
        kmeans_init.fit(X)
        labels = kmeans_init.labels_
        centers = kmeans_init.cluster_centers_.copy()
        weights = np.ones((self.n_clusters, n_features)) / n_features

        for iteration in range(self.max_iter):
            old_centers = centers.copy()

            # E-step: weighted distance assignment
            distances = np.zeros((n_samples, self.n_clusters))
            for l in range(self.n_clusters):
                diff = X - centers[l]
                distances[:, l] = np.sum(weights[l] * (diff ** 2), axis=1)
            labels = np.argmin(distances, axis=1)

            # M-step 1: update centroids
            for l in range(self.n_clusters):
                members = X[labels == l]
                if len(members) > 0:
                    centers[l] = members.mean(axis=0)

            # M-step 2: update feature weights via entropy regularization
            for l in range(self.n_clusters):
                members = X[labels == l]
                if len(members) > 0:
                    dispersions = np.sum((members - centers[l]) ** 2, axis=0) + 1e-10
                    log_w = -dispersions / self.gamma
                    log_w -= log_w.max()
                    weights[l] = np.exp(log_w)
                    weights[l] /= weights[l].sum()

            if np.linalg.norm(centers - old_centers) < self.tol:
                break

        self.cluster_centers_ = centers
        self.feature_weights_ = weights
        self.n_iter_ = iteration + 1
        return self

# ============================================================
# Config
# ============================================================
V3_EWKM_CONFIG = {**V3_CONFIG, 'use_ewkm': True, 'ewkm_gamma': 0.1, 'patience': 30}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_all_seeds(RANDOM_STATE)

# --- Load Data ---
data_dir = "/content" if os.path.exists("/content") else "."
X_train_scaled, X_test_scaled, y_train, y_test, _, _ = load_and_preprocess_data(data_dir=data_dir)

train_loader, val_loader, _, X_test_tensor = create_dataloaders(
    X_train_scaled, y_train, X_test_scaled, y_test,
    batch_size=V3_EWKM_CONFIG['batch_size']
)

input_dim = X_train_scaled.shape[1]
count_positive = y_train.sum()
count_negative = len(y_train) - count_positive
class_counts = [count_negative, count_positive]
pos_weight = torch.tensor([count_negative / count_positive], device=device, dtype=torch.float32)

# --- Generate EWKM Prototypes ---
print("Generating EWKM prototypes...")
minority_mask = y_train.values == 1
X_minority = X_train_scaled[minority_mask]

ewkm = EntropyWeightedKMeans(
    n_clusters=V3_EWKM_CONFIG['n_prototypes'],
    gamma=V3_EWKM_CONFIG['ewkm_gamma'],
    random_state=RANDOM_STATE
)
ewkm.fit(X_minority)
prototypes_ewkm = ewkm.cluster_centers_

print(f"  Shape: {prototypes_ewkm.shape}")
print(f"  Gamma: {V3_EWKM_CONFIG['ewkm_gamma']}")
print(f"  Converged in {ewkm.n_iter_} iterations")

# Diagnostic: verify feature weights are NOT uniform
for i in range(min(3, ewkm.feature_weights_.shape[0])):
    w = ewkm.feature_weights_[i]
    ent = -np.sum(w * np.log(w + 1e-10))
    max_ent = np.log(len(w))
    ratio = ent / max_ent
    top3 = np.argsort(w)[-3:][::-1]
    print(f"  Prototype {i}: entropy_ratio={ratio:.3f} "
          f"(1.0=uniform, lower=more discriminative), top features: {top3}")

# --- Helper: build model + init with EWKM prototypes using existing API ---
def build_ewkm_model():
    """Build EDANv3_Ablation and initialize prototypes from EWKM centroids."""
    set_all_seeds(RANDOM_STATE)
    model = EDANv3_Ablation(
        input_dim=input_dim,
        num_heads=V3_EWKM_CONFIG['num_heads'],
        attention_dim=V3_EWKM_CONFIG['attention_dim'],
        n_prototypes=V3_EWKM_CONFIG['n_prototypes'],
    ).to(device)
    # Use the EXISTING initialize_all_prototypes API — no monkey-patching
    model.faiia.initialize_all_prototypes(prototypes_ewkm, device)
    return model

ewkm_results = {}

# --- Experiment: FAIIA + EWKM + BCE ---
print("\n" + "="*60)
print("Experiment: FAIIA + EWKM + BCE")
print("="*60)
model_ewkm_bce = build_ewkm_model()
criterion_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

model_ewkm_bce, hist_bce = train_model(
    model_ewkm_bce, train_loader, val_loader, V3_EWKM_CONFIG, criterion_bce, device
)
metrics_bce, y_probs_bce, y_pred_bce = evaluate_model(model_ewkm_bce, X_test_tensor, y_test, device)
print_metrics(metrics_bce, "FAIIA + EWKM + BCE Results")
save_predictions(y_test, y_probs_bce, "faiia_ewkm_bce_predictions.npz")
ewkm_results['FAIIA + EWKM + BCE'] = metrics_bce

# --- Experiment: FAIIA + EWKM + Focal Loss ---
print("\n" + "="*60)
print("Experiment: FAIIA + EWKM + Focal Loss")
print("="*60)
model_ewkm_focal = build_ewkm_model()
criterion_focal = ImbalanceAwareFocalLoss_Logits(class_counts=class_counts, gamma=2.0)

model_ewkm_focal, hist_focal = train_model(
    model_ewkm_focal, train_loader, val_loader, V3_EWKM_CONFIG, criterion_focal, device
)
metrics_focal, y_probs_focal, y_pred_focal = evaluate_model(model_ewkm_focal, X_test_tensor, y_test, device)
print_metrics(metrics_focal, "FAIIA + EWKM + Focal Results")
save_predictions(y_test, y_probs_focal, "faiia_ewkm_focal_predictions.npz")
ewkm_results['FAIIA + EWKM + Focal'] = metrics_focal

# --- Summary ---
print("\n" + "="*60)
print("EWKM Ablation Summary")
print("="*60)
df_ewkm = pd.DataFrame(ewkm_results).T
display(df_ewkm)

if os.path.exists('ablation_summary.csv'):
    df_existing = pd.read_csv('ablation_summary.csv', index_col=0)
    df_existing = df_existing[~df_existing.index.str.contains('EWKM')]
    df_combined = pd.concat([df_existing, df_ewkm])
    df_combined.to_csv('ablation_summary.csv')
    print("Updated ablation_summary.csv")
else:
    df_ewkm.to_csv('ewkm_ablation_summary.csv')
    print("Saved to ewkm_ablation_summary.csv")

## 3. Run Standard Baselines

Runs classical ML baselines:
1. XGBoost
2. LightGBM

In [None]:
try:
    run_baseline_experiment()
except FileNotFoundError as e:
    print(f"Error: {e}")

## 4. View Consolidated Results

Load and display the summary CSVs generated by the experiments.

In [None]:
results = []

if os.path.exists('ablation_summary.csv'):
    df_ablation = pd.read_csv('ablation_summary.csv', index_col=0)
    print("Ablation Results Loaded")
    results.append(df_ablation)

if os.path.exists('baseline_summary.csv'):
    df_baseline = pd.read_csv('baseline_summary.csv', index_col=0)
    print("Baseline Results Loaded")
    results.append(df_baseline)
    
if results:
    final_df = pd.concat(results)
    display(final_df)
else:
    print("No results files found. Ensure experiments ran successfully.")