# Random Forest Model

## Imports

In [4]:
# Custom
import sys
sys.path.append('../')
from utils.constant import FEATURES, LABELS, ALL_ATTACKS
from utils.dataset_manager import fit_dataset, get_classes_weights

# General
import warnings
from joblib import dump

# Model and Metrics
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier

# Warnings
warnings.filterwarnings('ignore')

## Dataset


In [5]:
n_files = 20

df_train, df_test = fit_dataset(n_files, ALL_ATTACKS)

X_train, y_train = df_train[FEATURES], df_train[LABELS]

# Prints
print('Training Population: {}'.format(len(df_train)))
print('Testing Population: {}'.format(len(df_test)))

100%|██████████| 20/20 [00:38<00:00,  1.90s/it]
100%|██████████| 6/6 [00:07<00:00,  1.27s/it]


Training Population: 4723822
Testing Population: 1648176


# Model

In [6]:
# Model
rf_model = RandomForestClassifier(class_weight=get_classes_weights(df_train))

# Train
rf_model.fit(X_train, y_train)

In [7]:
# Save Model
name = f"../outputs/random_forest_{n_files}.joblib"
dump(rf_model, name)

['../outputs/random_forest_20.joblib']

# Evaluation

In [8]:
X_test , y_test = df_test[FEATURES], df_test[LABELS]

# Predict
y_pred = rf_model.predict(X_test)

# Evaluate
print('Classification Report: \n{}'.format(classification_report(y_test, y_pred)))

Classification Report: 
              precision    recall  f1-score   support

           0       1.00      1.00      1.00    142361
           1       1.00      1.00      1.00    144128
           2       1.00      1.00      1.00    143521
           3       1.00      1.00      1.00    191686
           4       1.00      1.00      1.00    159101
           5       1.00      1.00      1.00    254077
           6       1.00      1.00      1.00    126849
           7       1.00      0.99      0.99     10061
           8       1.00      1.00      1.00     10244
           9       0.99      1.00      1.00     16043
          10       0.98      0.99      0.98       844
          11       1.00      0.96      0.98      1050
          12       1.00      1.00      1.00    116827
          13       1.00      1.00      1.00     71688
          14       1.00      1.00      1.00     94194
          15       0.99      0.99      0.99      2474
          16       1.00      1.00      1.00     35144
   