In [None]:
import os
import random

import numpy as np
import pandas as pd
import sklearn.tree as tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.inspection import permutation_importance
import joblib

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
RANDOM_SEED = 33
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [None]:
os.makedirs('Models/', exist_ok=True)

model_path = 'Models/dtree.joblib'

monday_data_path = 'Data/CICIDS2017_labeled/monday_labeled.csv'
tuesday_data_path = 'Data/CICIDS2017_labeled/tuesday_labeled.csv'
wednesday_data_path = 'Data/CICIDS2017_labeled/wednesday_labeled.csv'
thursday_data_path = 'Data/CICIDS2017_labeled/thursday_labeled.csv'
friday_data_path = 'Data/CICIDS2017_labeled/friday_labeled.csv'

In [None]:
df_mon = pd.read_csv(monday_data_path)
df_tue = pd.read_csv(tuesday_data_path)
df_wed = pd.read_csv(wednesday_data_path)
df_thu = pd.read_csv(thursday_data_path)
df_fri = pd.read_csv(friday_data_path)

df = pd.concat([df_mon, df_tue, df_wed, df_thu, df_fri], ignore_index=True)

In [None]:
df['Label'].value_counts()

In [None]:
"""
excluded_cols = ['id','Flow ID','Src IP','Src Port','Dst IP','Dst Port','Protocol','Timestamp',
                'Fwd URG Flags','Bwd URG Flags','URG Flag Count',
                'Attempted Category','Label']
"""

excluded_cols = ['Flow ID','Src IP','Src Port','Dst IP','Dst Port','Protocol','Timestamp',
                 'Out of order packets','Malformed packets','Direction guessed',
                 'Fwd URG Flags','Bwd URG Flags','URG Flag Count',
                 'TCP Bwd invalid seq',
                 'Label']

X = df.drop(columns=excluded_cols, axis=1)
y = df['Label']

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=RANDOM_SEED, stratify=y
)

print(f"Training shape: {X_train.shape}, test shape: {X_test.shape}")

In [None]:
MAX_DEPTH = 5

dtc = DecisionTreeClassifier(max_depth=MAX_DEPTH, random_state=RANDOM_SEED)
dtc.fit(X_train, y_train)

joblib.dump(dtc, model_path)

In [None]:
y_pred = dtc.predict(X_test)

In [None]:
labels = ["BENIGN", "ATTACK"]

print(classification_report(y_true=y_test, y_pred=y_pred, target_names=labels))

In [None]:
cm = confusion_matrix(y_true=y_test, y_pred=y_pred, labels=labels)
cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
annotations = np.empty_like(cm, dtype=object)
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        annotations[i,j] = f"{cm[i,j]}\n({cmn[i,j]:.2%})"
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cmn, annot=annotations, fmt='s', xticklabels=labels, yticklabels=labels, cmap='Blues', vmin=0, vmax=1, square=True, annot_kws={"size": 25})
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion matrix - DecisionTreeClassifier')
plt.show()

In [None]:
TOP_N = 7

fig, ax = plt.subplots(figsize=(10,8))
(
    pd.DataFrame({
        'importance': dtc.feature_importances_},
        index=X.columns)
        .sort_values('importance', ascending=False)
        .head(TOP_N)
        .plot.barh(ax=ax)
)
ax.invert_yaxis()
ax.set_title(f'Feature importances using MDI (top {TOP_N})', fontsize=18)
ax.set_ylabel('Mean Decrease in Impurity')
ax.get_legend().remove()
fig.tight_layout()
plt.show()

In [None]:
result = permutation_importance(
    dtc, X_test, y_test, n_repeats=10, random_state=42, n_jobs=-1
)

In [None]:
sorted_importances_idx = result.importances_mean.argsort()
top_n_indices = sorted_importances_idx[-TOP_N:]
importances = pd.DataFrame(
    result.importances[top_n_indices].T,
    columns=X_test.columns[top_n_indices],
)
ax = importances.plot.box(vert=False, whis=10)
ax.set_title("Permutation importances (test set)", fontsize=18)
ax.axvline(x=0, color="k", linestyle="--")
ax.set_xlabel("Decrease in accuracy score")
ax.figure.tight_layout()

In [None]:
text_representation = tree.export_text(dtc, feature_names=X.columns.values)
print(text_representation)

In [None]:
plt.figure(figsize=(40, 20))
tree.plot_tree(
    dtc,
    feature_names=X.columns.values,
    label='all',
    impurity=False,
    proportion=True,
    filled=True,
    fontsize=45
)
plt.show()

In [None]:
sns.histplot(data=X, x='Bwd Packet Length Std', hue=y.values, kde=True, bins=50)
plt.title('Distribution of \'Bwd Packet Length Std\'', pad=20)
plt.xlim(0, 5000)
plt.ylim(0, 80000)
plt.axvline(1496.86, color = 'tomato', linestyle='--')
plt.tight_layout()
plt.show()