In [13]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/nsl-kdd-augmented/smote_augmented.csv
/kaggle/input/nslkdd/KDDTest+.arff
/kaggle/input/nslkdd/KDDTest-21.arff
/kaggle/input/nslkdd/KDDTest1.jpg
/kaggle/input/nslkdd/KDDTrain+.txt
/kaggle/input/nslkdd/KDDTrain+_20Percent.txt
/kaggle/input/nslkdd/KDDTest-21.txt
/kaggle/input/nslkdd/KDDTest+.txt
/kaggle/input/nslkdd/KDDTrain+.arff
/kaggle/input/nslkdd/index.html
/kaggle/input/nslkdd/KDDTrain+_20Percent.arff
/kaggle/input/nslkdd/KDDTrain1.jpg
/kaggle/input/nslkdd/nsl-kdd/KDDTest+.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTest-21.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTest1.jpg
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+_20Percent.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTest-21.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTest+.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+.arff
/kaggle/input/nslkdd/nsl-kdd/index.html
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+_20Percent.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTrain1.jpg


In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from xgboost import XGBClassifier
from sklearn.metrics import classification_report
import numpy as np

# ===========================================
# 1️⃣ The Micro-Specialist (The "Small Model")
# ===========================================
# This model is a specialist in R2L/U2R patterns. 
# We'll use your best performing neural specialist instance here.
def hea_fusion_inference(X_proc, df_orig):
    model_sp.eval() 
    with torch.no_grad():
        logits, _ = model_sp(torch.tensor(X_proc, dtype=torch.float32).to(device))
        probs_sp = torch.softmax(logits * 2.0, dim=1).cpu().numpy()
    
    probs_xgb = expert.predict_proba(X_proc)
    final_preds = []
    
    # Pre-define class indices
    idx_normal = le.transform(['normal'])[0]
    idx_back = le.transform(['back'])[0]
    idx_guess = le.transform(['guess_passwd'])[0]
    idx_warez = le.transform(['warezmaster'])[0]

    for i in range(len(X_proc)):
        p_x = probs_xgb[i]
        p_s = probs_sp[i]
        
        # --- LEVEL 1: PROTOCOL ANCHORS (Protecting the Stability) ---
        # If it's a high-confidence DoS/Probe, don't let the Specialist touch it.
        # This fixes the 'back' recall drop.
        if (p_x[idx_back] > 0.3 and df_orig['src_bytes'].iloc[i] > 5000) or \
           (le.classes_[np.argmax(p_x)] in ['neptune', 'smurf', 'satan', 'ipsweep'] and np.max(p_x) > 0.8):
            final_preds.append(np.argmax(p_x))
            continue

        # --- LEVEL 2: THE STATEFUL SIEVE (The "Small Model" Logic) ---
        # If login/content flags are tripped, we FORCE an attack prediction.
        is_content_attack = (df_orig['hot'].iloc[i] > 0) or \
                            (df_orig['num_failed_logins'].iloc[i] > 0) or \
                            (df_orig['is_guest_login'].iloc[i] > 0) or \
                            (df_orig['num_compromised'].iloc[i] > 0)
        
        if is_content_attack:
            # We MASK the 'normal' class. The Specialist MUST find the attack.
            p_s_masked = p_s.copy()
            p_s_masked[idx_normal] = 0
            # Also mask DoS to prevent collision in this branch
            dos_indices = [le.transform([c])[0] for c in ['neptune', 'back', 'land', 'pod', 'smurf', 'teardrop']]
            p_s_masked[dos_indices] = 0
            
            final_preds.append(np.argmax(p_s_masked))
            continue

        # --- LEVEL 3: RESIDUAL STABILITY ---
        # If XGBoost is very sure about 'Normal', trust it.
        if p_x[idx_normal] > 0.95:
            final_preds.append(idx_normal)
        else:
            # Default to the most likely prediction between the two
            final_preds.append(np.argmax(0.6 * p_x + 0.4 * p_s))
            
    return np.array(final_preds)

# ===========================================
# 2️⃣ Final Execution
# ===========================================
print("Executing HEA-Net Final Fusion...")
final_preds = hea_fusion_inference(X_test_proc, df_test)

unique_labels = np.unique(np.concatenate([y_test_enc, final_preds]))
target_names = [le.classes_[i] for i in unique_labels]

print("\n--- HEA-Net Q1 FINAL RESULTS ---")
print(classification_report(y_test_enc, final_preds, 
                            labels=unique_labels, 
                            target_names=target_names, 
                            zero_division=0))

Executing HEA-Net Final Fusion...

--- HEA-Net Q1 FINAL RESULTS ---
                 precision    recall  f1-score   support

           back       0.98      0.95      0.97       359
buffer_overflow       0.22      0.40      0.28        20
      ftp_write       0.04      0.33      0.07         3
   guess_passwd       0.71      0.25      0.37      1231
           imap       0.00      0.00      0.00         1
        ipsweep       0.99      0.98      0.98       141
           land       0.00      0.00      0.00         7
     loadmodule       0.00      0.00      0.00         2
       multihop       0.00      0.00      0.00        18
        neptune       1.00      1.00      1.00      4657
           nmap       0.99      1.00      0.99        73
         normal       0.89      0.97      0.93      9711
           perl       0.00      0.00      0.00         2
            phf       0.00      0.00      0.00         2
            pod       0.72      0.93      0.81        41
      portsweep    