In [8]:
import os
import sys
from itertools import combinations

import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, f1_score
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.preprocessing import OrdinalEncoder, RobustScaler, StandardScaler
from xgboost import XGBClassifier
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm, trange

from utils import *

In [9]:
dataset_train = FlowDataset(train=True)
dataset_valid = FlowDataset(train=False)

In [10]:
x_train, y_train = dataset_train.get_xy()
x_valid, y_valid = dataset_valid.get_xy()

In [11]:
model = XGBClassifier(n_jobs=-1, random_state=GLOBAL_SEED, use_label_encoder=False)

In [12]:
model.fit(x_train, y_train)



XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, enable_categorical=False,
              gamma=0, gpu_id=-1, importance_type=None,
              interaction_constraints='', learning_rate=0.300000012,
              max_delta_step=0, max_depth=6, min_child_weight=1, missing=nan,
              monotone_constraints='()', n_estimators=100, n_jobs=-1,
              num_parallel_tree=1, predictor='auto', random_state=755,
              reg_alpha=0, reg_lambda=1, scale_pos_weight=1, subsample=1,
              tree_method='exact', use_label_encoder=False,
              validate_parameters=1, verbosity=None)

In [13]:
y_pred = model.predict(x_valid)

In [14]:
print(classification_report(y_valid, y_pred))

              precision    recall  f1-score   support

           0       1.00      0.99      0.99    190273
           1       0.99      1.00      0.99    135287

    accuracy                           0.99    325560
   macro avg       0.99      0.99      0.99    325560
weighted avg       0.99      0.99      0.99    325560



In [134]:
def get_detect_result(model, df:pd.DataFrame, x:np.array, threshold:int=0, scaler=None):
    pred = model.predict_proba(x)
    assert len(pred) == len(x)
    total = set(df['dst_ip'])
    candidate = df[pred >= 0.995].groupby('dst_ip')['dst_ip'].count()
    detected = set(candidate[candidate >= threshold].index)
    not_detected = total - detected
    return detected, not_detected

In [152]:
def get_detect_report(detected, not_detected, mal_ip, ben_ip, digit=4, verbose=False):
    detected, not_detected = map(set, [detected, not_detected])
    mal_ip, ben_ip = map(set, [mal_ip, ben_ip])
    tp = detected & mal_ip
    fp = detected & ben_ip
    tn = not_detected & ben_ip
    fn = not_detected & mal_ip
    tp, fp, tn, fn = map(len, [tp, fp, tn, fn])
    total = tp + fp + tn + fn
    acc = (tp + tn) / total
    pre = (tp) / (tp + fp + 1e-8)
    rec = (tp) / (tp + fn + 1e-8)
    f1 = 2 * (pre * rec) / (pre + rec + 1e-8)
    if verbose:
        print(f"accuracy: {acc:.{digit}}\nprecision: {pre:.{digit}}\nrecall: {rec:.{digit}}\nf1: {f1:.{digit}}")
    return f1

In [145]:
detected, not_detected = get_detect_result(model, dataset_valid.df, x_valid, 1)

In [146]:
get_detect_report(detected, not_detected, dataset_valid.outer_mal, dataset_valid.outer_ben)

0.5826210782217982

In [147]:
pred = model.predict_proba(x_valid)

In [148]:
df = dataset_valid.df

In [149]:
df['prob_0'] = pred[:, 0]
df['prob_1'] = pred[:, 1]


In [150]:
raw = df.groupby('dst_ip', as_index=False)[['prob_0', 'prob_1']]

In [160]:
temp = raw.mean()
whole = set(temp['dst_ip'])
best_f1, best_threshold = 0, 0
best_detected, best_not_detected = None, None
for i in tqdm(np.linspace(0, 1, 1000)):
    detected = set(temp[temp['prob_1'] >= i]['dst_ip'])
    not_detected = whole - detected
    t = get_detect_report(detected, not_detected, dataset_valid.outer_mal, dataset_valid.outer_ben)
    if t > best_f1:
        best_threshold = i
        best_f1 = t
        best_detected = detected
        best_not_detected = not_detected
get_detect_report(best_detected, best_not_detected, dataset_valid.outer_mal, dataset_valid.outer_ben, digit=6, verbose=True)

  0%|          | 0/10000 [00:00<?, ?it/s]

accuracy: 0.990449
precision: 0.980645
recall: 0.997812
f1: 0.989154


0.9891540079941042

In [163]:
best_threshold

0.6085608560856086

In [161]:
import pickle

In [162]:
with open('proba_best_detected.pkl', 'wb') as f:
    pickle.dump(best_detected, f)
with open('proba_best_not_detected.pkl', 'wb') as f:
    pickle.dump(best_not_detected, f)