# CICIDS - Shallow Models (DT)

The CICIDS2017 dataset is a comprehensive dataset for network intrusion detection, created by the Canadian Institute for Cybersecurity. It includes a diverse set of attack scenarios and normal traffic, making it suitable for training and evaluating intrusion detection systems.

The dataset includes various types of attacks such as Brute Force, Heartbleed, Botnet, DoS (Denial of Service), DDoS (Distributed Denial of Service), Web attacks, and Infiltration of the network from inside.

In [1]:
model_name = "decision_trees_(DT)"

In [2]:
import warnings
from sklearn.exceptions import UndefinedMetricWarning
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

## Step 1. Read data and import necessary libraries

In [3]:
import pandas as pd
df_train = pd.read_csv("../data/concatenated/concat.csv")

In [4]:
df_train.head(5)

Unnamed: 0,Destination Port,Flow Duration,Total Fwd Packets,Total Backward Packets,Total Length of Fwd Packets,Total Length of Bwd Packets,Fwd Packet Length Max,Fwd Packet Length Min,Fwd Packet Length Mean,Fwd Packet Length Std,Bwd Packet Length Max,Bwd Packet Length Min,Bwd Packet Length Mean,Bwd Packet Length Std,Flow Bytes/s,Flow Packets/s,Flow IAT Mean,Flow IAT Std,Flow IAT Max,Flow IAT Min,Fwd IAT Total,Fwd IAT Mean,Fwd IAT Std,Fwd IAT Max,Fwd IAT Min,Bwd IAT Total,Bwd IAT Mean,Bwd IAT Std,Bwd IAT Max,Bwd IAT Min,Fwd PSH Flags,Bwd PSH Flags,Fwd URG Flags,Bwd URG Flags,Fwd Header Length,Bwd Header Length,Fwd Packets/s,Bwd Packets/s,Min Packet Length,Max Packet Length,Packet Length Mean,Packet Length Std,Packet Length Variance,FIN Flag Count,SYN Flag Count,RST Flag Count,PSH Flag Count,ACK Flag Count,URG Flag Count,CWE Flag Count,ECE Flag Count,Down/Up Ratio,Average Packet Size,Avg Fwd Segment Size,Avg Bwd Segment Size,Fwd Header Length.1,Fwd Avg Bytes/Bulk,Fwd Avg Packets/Bulk,Fwd Avg Bulk Rate,Bwd Avg Bytes/Bulk,Bwd Avg Packets/Bulk,Bwd Avg Bulk Rate,Subflow Fwd Packets,Subflow Fwd Bytes,Subflow Bwd Packets,Subflow Bwd Bytes,Init_Win_bytes_forward,Init_Win_bytes_backward,act_data_pkt_fwd,min_seg_size_forward,Active Mean,Active Std,Active Max,Active Min,Idle Mean,Idle Std,Idle Max,Idle Min,Label
0,49188,4,2,0,12,0,6,6,6.0,0.0,0,0,0.0,0.0,3000000.0,500000.0,4.0,0.0,4,4,4,4.0,0.0,4,4,0,0.0,0.0,0,0,0,0,0,0,40,0,500000.0,0.0,6,6,6.0,0.0,0.0,0,0,0,0,1,1,0,0,0,9.0,6.0,0.0,40,0,0,0,0,0,0,2,12,0,0,329,-1,1,20,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN
1,49188,1,2,0,12,0,6,6,6.0,0.0,0,0,0.0,0.0,12000000.0,2000000.0,1.0,0.0,1,1,1,1.0,0.0,1,1,0,0.0,0.0,0,0,0,0,0,0,40,0,2000000.0,0.0,6,6,6.0,0.0,0.0,0,0,0,0,1,1,0,0,0,9.0,6.0,0.0,40,0,0,0,0,0,0,2,12,0,0,329,-1,1,20,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN
2,49188,1,2,0,12,0,6,6,6.0,0.0,0,0,0.0,0.0,12000000.0,2000000.0,1.0,0.0,1,1,1,1.0,0.0,1,1,0,0.0,0.0,0,0,0,0,0,0,40,0,2000000.0,0.0,6,6,6.0,0.0,0.0,0,0,0,0,1,1,0,0,0,9.0,6.0,0.0,40,0,0,0,0,0,0,2,12,0,0,329,-1,1,20,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN
3,49188,1,2,0,12,0,6,6,6.0,0.0,0,0,0.0,0.0,12000000.0,2000000.0,1.0,0.0,1,1,1,1.0,0.0,1,1,0,0.0,0.0,0,0,0,0,0,0,40,0,2000000.0,0.0,6,6,6.0,0.0,0.0,0,0,0,0,1,1,0,0,0,9.0,6.0,0.0,40,0,0,0,0,0,0,2,12,0,0,329,-1,1,20,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN
4,49486,3,2,0,12,0,6,6,6.0,0.0,0,0,0.0,0.0,4000000.0,666666.7,3.0,0.0,3,3,3,3.0,0.0,3,3,0,0.0,0.0,0,0,0,0,0,0,40,0,666666.7,0.0,6,6,6.0,0.0,0.0,0,0,0,0,1,1,0,0,0,9.0,6.0,0.0,40,0,0,0,0,0,0,2,12,0,0,245,-1,1,20,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN


