# Hybrid IDS: Flow-Based Binary Detection + Multiclass Attack Classification

This notebook implements a two-stage intrusion detection system:
1. **Stage 1 (Binary)**: Ensemble model to classify traffic as Attack vs Benign
2. **Stage 2 (Multiclass)**: For detected attacks, classify the attack type using SPLT features

## Features
- Domain-invariant features for robustness across different networks
- Calibrated confidence thresholds to reduce false positives
- SPLT (Sequence of Packet Length and Time) for attack type classification

In [None]:
# Standard imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# ML imports
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score, 
    precision_recall_curve, roc_curve, f1_score, accuracy_score
)
import joblib

# Optional: XGBoost and LightGBM
try:
    import xgboost as xgb
    HAS_XGB = True
except ImportError:
    HAS_XGB = False
    print("XGBoost not installed, will use alternatives")

try:
    import lightgbm as lgb
    HAS_LGB = True
except ImportError:
    HAS_LGB = False
    print("LightGBM not installed, will use alternatives")

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print("✓ Imports complete")

## Configuration

In [None]:
# === CONFIGURATION ===
# Modify these paths based on your setup

DATA_DIR = Path("../data")
OUTPUT_DIR = Path("../outputs/hybrid_ids")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# CICIDS2017 labeled data
# Download from: https://www.unb.ca/cic/datasets/ids-2017.html
CICIDS_CSV_DIR = DATA_DIR / "cicids2017"  # Directory with labeled CSVs

# T-Pot attack flows (extracted with DPI)
ATTACK_FLOWS_PATH = DATA_DIR / "attack_flows_dpi.csv"

# Benign flows (from CICIDS Monday or extracted from benign PCAP)
BENIGN_FLOWS_PATH = DATA_DIR / "benign_flows_dpi.csv"

# Detection thresholds (can be tuned per environment)
CONFIG = {
    "binary_threshold": 0.7,      # Confidence threshold for binary detection
    "multiclass_threshold": 0.5,  # Confidence threshold for attack type
    "test_size": 0.2,
    "random_state": 42,
}

print(f"Output directory: {OUTPUT_DIR}")

## 1. Data Loading & Preparation

In [None]:
def load_cicids_data(csv_dir: Path, sample_frac: float = 1.0) -> pd.DataFrame:
    """
    Load CICIDS2017 labeled CSV files.
    """
    all_dfs = []
    csv_files = list(csv_dir.glob("*.csv"))
    
    if not csv_files:
        raise FileNotFoundError(f"No CSV files found in {csv_dir}")
    
    for csv_file in csv_files:
        print(f"Loading {csv_file.name}...")
        try:
            df = pd.read_csv(csv_file, encoding='utf-8', low_memory=False)
        except UnicodeDecodeError:
            df = pd.read_csv(csv_file, encoding='latin-1', low_memory=False)
        
        # Standardize column names
        df.columns = df.columns.str.strip().str.lower().str.replace(' ', '_')
        
        if sample_frac < 1.0:
            df = df.sample(frac=sample_frac, random_state=42)
        
        all_dfs.append(df)
        print(f"  → {len(df):,} rows, {len(df.columns)} columns")
    
    combined = pd.concat(all_dfs, ignore_index=True)
    print(f"\nTotal: {len(combined):,} rows")
    return combined


def load_extracted_flows(path: Path) -> pd.DataFrame:
    """
    Load flows extracted with our DPI script.
    """
    if not path.exists():
        raise FileNotFoundError(f"Flow file not found: {path}")
    
    df = pd.read_csv(path)
    print(f"Loaded {len(df):,} flows from {path.name}")
    return df

In [None]:
# Load CICIDS2017 data if available
USE_CICIDS = CICIDS_CSV_DIR.exists()

if USE_CICIDS:
    print("Loading CICIDS2017 dataset...")
    cicids_df = load_cicids_data(CICIDS_CSV_DIR, sample_frac=0.5)  # Sample for speed
    
    # Show label distribution
    if 'label' in cicids_df.columns:
        print("\nLabel distribution:")
        print(cicids_df['label'].value_counts())
