# Model Comparison

## Imports

In [1]:
from utils.dataset_manager import fit_dataset, get_classes_weights
from utils.constant import ALL_ATTACKS, FEATURES, LABELS

# Models
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, classification_report

# Other
import warnings
from tqdm import tqdm

# Ignore warnings
warnings.filterwarnings('ignore')

## Dataset

In [2]:
n_files = 2

df_train, df_test = fit_dataset(n_files, ALL_ATTACKS)

X_train, y_train = df_train[FEATURES], df_train[LABELS]
X_test, y_test = df_test[FEATURES], df_test[LABELS]

# Prints
print('Training Population: {}'.format(len(df_train)))
print('Testing Population: {}'.format(len(df_test)))

100%|██████████| 2/2 [00:09<00:00,  4.97s/it]
100%|██████████| 1/1 [00:03<00:00,  3.46s/it]


Training Population: 457492
Testing Population: 275258


## Models

Since we have unbalanced data, we have to add their weigth in each model

In [4]:
models = {
    'log_reg': LogisticRegression(class_weight=get_classes_weights(df_train)),
    'xgb': XGBClassifier(),
    #'svm': SVC(), # Too slow
}

### Training

In [5]:
for model in tqdm(models):
    models[model].fit(X_train, y_train)

100%|██████████| 2/2 [04:01<00:00, 120.87s/it]


## Evaluation

In [6]:
# Predict
for model in tqdm(models):
    y_pred = list(models[model].predict(X_test))

    # Evaluate
    y_test = list(y_test)
    print('Model: ', model)
    print('  accuracy_score = ', accuracy_score(y_pred, y_test))
    print('  recall_score = ', recall_score(y_pred, y_test, average='macro'))
    print('  precision_score = ', precision_score(y_pred, y_test, average='macro'))
    print('  f1_score = ', f1_score(y_pred, y_test, average='macro'))
    print('  classification_report = \n', classification_report(y_pred, y_test))

  0%|          | 0/2 [00:00<?, ?it/s]

Model:  log_reg
  accuracy_score =  0.7741936655792021
  recall_score =  0.48097141706008584
  precision_score =  0.5256358749894402
  f1_score =  0.473294530373888


 50%|█████     | 1/2 [00:02<00:02,  2.74s/it]

  classification_report = 
               precision    recall  f1-score   support

           0       1.00      1.00      1.00     23779
           1       0.98      1.00      0.99     23474
           2       0.93      0.66      0.77     34163
           3       0.90      0.72      0.80     39454
           4       0.79      0.65      0.71     32416
           5       1.00      1.00      1.00     42280
           6       0.63      0.77      0.69     17480
           7       0.98      0.91      0.94      1758
           8       0.98      0.92      0.95      1800
           9       0.97      0.97      0.97      2680
          10       0.51      0.15      0.23       459
          11       0.69      0.26      0.37       396
          12       0.44      0.74      0.55     11745
          13       0.21      0.47      0.29      5279
          14       0.27      0.43      0.33     10205
          15       0.76      0.62      0.68       562
          16       0.65      0.59      0.62      6261

100%|██████████| 2/2 [00:11<00:00,  5.62s/it]

  classification_report = 
               precision    recall  f1-score   support

           0       1.00      1.00      1.00     23824
           1       1.00      1.00      1.00     23890
           2       1.00      1.00      1.00     24172
           3       1.00      1.00      1.00     31749
           4       1.00      1.00      1.00     26523
           5       1.00      1.00      1.00     42340
           6       1.00      1.00      1.00     21422
           7       1.00      1.00      1.00      1635
           8       1.00      1.00      1.00      1689
           9       1.00      1.00      1.00      2676
          10       0.98      0.97      0.98       133
          11       0.98      0.97      0.97       148
          12       1.00      1.00      1.00     19592
          13       1.00      1.00      1.00     12010
          14       1.00      1.00      1.00     15932
          15       0.99      1.00      0.99       450
          16       1.00      1.00      1.00      5683