In [5]:
df_train.shape

(2830743, 79)

In [6]:
df_train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2830743 entries, 0 to 2830742
Data columns (total 79 columns):
 #   Column                        Dtype  
---  ------                        -----  
 0    Destination Port             int64  
 1    Flow Duration                int64  
 2    Total Fwd Packets            int64  
 3    Total Backward Packets       int64  
 4   Total Length of Fwd Packets   int64  
 5    Total Length of Bwd Packets  int64  
 6    Fwd Packet Length Max        int64  
 7    Fwd Packet Length Min        int64  
 8    Fwd Packet Length Mean       float64
 9    Fwd Packet Length Std        float64
 10  Bwd Packet Length Max         int64  
 11   Bwd Packet Length Min        int64  
 12   Bwd Packet Length Mean       float64
 13   Bwd Packet Length Std        float64
 14  Flow Bytes/s                  float64
 15   Flow Packets/s               float64
 16   Flow IAT Mean                float64
 17   Flow IAT Std                 float64
 18   Flow IAT Max         

In [7]:
df_train.describe()

  sqr = _ensure_numeric((avg - values) ** 2)
  sqr = _ensure_numeric((avg - values) ** 2)


Unnamed: 0,Destination Port,Flow Duration,Total Fwd Packets,Total Backward Packets,Total Length of Fwd Packets,Total Length of Bwd Packets,Fwd Packet Length Max,Fwd Packet Length Min,Fwd Packet Length Mean,Fwd Packet Length Std,Bwd Packet Length Max,Bwd Packet Length Min,Bwd Packet Length Mean,Bwd Packet Length Std,Flow Bytes/s,Flow Packets/s,Flow IAT Mean,Flow IAT Std,Flow IAT Max,Flow IAT Min,Fwd IAT Total,Fwd IAT Mean,Fwd IAT Std,Fwd IAT Max,Fwd IAT Min,Bwd IAT Total,Bwd IAT Mean,Bwd IAT Std,Bwd IAT Max,Bwd IAT Min,Fwd PSH Flags,Bwd PSH Flags,Fwd URG Flags,Bwd URG Flags,Fwd Header Length,Bwd Header Length,Fwd Packets/s,Bwd Packets/s,Min Packet Length,Max Packet Length,Packet Length Mean,Packet Length Std,Packet Length Variance,FIN Flag Count,SYN Flag Count,RST Flag Count,PSH Flag Count,ACK Flag Count,URG Flag Count,CWE Flag Count,ECE Flag Count,Down/Up Ratio,Average Packet Size,Avg Fwd Segment Size,Avg Bwd Segment Size,Fwd Header Length.1,Fwd Avg Bytes/Bulk,Fwd Avg Packets/Bulk,Fwd Avg Bulk Rate,Bwd Avg Bytes/Bulk,Bwd Avg Packets/Bulk,Bwd Avg Bulk Rate,Subflow Fwd Packets,Subflow Fwd Bytes,Subflow Bwd Packets,Subflow Bwd Bytes,Init_Win_bytes_forward,Init_Win_bytes_backward,act_data_pkt_fwd,min_seg_size_forward,Active Mean,Active Std,Active Max,Active Min,Idle Mean,Idle Std,Idle Max,Idle Min
count,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2829385.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0,2830743.0
mean,8071.483,14785660.0,9.36116,10.39377,549.3024,16162.64,207.5999,18.71366,58.20194,68.91013,870.8495,41.04958,305.9493,335.3257,inf,inf,1298449.0,2919271.0,9182475.0,162379.6,14482960.0,2610193.0,3266957.0,9042939.0,1021893.0,9893830.0,1805784.0,1485973.0,4684692.0,967261.4,0.04644646,0.0,0.0001112782,0.0,-25997.39,-2273.275,63865.35,6995.192,16.4345,950.4024,171.9444,294.9756,486154.8,0.03537976,0.04644646,0.0002423392,0.2980705,0.3158443,0.09482316,0.0001112782,0.000243399,0.6835004,191.9837,58.20194,305.9493,-25997.39,0.0,0.0,0.0,0.0,0.0,0.0,9.36116,549.2919,10.39377,16162.3,6989.837,1989.433,5.418218,-2741.688,81551.32,41134.12,153182.5,58295.82,8316037.0,503843.9,8695752.0,7920031.0
std,18283.63,33653740.0,749.6728,997.3883,9993.589,2263088.0,717.1848,60.33935,186.0912,281.1871,1946.367,68.8626,605.2568,839.6932,,,4507944.0,8045870.0,24459540.0,2950282.0,33575810.0,9525722.0,9639055.0,24529160.0,8591436.0,28736610.0,8887197.0,6278469.0,17160950.0,8308983.0,0.21045,0.0,0.01054826,0.0,21052860.0,1452209.0,247537.1,38151.7,25.23772,2028.229,305.4915,631.8001,1647490.0,0.1847378,0.21045,0.01556536,0.4574107,0.4648513,0.2929706,0.01054826,0.01559935,0.680492,331.8603,186.0912,605.2568,21052860.0,0.0,0.0,0.0,0.0,0.0,0.0,749.6728,9980.07,997.3883,2263057.0,14338.73,8456.883,636.4257,1084989.0,648599.9,393381.5,1025825.0,577092.3,23630080.0,4602984.0,24366890.0,23363420.0
min,0.0,-13.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-261000000.0,-2000000.0,-13.0,0.0,-13.0,-14.0,0.0,0.0,0.0,0.0,-12.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-32212230000.0,-1073741000.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-32212230000.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,-1.0,-1.0,0.0,-536870700.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,53.0,155.0,2.0,1.0,12.0,0.0,6.0,0.0,6.0,0.0,0.0,0.0,0.0,0.0,119.3197,3.446226,63.66667,0.0,123.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,40.0,20.0,1.749446,0.1229197,0.0,6.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.5,6.0,0.0,40.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,12.0,1.0,0.0,-1.0,-1.0,0.0,20.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,80.0,31316.0,2.0,2.0,62.0,123.0,37.0,2.0,34.0,0.0,79.0,0.0,72.0,0.0,4595.549,110.6684,11438.84,137.1787,30865.0,4.0,43.0,26.0,0.0,43.0,3.0,3.0,3.0,0.0,3.0,1.0,0.0,0.0,0.0,0.0,64.0,40.0,61.32524,19.82789,2.0,87.0,57.2,25.98076,675.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,72.25,34.0,72.0,64.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,62.0,2.0,123.0,251.0,-1.0,1.0,24.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,443.0,3204828.0,5.0,4.0,187.0,482.0,81.0,36.0,50.0,26.16295,280.0,77.0,181.0,77.94054,166666.7,23255.81,337426.6,691266.3,2440145.0,64.0,1242844.0,206306.4,65989.82,931006.0,48.0,98580.5,18248.57,15724.09,60210.0,45.0,0.0,0.0,0.0,0.0,120.0,104.0,12048.19,7352.941,36.0,525.0,119.8,174.3239,30388.84,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,149.2639,50.0,181.0,120.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,187.0,4.0,482.0,8192.0,235.0,2.0,32.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,65535.0,120000000.0,219759.0,291922.0,12900000.0,655453000.0,24820.0,2325.0,5940.857,7125.597,19530.0,2896.0,5800.5,8194.66,inf,inf,120000000.0,84800260.0,120000000.0,120000000.0,120000000.0,120000000.0,84602930.0,120000000.0,120000000.0,120000000.0,120000000.0,84418010.0,120000000.0,120000000.0,1.0,0.0,1.0,0.0,4644908.0,5838440.0,3000000.0,2000000.0,1448.0,24820.0,3337.143,4731.522,22400000.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,156.0,3893.333,5940.857,5800.5,4644908.0,0.0,0.0,0.0,0.0,0.0,0.0,219759.0,12870340.0,291922.0,655453000.0,65535.0,65535.0,213557.0,138.0,110000000.0,74200000.0,110000000.0,110000000.0,120000000.0,76900000.0,120000000.0,120000000.0