else:
    print(f"CICIDS2017 directory not found: {CICIDS_CSV_DIR}")
    print("Will use extracted flows from PCAPs instead.")
    cicids_df = None

In [None]:
# Load extracted flows (if using our DPI extraction)
USE_EXTRACTED = ATTACK_FLOWS_PATH.exists() and BENIGN_FLOWS_PATH.exists()

if USE_EXTRACTED:
    print("Loading extracted flows...")
    attack_flows = load_extracted_flows(ATTACK_FLOWS_PATH)
    benign_flows = load_extracted_flows(BENIGN_FLOWS_PATH)
    
    # Add labels
    attack_flows['label'] = 'attack'
    benign_flows['label'] = 'benign'
    
    # Combine
    extracted_df = pd.concat([attack_flows, benign_flows], ignore_index=True)
    print(f"\nCombined: {len(extracted_df):,} flows")
else:
    print("Extracted flows not found. Run extract_flows_with_dpi.py first:")
    print(f"  python src/extract_flows_with_dpi.py --pcap <attack.pcap> --output {ATTACK_FLOWS_PATH}")
    print(f"  python src/extract_flows_with_dpi.py --pcap <benign.pcap> --output {BENIGN_FLOWS_PATH}")
    extracted_df = None

In [None]:
# Choose which dataset to use
if cicids_df is not None:
    df = cicids_df
    print("Using CICIDS2017 dataset")
elif extracted_df is not None:
    df = extracted_df
    print("Using extracted flows")
else:
    raise RuntimeError("No dataset available! Please provide CICIDS2017 data or extract flows from PCAPs.")

print(f"\nDataset shape: {df.shape}")
df.head()

## 2. Feature Engineering (Domain-Invariant)

In [None]:
def create_binary_label(labels: pd.Series) -> pd.Series:
    """
    Convert multiclass labels to binary (benign=0, attack=1).
    """
    # CICIDS2017 uses 'BENIGN' for normal traffic
    binary = labels.str.upper().apply(
        lambda x: 0 if 'BENIGN' in str(x).upper() else 1
    )
    return binary


