In [13]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/nsl-kdd-augmented/smote_augmented.csv
/kaggle/input/nslkdd/KDDTest+.arff
/kaggle/input/nslkdd/KDDTest-21.arff
/kaggle/input/nslkdd/KDDTest1.jpg
/kaggle/input/nslkdd/KDDTrain+.txt
/kaggle/input/nslkdd/KDDTrain+_20Percent.txt
/kaggle/input/nslkdd/KDDTest-21.txt
/kaggle/input/nslkdd/KDDTest+.txt
/kaggle/input/nslkdd/KDDTrain+.arff
/kaggle/input/nslkdd/index.html
/kaggle/input/nslkdd/KDDTrain+_20Percent.arff
/kaggle/input/nslkdd/KDDTrain1.jpg
/kaggle/input/nslkdd/nsl-kdd/KDDTest+.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTest-21.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTest1.jpg
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+_20Percent.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTest-21.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTest+.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+.arff
/kaggle/input/nslkdd/nsl-kdd/index.html
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+_20Percent.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTrain1.jpg


In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from xgboost import XGBClassifier
from sklearn.metrics import classification_report
import numpy as np

def surgical_q1_fusion(X_proc, df_orig):
    # Get raw outputs from both experts
    model_sp.eval() 
    with torch.no_grad():
        output = model_sp(torch.tensor(X_proc, dtype=torch.float32).to(device))
        logits = output[0] if isinstance(output, tuple) else output
        # Logit Sharpening to prevent "Maybe" answers
        probs_sp = torch.softmax(logits * 2.0, dim=1).cpu().numpy()
    
    probs_xgb = expert.predict_proba(X_proc)
    final_preds = []
    
    # Pre-define class indices
    idx_normal = le.transform(['normal'])[0]
    idx_back = le.transform(['back'])[0]
    
    # Rare Attack Templates (The "Hard Locks")
    r2l_classes = ['guess_passwd', 'warezmaster', 'ftp_write', 'imap', 'multihop', 'phf']
    u2r_classes = ['rootkit', 'buffer_overflow', 'loadmodule', 'perl']
    r2l_indices = [le.transform([c])[0] for c in r2l_classes]
    u2r_indices = [le.transform([c])[0] for c in u2r_classes]

    for i in range(len(X_proc)):
        p_x = probs_xgb[i]
        p_s = probs_sp[i]
        
        # --- TIER 1: THE ACCURACY ANCHOR (Protecting the high-volume wins) ---
        # If XGBoost is very sure about a Probing or DoS attack, trust it.
        # This recovers the 0.00 recall classes (nmap, portsweep, teardrop)
        best_xgb_class = np.argmax(p_x)
        if p_x[best_xgb_class] > 0.85 and best_xgb_class != idx_normal:
            final_preds.append(best_xgb_class)
            continue

        # --- TIER 2: SEMANTIC HARD-LOCK (The "Surgical" Novelty) ---
        # If security flags are present, we FORBID 'Normal' and 'DoS'
        has_r2l_flag = (df_orig['num_failed_logins'].iloc[i] > 0) or (df_orig['hot'].iloc[i] > 0)
        has_u2r_flag = (df_orig['root_shell'].iloc[i] > 0) or (df_orig['num_shells'].iloc[i] > 0)

        if has_r2l_flag:
            # Force decision within R2L manifold only
            final_preds.append(r2l_indices[np.argmax(p_s[r2l_indices])])
            continue
        elif has_u2r_flag:
            # Force decision within U2R manifold only
            final_preds.append(u2r_indices[np.argmax(p_s[u2r_indices])])
            continue

        # --- TIER 3: STABILITY GATE ---
        # If no flags and XGBoost says it's Normal with high confidence, trust it.
        if p_x[idx_normal] > 0.95:
            final_preds.append(idx_normal)
        else:
            # Final residual choice
            final_preds.append(np.argmax(0.5 * p_x + 0.5 * p_s))
            
    return np.array(final_preds)

# ===========================================
# Execution
# ===========================================
print("Executing Surgical Q1 Fusion...")
final_preds = surgical_q1_fusion(X_test_proc, df_test)

unique_labels = np.unique(np.concatenate([y_test_enc, final_preds]))
target_names = [le.classes_[i] for i in unique_labels]

print("\n--- CS-HFL Q1 RESULTS ---")
print(classification_report(y_test_enc, final_preds, 
                            labels=unique_labels, 
                            target_names=target_names, 
                            zero_division=0))

Executing Surgical Q1 Fusion...

--- CS-HFL Q1 RESULTS ---
                 precision    recall  f1-score   support

           back       1.00      0.94      0.97       359
buffer_overflow       0.00      0.00      0.00        20
      ftp_write       0.03      0.33      0.06         3
   guess_passwd       0.72      0.38      0.50      1231
           imap       0.00      0.00      0.00         1
        ipsweep       0.99      0.98      0.98       141
           land       0.00      0.00      0.00         7
     loadmodule       0.00      0.00      0.00         2
       multihop       0.00      0.00      0.00        18
        neptune       1.00      0.99      1.00      4657
           nmap       1.00      1.00      1.00        73
         normal       0.89      0.97      0.93      9711
           perl       0.00      0.00      0.00         2
            phf       1.00      0.50      0.67         2
            pod       0.71      0.90      0.80        41
      portsweep       0.79  