In [8]:
# Remove leading and trailing whitespaces from column names
df_train.columns = df_train.columns.str.strip()

In [9]:
df_train.columns

Index(['Destination Port', 'Flow Duration', 'Total Fwd Packets',
       'Total Backward Packets', 'Total Length of Fwd Packets',
       'Total Length of Bwd Packets', 'Fwd Packet Length Max',
       'Fwd Packet Length Min', 'Fwd Packet Length Mean',
       'Fwd Packet Length Std', 'Bwd Packet Length Max',
       'Bwd Packet Length Min', 'Bwd Packet Length Mean',
       'Bwd Packet Length Std', 'Flow Bytes/s', 'Flow Packets/s',
       'Flow IAT Mean', 'Flow IAT Std', 'Flow IAT Max', 'Flow IAT Min',
       'Fwd IAT Total', 'Fwd IAT Mean', 'Fwd IAT Std', 'Fwd IAT Max',
       'Fwd IAT Min', 'Bwd IAT Total', 'Bwd IAT Mean', 'Bwd IAT Std',
       'Bwd IAT Max', 'Bwd IAT Min', 'Fwd PSH Flags', 'Bwd PSH Flags',
       'Fwd URG Flags', 'Bwd URG Flags', 'Fwd Header Length',
       'Bwd Header Length', 'Fwd Packets/s', 'Bwd Packets/s',
       'Min Packet Length', 'Max Packet Length', 'Packet Length Mean',
       'Packet Length Std', 'Packet Length Variance', 'FIN Flag Count',
       'SYN Flag Co

## Step 2. Data Cleaning

### A. Missing values

In [10]:
print(df_train.isna().sum().sum())

1358


In [11]:
df_train.dropna(subset=["Flow Bytes/s"], inplace=True)

In [12]:
print(df_train.isna().sum().sum())

0


### Inf. values

In [13]:
import numpy as np
df_train = df_train.replace([np.inf, -np.inf], np.nan).dropna()

## Step 3. Data Preparation

### A. Normalise numeric features

In [14]:
# Get all numerical columns
numerical_columns = df_train.select_dtypes(include="number").columns

In [15]:
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
df_train[numerical_columns] = scaler.fit_transform(df_train[numerical_columns])

### B. Map Labels to Multi-class

In [16]:
df_train["Label"].value_counts()

Label
BENIGN                        2271320
DoS Hulk                       230124
PortScan                       158804
DDoS                           128025
DoS GoldenEye                   10293
FTP-Patator                      7935
SSH-Patator                      5897
DoS slowloris                    5796
DoS Slowhttptest                 5499
Bot                              1956
Web Attack � Brute Force         1507
Web Attack � XSS                  652
Infiltration                       36
Web Attack � Sql Injection         21
Heartbleed                         11
Name: count, dtype: int64

In [17]:
attack_mapping = {
	"BENIGN": 0,
	"DoS Hulk": 1,
	"PortScan": 2,
	"DDoS": 3,
	"DoS GoldenEye": 4,
	"FTP-Patator": 5,
	"SSH-Patator": 6,
	"DoS slowloris": 7,
	"DoS Slowhttptest": 8,
	"Bot": 9,
	"Web Attack � Brute Force": 10,
	"Web Attack � XSS": 11,
	"Infiltration": 12,
	"Web Attack � Sql Injection": 13,
	"Heartbleed": 14,
}

df_train["Label"] = df_train["Label"].map(attack_mapping)

In [18]:
df_train["Label"].value_counts()

Label
0     2271320
1      230124
2      158804
3      128025
4       10293
5        7935
6        5897
7        5796
8        5499
9        1956
10       1507
11        652
12         36
13         21
14         11
Name: count, dtype: int64

### C. Data Splitting

In [19]:
X = df_train.drop(columns="Label")
y = df_train["Label"]

In [20]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

### D. Apply SMOTE to balance the training data

In [21]:
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler

# 1. Undersample the majority class
undersampling_strategy = {
    0: 1000,
    1: 1000,
    2: 1000,
    3: 1000,
    4: 1000,
    5: 1000,
    6: 1000,
    7: 1000,
    8: 1000,
    9: 1000,
	10: 1000,
}
rus = RandomUnderSampler(random_state=42, sampling_strategy=undersampling_strategy)
X_train_undersampled, y_train_undersampled = rus.fit_resample(X_train, y_train)

# 2. Oversample the minority class
smote = SMOTE(random_state=42, sampling_strategy="auto")
X_train_balanced, y_train_balanced = smote.fit_resample(X_train_undersampled, y_train_undersampled)

In [22]:
# Check class distribution after SMOTE
from collections import Counter

print(f"Class distribution before SMOTE: {Counter(y_train)}")
print(f"Class distribution after SMOTE: {Counter(y_train_balanced)}")

Class distribution before SMOTE: Counter({0: 1817112, 1: 184342, 2: 126927, 3: 102239, 4: 8219, 5: 6363, 6: 4769, 7: 4630, 8: 4390, 9: 1515, 10: 1206, 11: 533, 12: 29, 13: 17, 14: 9})
Class distribution after SMOTE: Counter({0: 1000, 1: 1000, 2: 1000, 3: 1000, 4: 1000, 5: 1000, 6: 1000, 7: 1000, 8: 1000, 9: 1000, 10: 1000, 11: 1000, 12: 1000, 13: 1000, 14: 1000})


## Step 4. Model

#### A. Defining the model

In [31]:
from sklearn.tree import DecisionTreeClassifier
from tqdm import tqdm

In [32]:
model = DecisionTreeClassifier(random_state=42, max_depth=10)

### B. Training the model

In [33]:
model.fit(X_train_balanced, y_train_balanced)

In [34]:
# Save the model
import joblib
joblib.dump(model, f"../models/shallow/{model_name}.joblib")


['../models/shallow/decision_trees_(DT).joblib']

### G. Evaluating the model

In [35]:
from sklearn.metrics import classification_report, accuracy_score, f1_score
# y_pred = svm_model.predict(X_test)

# Balance the test set

# 1. Undersample the majority classes, and keep the values for the minority classes

SAMPLE_TARGET = 1000
undersampling_strategy_test_set = {
    0: SAMPLE_TARGET if y_test.value_counts().get(0) > SAMPLE_TARGET else y_test.value_counts().get(0),	# BENIGN
	1: SAMPLE_TARGET if y_test.value_counts().get(1) > SAMPLE_TARGET else y_test.value_counts().get(1),	# DoS Hulk
	2: SAMPLE_TARGET if y_test.value_counts().get(2) > SAMPLE_TARGET else y_test.value_counts().get(2),	# PortScan
	3: SAMPLE_TARGET if y_test.value_counts().get(3) > SAMPLE_TARGET else y_test.value_counts().get(3),	# DDoS
	4: SAMPLE_TARGET if y_test.value_counts().get(4) > SAMPLE_TARGET else y_test.value_counts().get(4),	# DoS GoldenEye
	5: SAMPLE_TARGET if y_test.value_counts().get(5) > SAMPLE_TARGET else y_test.value_counts().get(5),	# FTP-Patator
	6: SAMPLE_TARGET if y_test.value_counts().get(6) > SAMPLE_TARGET else y_test.value_counts().get(6),	# SSH-Patator
	7: SAMPLE_TARGET if y_test.value_counts().get(7) > SAMPLE_TARGET else y_test.value_counts().get(7),	# DoS slowloris
	8: SAMPLE_TARGET if y_test.value_counts().get(8) > SAMPLE_TARGET else y_test.value_counts().get(8),	# DoS Slowhttptest
	9: SAMPLE_TARGET if y_test.value_counts().get(9) > SAMPLE_TARGET else y_test.value_counts().get(9),	# Bot
	10: SAMPLE_TARGET if y_test.value_counts().get(10) > SAMPLE_TARGET else y_test.value_counts().get(10),	# Web Attack - Brute Force
	11: SAMPLE_TARGET if y_test.value_counts().get(11) > SAMPLE_TARGET else y_test.value_counts().get(11),	# Web Attack - XSS
	12: SAMPLE_TARGET if y_test.value_counts().get(12) > SAMPLE_TARGET else y_test.value_counts().get(12),	# Infiltration
	13: SAMPLE_TARGET if y_test.value_counts().get(13) > SAMPLE_TARGET else y_test.value_counts().get(13),	# Web Attack - SQL Injection
	14: SAMPLE_TARGET if y_test.value_counts().get(14) > SAMPLE_TARGET else y_test.value_counts().get(14),	# Heartbleed
}

rus_test = RandomUnderSampler(random_state=42, sampling_strategy=undersampling_strategy_test_set)
X_test_balanced, y_test_balanced = rus_test.fit_resample(X_test, y_test)


In [36]:
batch_size = 100  # Adjust based on your dataset size
y_pred = []

for i in tqdm(range(0, len(X_test_balanced), batch_size), desc="Predicting batches"):
    batch = X_test_balanced[i:i + batch_size]
    y_pred.extend(model.predict(batch))

y_pred = np.array(y_pred)

Predicting batches: 100%|██████████| 99/99 [00:00<00:00, 523.80it/s]


In [37]:
print(f"Accuracy: {accuracy_score(y_test_balanced, y_pred)}")
print(f"F1 Score: {f1_score(y_test_balanced, y_pred, average='weighted')}")

Accuracy: 0.952096414826818
F1 Score: 0.9539228941604226


In [38]:
print(classification_report(y_test_balanced, y_pred, target_names=attack_mapping.keys()))

                            precision    recall  f1-score   support

                    BENIGN       0.98      0.90      0.94      1000
                  DoS Hulk       0.97      0.96      0.97      1000
                  PortScan       0.99      0.98      0.99      1000
                      DDoS       1.00      1.00      1.00      1000
             DoS GoldenEye       0.95      0.98      0.97      1000
               FTP-Patator       1.00      1.00      1.00      1000
               SSH-Patator       1.00      1.00      1.00      1000
             DoS slowloris       0.93      0.99      0.96      1000
          DoS Slowhttptest       0.99      0.91      0.95      1000
                       Bot       0.95      1.00      0.97       441
  Web Attack � Brute Force       0.69      0.42      0.52       301
          Web Attack � XSS       0.34      0.81      0.48       119
              Infiltration       1.00      1.00      1.00         7
Web Attack � Sql Injection       0.15      0.75