def engineer_domain_invariant_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    Create domain-invariant features that transfer across networks.
    """
    df = df.copy()
    
    # Find common flow columns (works with both CICIDS and our extracted data)
    col_map = {
        # CICIDS names -> our names
        'flow_duration': 'duration',
        'total_fwd_packets': 'fwd_packets',
        'total_backward_packets': 'bwd_packets',
        'total_length_of_fwd_packets': 'fwd_bytes',
        'total_length_of_bwd_packets': 'bwd_bytes',
        'fwd_packet_length_mean': 'fwd_pkt_mean',
        'bwd_packet_length_mean': 'bwd_pkt_mean',
        'flow_iat_mean': 'iat_mean',
        'flow_iat_std': 'iat_std',
        'fwd_iat_mean': 'fwd_iat_mean',
        'bwd_iat_mean': 'bwd_iat_mean',
        'syn_flag_count': 'syn_count',
        'ack_flag_count': 'ack_count',
        'psh_flag_count': 'psh_count',
        'rst_flag_count': 'rst_count',
    }
    
    # Rename if CICIDS format
    for old, new in col_map.items():
        if old in df.columns and new not in df.columns:
            df[new] = df[old]
    
    # === Create ratio features (network-invariant) ===
    if 'fwd_packets' in df.columns and 'bwd_packets' in df.columns:
        total_pkts = df['fwd_packets'] + df['bwd_packets'] + 1
        df['fwd_pkt_ratio'] = df['fwd_packets'] / total_pkts
        df['bwd_pkt_ratio'] = df['bwd_packets'] / total_pkts
    
    if 'fwd_bytes' in df.columns and 'bwd_bytes' in df.columns:
        total_bytes = df['fwd_bytes'] + df['bwd_bytes'] + 1
        df['fwd_bytes_ratio'] = df['fwd_bytes'] / total_bytes
        df['bwd_bytes_ratio'] = df['bwd_bytes'] / total_bytes
    
    # === Log-transform volume features ===
    volume_cols = ['fwd_packets', 'bwd_packets', 'fwd_bytes', 'bwd_bytes', 'duration']
    for col in volume_cols:
        if col in df.columns:
            df[f'log_{col}'] = np.log1p(df[col].clip(lower=0))
    
    # === Packet size ratios ===
    if 'fwd_pkt_mean' in df.columns and 'bwd_pkt_mean' in df.columns:
        df['pkt_size_ratio'] = df['fwd_pkt_mean'] / (df['bwd_pkt_mean'] + 1)
    
    # === IAT ratios ===
    if 'fwd_iat_mean' in df.columns and 'bwd_iat_mean' in df.columns:
        df['iat_ratio'] = df['fwd_iat_mean'] / (df['bwd_iat_mean'] + 1)
    
    return df

In [None]:
# Apply feature engineering
df = engineer_domain_invariant_features(df)

# Create labels
if 'label' in df.columns:
    df['binary_label'] = create_binary_label(df['label'])
    df['attack_type'] = df['label'].where(df['binary_label'] == 1, 'BENIGN')
    
    print("Binary label distribution:")
    print(df['binary_label'].value_counts())
    print(f"\nAttack ratio: {df['binary_label'].mean():.2%}")
else:
    raise ValueError("No 'label' column found in dataset!")

In [None]:
# Select features for training
# Exclude identifiers and raw values that don't transfer across networks
EXCLUDE_COLS = [
    'label', 'binary_label', 'attack_type',  # Labels
    'src_ip', 'dst_ip', 'source_ip', 'destination_ip',  # IPs
    'src_port', 'dst_port', 'source_port', 'destination_port',  # Raw ports
    'flow_id', 'timestamp', 'first_seen_ms', 'last_seen_ms',  # Identifiers
    'requested_server_name', 'client_fingerprint',  # Strings
    'application_name', 'application_category',  # Will one-hot encode separately
]

# Get numeric columns only
feature_cols = [
    col for col in df.columns 
    if col not in EXCLUDE_COLS 
    and df[col].dtype in ['int64', 'float64', 'int32', 'float32']
]

print(f"Selected {len(feature_cols)} features for training")
print(f"Features: {feature_cols[:20]}...")  # Show first 20

In [None]:
# Prepare feature matrix
X = df[feature_cols].copy()

# Handle missing/infinite values
X = X.replace([np.inf, -np.inf], np.nan)
X = X.fillna(X.median())

# Binary labels
y_binary = df['binary_label'].values

# Multiclass labels (for attacks only)
attack_mask = df['binary_label'] == 1
y_multiclass = df.loc[attack_mask, 'attack_type'].values

print(f"Feature matrix shape: {X.shape}")
print(f"Binary labels: {len(y_binary)} (benign={sum(y_binary==0)}, attack={sum(y_binary==1)})")
print(f"Multiclass labels: {len(y_multiclass)} attacks")

## 3. Stage 1: Binary Classification (Attack vs Benign)

In [None]:
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y_binary, 
    test_size=CONFIG['test_size'], 
    random_state=CONFIG['random_state'],
    stratify=y_binary
)

print(f"Train: {X_train.shape[0]:,} samples")
print(f"Test:  {X_test.shape[0]:,} samples")

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

In [None]:
# Build ensemble classifiers
print("Training base classifiers...")

# Random Forest
rf = RandomForestClassifier(
    n_estimators=100,
    max_depth=15,
    min_samples_leaf=5,
    n_jobs=-1,
    random_state=CONFIG['random_state']
)

# Gradient Boosting
gb = GradientBoostingClassifier(
    n_estimators=100,
    max_depth=5,
    learning_rate=0.1,
    random_state=CONFIG['random_state']
)

classifiers = [('rf', rf), ('gb', gb)]

# Add XGBoost if available
if HAS_XGB:
    xgb_clf = xgb.XGBClassifier(
        n_estimators=100,
        max_depth=6,
        learning_rate=0.1,
        n_jobs=-1,
        random_state=CONFIG['random_state'],
        use_label_encoder=False,
        eval_metric='logloss'
    )
    classifiers.append(('xgb', xgb_clf))

# Add LightGBM if available
if HAS_LGB:
    lgb_clf = lgb.LGBMClassifier(
        n_estimators=100,
        max_depth=6,
        learning_rate=0.1,
        n_jobs=-1,
        random_state=CONFIG['random_state'],
        verbose=-1
    )
    classifiers.append(('lgb', lgb_clf))

print(f"Ensemble: {[name for name, _ in classifiers]}")

In [None]:
# Train voting ensemble
ensemble = VotingClassifier(
    estimators=classifiers,
    voting='soft'  # Use probabilities for confidence
)

print("Training ensemble...")
ensemble.fit(X_train_scaled, y_train)
print("✓ Training complete")

In [None]:
# Calibrate probabilities for better confidence estimates
print("Calibrating probabilities...")
calibrated_ensemble = CalibratedClassifierCV(ensemble, cv=3, method='isotonic')
calibrated_ensemble.fit(X_train_scaled, y_train)
print("✓ Calibration complete")

In [None]:
# Predict with calibrated probabilities
y_proba = calibrated_ensemble.predict_proba(X_test_scaled)[:, 1]
y_pred = (y_proba >= CONFIG['binary_threshold']).astype(int)

# Also get predictions with default threshold for comparison
y_pred_default = calibrated_ensemble.predict(X_test_scaled)

print(f"Threshold: {CONFIG['binary_threshold']}")
print(f"\n=== With Default Threshold (0.5) ===")
print(classification_report(y_test, y_pred_default, target_names=['Benign', 'Attack']))

print(f"\n=== With Custom Threshold ({CONFIG['binary_threshold']}) ===")
print(classification_report(y_test, y_pred, target_names=['Benign', 'Attack']))

In [None]:
# Plot ROC curve
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# ROC Curve
fpr, tpr, thresholds = roc_curve(y_test, y_proba)
auc = roc_auc_score(y_test, y_proba)

axes[0].plot(fpr, tpr, label=f'ROC (AUC = {auc:.3f})', linewidth=2)
axes[0].plot([0, 1], [0, 1], 'k--', alpha=0.5)
axes[0].axvline(x=fpr[np.argmin(np.abs(thresholds - CONFIG['binary_threshold']))], 
                color='r', linestyle=':', label=f"Threshold={CONFIG['binary_threshold']}")
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('ROC Curve - Binary Detection')
axes[0].legend()

# Precision-Recall Curve
precision, recall, pr_thresholds = precision_recall_curve(y_test, y_proba)
axes[1].plot(recall, precision, linewidth=2)
axes[1].set_xlabel('Recall')
axes[1].set_ylabel('Precision')
axes[1].set_title('Precision-Recall Curve')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'binary_roc_pr_curves.png', dpi=150)
plt.show()

In [None]:
# Confusion matrix
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for ax, pred, title in [
    (axes[0], y_pred_default, 'Default Threshold (0.5)'),
    (axes[1], y_pred, f'Custom Threshold ({CONFIG["binary_threshold"]})')
]:
    cm = confusion_matrix(y_test, pred)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=['Benign', 'Attack'],
                yticklabels=['Benign', 'Attack'])
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title(title)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'binary_confusion_matrix.png', dpi=150)
plt.show()

## 4. Stage 2: Multiclass Attack Classification

In [None]:
# Prepare multiclass data (attacks only)
X_attacks = X[attack_mask].copy()
y_attacks = y_multiclass

# Encode attack types
label_encoder = LabelEncoder()
y_attacks_encoded = label_encoder.fit_transform(y_attacks)

print(f"Attack samples: {len(X_attacks):,}")
print(f"\nAttack type distribution:")
for label, count in zip(*np.unique(y_attacks, return_counts=True)):
    print(f"  {label}: {count:,}")

In [None]:
# Train/test split for multiclass
X_train_mc, X_test_mc, y_train_mc, y_test_mc = train_test_split(
    X_attacks, y_attacks_encoded,
    test_size=CONFIG['test_size'],
    random_state=CONFIG['random_state'],
    stratify=y_attacks_encoded
)

# Scale
scaler_mc = StandardScaler()
X_train_mc_scaled = scaler_mc.fit_transform(X_train_mc)
X_test_mc_scaled = scaler_mc.transform(X_test_mc)

print(f"Multiclass Train: {X_train_mc.shape[0]:,}")
print(f"Multiclass Test:  {X_test_mc.shape[0]:,}")

In [None]:
# Train multiclass classifier
print("Training multiclass classifier...")

if HAS_XGB:
    mc_classifier = xgb.XGBClassifier(
        n_estimators=150,
        max_depth=8,
        learning_rate=0.1,
        n_jobs=-1,
        random_state=CONFIG['random_state'],
        use_label_encoder=False,
        eval_metric='mlogloss'
    )
else:
    mc_classifier = RandomForestClassifier(
        n_estimators=150,
        max_depth=15,
        n_jobs=-1,
        random_state=CONFIG['random_state']
    )

mc_classifier.fit(X_train_mc_scaled, y_train_mc)
print("✓ Training complete")

In [None]:
# Evaluate multiclass
y_pred_mc = mc_classifier.predict(X_test_mc_scaled)

print("Multiclass Classification Report:")
print(classification_report(
    y_test_mc, y_pred_mc, 
    target_names=label_encoder.classes_
))

In [None]:
# Multiclass confusion matrix
cm_mc = confusion_matrix(y_test_mc, y_pred_mc)

plt.figure(figsize=(12, 10))
sns.heatmap(cm_mc, annot=True, fmt='d', cmap='Blues',
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Multiclass Attack Type Classification')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'multiclass_confusion_matrix.png', dpi=150)
plt.show()

## 5. Robustness: Adversarial Validation & Drift Detection

In [None]:
def adversarial_validation(X_train: np.ndarray, X_test: np.ndarray) -> float:
    """
    Detect domain shift by training a classifier to distinguish train vs test.
    AUC > 0.6 indicates significant domain shift.
    """
    # Create labels: train=0, test=1
    y_adv = np.concatenate([np.zeros(len(X_train)), np.ones(len(X_test))])
    X_adv = np.vstack([X_train, X_test])
    
    # Train quick classifier
    adv_clf = RandomForestClassifier(n_estimators=50, max_depth=5, n_jobs=-1, random_state=42)
    
    # Cross-validated AUC
    cv_scores = cross_val_score(adv_clf, X_adv, y_adv, cv=3, scoring='roc_auc')
    mean_auc = cv_scores.mean()
    
    return mean_auc


# Check for domain shift between train and test
adv_auc = adversarial_validation(X_train_scaled, X_test_scaled)

print(f"Adversarial Validation AUC: {adv_auc:.3f}")
if adv_auc > 0.6:
    print("⚠️  WARNING: Significant domain shift detected!")
    print("   Consider collecting more diverse training data.")
else:
    print("✓ No significant domain shift detected.")

In [None]:
def threshold_sensitivity_analysis(y_true, y_proba, thresholds=None):
    """
    Analyze how different thresholds affect FP/TP rates.
    """
    if thresholds is None:
        thresholds = np.arange(0.3, 0.95, 0.05)
    
    results = []
    for thresh in thresholds:
        y_pred = (y_proba >= thresh).astype(int)
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        
        results.append({
            'threshold': thresh,
            'precision': tp / (tp + fp + 1e-9),
            'recall': tp / (tp + fn + 1e-9),
            'fpr': fp / (fp + tn + 1e-9),
            'f1': f1_score(y_true, y_pred),
        })
    
    return pd.DataFrame(results)


# Threshold sensitivity
thresh_df = threshold_sensitivity_analysis(y_test, y_proba)

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(thresh_df['threshold'], thresh_df['precision'], label='Precision', marker='o')
ax.plot(thresh_df['threshold'], thresh_df['recall'], label='Recall', marker='s')
ax.plot(thresh_df['threshold'], thresh_df['f1'], label='F1', marker='^')
ax.plot(thresh_df['threshold'], thresh_df['fpr'], label='False Positive Rate', marker='x')
ax.axvline(x=CONFIG['binary_threshold'], color='r', linestyle='--', 
           label=f"Current Threshold ({CONFIG['binary_threshold']})")
ax.set_xlabel('Detection Threshold')
ax.set_ylabel('Score')
ax.set_title('Threshold Sensitivity Analysis')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'threshold_sensitivity.png', dpi=150)
plt.show()

print("\nThreshold Analysis:")
print(thresh_df.round(3).to_string(index=False))

## 6. Save Models

In [None]:
# Save all artifacts
models_dir = OUTPUT_DIR / 'models'
models_dir.mkdir(exist_ok=True)

# Save models
joblib.dump(calibrated_ensemble, models_dir / 'binary_classifier.joblib')
joblib.dump(mc_classifier, models_dir / 'multiclass_classifier.joblib')
joblib.dump(scaler, models_dir / 'binary_scaler.joblib')
joblib.dump(scaler_mc, models_dir / 'multiclass_scaler.joblib')
joblib.dump(label_encoder, models_dir / 'label_encoder.joblib')

# Save config
import json
with open(models_dir / 'config.json', 'w') as f:
    json.dump({
        **CONFIG,
        'feature_columns': feature_cols,
        'attack_types': label_encoder.classes_.tolist(),
    }, f, indent=2)

print(f"✓ Models saved to {models_dir}")
print(f"\nFiles:")
for f in models_dir.iterdir():
    print(f"  - {f.name}")

## 7. Inference Function

In [None]:
def predict_hybrid(flow_features: pd.DataFrame, 
                   binary_threshold: float = 0.7,
                   multiclass_threshold: float = 0.5) -> pd.DataFrame:
    """
    Two-stage hybrid prediction:
    1. Binary: Is this traffic an attack?
    2. Multiclass: If attack, what type?
    
    Args:
        flow_features: DataFrame with flow features
        binary_threshold: Confidence threshold for attack detection
        multiclass_threshold: Confidence threshold for attack type
    
    Returns:
        DataFrame with predictions
    """
    # Load models (in production, load once at startup)
    binary_clf = joblib.load(models_dir / 'binary_classifier.joblib')
    mc_clf = joblib.load(models_dir / 'multiclass_classifier.joblib')
    binary_scaler = joblib.load(models_dir / 'binary_scaler.joblib')
    mc_scaler = joblib.load(models_dir / 'multiclass_scaler.joblib')
    le = joblib.load(models_dir / 'label_encoder.joblib')
    
    # Prepare features
    X = flow_features[feature_cols].copy()
    X = X.replace([np.inf, -np.inf], np.nan).fillna(0)
    
    # Stage 1: Binary detection
    X_scaled = binary_scaler.transform(X)
    attack_proba = binary_clf.predict_proba(X_scaled)[:, 1]
    is_attack = (attack_proba >= binary_threshold).astype(int)
    
    # Stage 2: Multiclass (only for detected attacks)
    attack_type = np.full(len(X), 'BENIGN', dtype=object)
    attack_type_proba = np.zeros(len(X))
    
    attack_mask = is_attack == 1
    if attack_mask.any():
        X_attacks = mc_scaler.transform(X[attack_mask])
        mc_proba = mc_clf.predict_proba(X_attacks)
        mc_pred = mc_clf.predict(X_attacks)
        mc_confidence = mc_proba.max(axis=1)
        
        # Apply multiclass threshold
        confident_preds = mc_confidence >= multiclass_threshold
        attack_type[attack_mask] = np.where(
            confident_preds,
            le.inverse_transform(mc_pred),
            'UNKNOWN_ATTACK'
        )
        attack_type_proba[attack_mask] = mc_confidence
    
    # Create results DataFrame
    results = pd.DataFrame({
        'is_attack': is_attack,
        'attack_confidence': attack_proba,
        'attack_type': attack_type,
        'type_confidence': attack_type_proba,
    })
    
    return results


# Test inference
print("Testing inference function...")
sample = X.iloc[:10].copy()
results = predict_hybrid(df.iloc[:10], binary_threshold=0.7)
print(results)

## Summary

### Binary Classification Results
- **Model**: Calibrated ensemble (RF + GB + XGBoost + LightGBM)
- **Threshold**: Adjustable (default 0.7 for low false positives)

### Multiclass Classification Results  
- **Model**: XGBoost (or RF fallback)
- **Classes**: Attack types from CICIDS2017

### Robustness Features
- Domain-invariant features (ratios, log transforms)
- Calibrated confidence scores
- Adjustable thresholds per environment
- Adversarial validation for drift detection