In [2]:
import copy
import io

import numpy as np

import graphviz
import matplotlib.pyplot as plt
import pandas as pd
import rootpath
import shap
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from interpret import show
from interpret.blackbox import LimeTabular
from skexplain.enums.feature_type import FeatureType
from skexplain.imitation import (ClassificationDagger,
                                 IncrementalClassificationDagger,
                                 RegressionDagger)
from skexplain.utils import dataset, log, persist
from skexplain.utils.const import (BOSTON_DATASET_META,
                                   CIC_IDS_2017_DATASET_META,
                                   DIABETES_DATASET_META,
                                   DOWNLOAD_DATASET_META, IOT_DATASET_META,
                                   WINE_DATASET_META)
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import classification_report, f1_score, r2_score
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.svm import LinearSVC

df_test_meta=DOWNLOAD_DATASET_META
df_train_meta=CIC_IDS_2017_DATASET_META 
model=RandomForestClassifier
as_df=False


""" Test using Reinforcement Learning to extract Decision Tree from a generic Blackbox model """
logger = log.Logger(
    "{}/res/log/{}/notebook_{}_{}.log".format(rootpath.detect(), df_train_meta['name'],  model.__name__, "Raw")
)
logger.log('Init done.')

2021-04-22 13:51:33,813 - INFO - Init done.


In [3]:
# Step 1: Parse test def
X_train, y_train, feature_names, _, _ = dataset.read(df_train_meta['path'], metadata=df_train_meta,
                                         verbose=True, logger=logger, as_df=as_df)

2021-04-22 13:51:35,643 - INFO - Names: ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', 'Total Length of Fwd Packets', 'Total Length of Bwd Packets', 'Fwd Packet Length Max', 'Fwd Packet Length Min', 'Fwd Packet Length Mean', 'Fwd Packet Length Std', 'Bwd Packet Length Max', 'Bwd Packet Length Min', 'Bwd Packet Length Mean', 'Bwd Packet Length Std', 'Flow Bytes/s', 'Flow Packets/s', 'Flow IAT Mean', 'Flow IAT Std', 'Flow IAT Max', 'Flow IAT Min', 'Fwd IAT Total', 'Fwd IAT Mean', 'Fwd IAT Std', 'Fwd IAT Max', 'Fwd IAT Min', 'Bwd IAT Total', 'Bwd IAT Mean', 'Bwd IAT Std', 'Bwd IAT Max', 'Bwd IAT Min', 'Fwd PSH Flags', 'Fwd URG Flags', 'Bwd Header Length', 'Fwd Packets/s', 'Bwd Packets/s', 'Min Packet Length', 'Max Packet Length', 'Packet Length Mean', 'Packet Length Std', 'Packet Length Variance', 'FIN Flag Count', 'SYN Flag Count', 'RST Flag Count', 'PSH Flag Count', 'ACK Flag Count', 'URG Flag Count', 'CWE Flag Count', 'ECE Flag Count', 'Down/Up Ratio', 'Average Packet

2021-04-22 13:51:55,133 - INFO - Targets shape: (2275074, 1) Index(['Label'], dtype='object')


In [4]:
# Step 2: Train black-box model with loaded dataset
logger.log("#" * 10, "Model init", "#" * 10)
model_path = "../res/weights/{}_{}_{}_{}.joblib".format(model.__name__,
                                                        "Raw", df_train_meta['name'], X_train.shape[1])
logger.log("Looking for pre-trained model: {}...".format(model_path))
blackbox = persist.load_model(model_path)
if not blackbox:
    raise ValueError("Trained model not found. Please train model before unit testing it.")
logger.log("#" * 10, "Done", "#" * 10)

2021-04-22 13:52:00,065 - INFO - ########## Model init ##########
2021-04-22 13:52:00,068 - INFO - Looking for pre-trained model: ../res/weights/RandomForestClassifier_Raw_cic_ids_2017_70.joblib...
2021-04-22 13:52:00,142 - INFO - ########## Done ##########


In [14]:
X_test, y_test, feature_names, _, _ = dataset.read(df_test_meta['path'], metadata=df_test_meta,
                                         verbose=True, logger=logger, as_df=as_df)                                                   

logger.log(X_test)
logger.log(y_test)

y_pred = blackbox.predict(X_test)
logger.log(y_test.ravel(), y_pred)

logger.log("Blackbox model training classification report:")
logger.log("\n{}".format(classification_report(y_test, y_pred, digits=3)))

2021-04-22 14:47:05,102 - INFO - Names: ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', 'Total Length of Fwd Packets', 'Total Length of Bwd Packets', 'Fwd Packet Length Max', 'Fwd Packet Length Min', 'Fwd Packet Length Mean', 'Fwd Packet Length Std', 'Bwd Packet Length Max', 'Bwd Packet Length Min', 'Bwd Packet Length Mean', 'Bwd Packet Length Std', 'Flow Bytes/s', 'Flow Packets/s', 'Flow IAT Mean', 'Flow IAT Std', 'Flow IAT Max', 'Flow IAT Min', 'Fwd IAT Total', 'Fwd IAT Mean', 'Fwd IAT Std', 'Fwd IAT Max', 'Fwd IAT Min', 'Bwd IAT Total', 'Bwd IAT Mean', 'Bwd IAT Std', 'Bwd IAT Max', 'Bwd IAT Min', 'Fwd PSH Flags', 'Fwd URG Flags', 'Bwd Header Length', 'Fwd Packets/s', 'Bwd Packets/s', 'Min Packet Length', 'Max Packet Length', 'Packet Length Mean', 'Packet Length Std', 'Packet Length Variance', 'FIN Flag Count', 'SYN Flag Count', 'RST Flag Count', 'PSH Flag Count', 'ACK Flag Count', 'URG Flag Count', 'CWE Flag Count', 'ECE Flag Count', 'Down/Up Ratio', 'Average Packet

2021-04-22 14:47:05,143 - INFO - (16, 69)
2021-04-22 14:47:05,145 - INFO - Any NAN? 0
2021-04-22 14:47:05,150 - INFO - Total memory usage: 0.00 MB
2021-04-22 14:47:05,151 - INFO - Average memory usage: 0.00 MB
2021-04-22 14:47:05,152 - INFO - [0 1]
2021-04-22 14:47:05,159 - INFO - Features Shape: (16, 70)
2021-04-22 14:47:05,160 - INFO - Column names:
0: Flow Duration
1: Total Fwd Packets
2: Total Backward Packets
3: Total Length of Fwd Packets
4: Total Length of Bwd Packets
5: Fwd Packet Length Max
6: Fwd Packet Length Min
7: Fwd Packet Length Mean
8: Fwd Packet Length Std
9: Bwd Packet Length Max
10: Bwd Packet Length Min
11: Bwd Packet Length Mean
12: Bwd Packet Length Std
13: Flow Bytes/s
14: Flow Packets/s
15: Flow IAT Mean
16: Flow IAT Std
17: Flow IAT Max
18: Flow IAT Min
19: Fwd IAT Total
20: Fwd IAT Mean
21: Fwd IAT Std
22: Fwd IAT Max
23: Fwd IAT Min
24: Bwd IAT Total
25: Bwd IAT Mean
26: Bwd IAT Std
27: Bwd IAT Max
28: Bwd IAT Min
29: Bwd Header Length
30: Fwd Packets/s
31: 

In [15]:
# Blackbox explainers need a predict function, and optionally a dataset
lime = LimeTabular(predict_fn=blackbox.predict_proba, data=X_train, random_state=1,  feature_names=list(feature_names))

# Pick the instances to explain, optionally pass in labels if you have them
lime_local = lime.explain_local([X_test[0]], y_test[0], name='LIME Benign Traffic Test Case')

show(lime_local)

In [16]:
# Pick the instances to explain, optionally pass in labels if you have them
lime_local = lime.explain_local([X_test[1]], y_test[1], name='LIME Small Heartbleed Test Case')

show(lime_local)

In [17]:
# Pick the instances to explain, optionally pass in labels if you have them
lime_local = lime.explain_local([X_test[2]], y_test[2], name='LIME Large Heartbleed Test Case')

show(lime_local)

In [18]:
# Pick the instances to explain, optionally pass in labels if you have them
lime_local = lime.explain_local([X_test[3]], y_test[3], name='LIME Original Heartbleed Test Case')

show(lime_local)

In [19]:
from interpret.blackbox import ShapKernel
import numpy as np

background_val = np.median(X_train, axis=0).reshape(1, -1)
shap = ShapKernel(predict_fn=blackbox.predict_proba, data=background_val, feature_names=list(feature_names))
shap_local = shap.explain_local(X_test, y_test, name='SHAP All Test Cases')


show(shap_local)

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))




In [20]:
from interpret.blackbox import MorrisSensitivity

sensitivity = MorrisSensitivity(predict_fn=blackbox.predict_proba, data=X_train,  feature_names=list(feature_names))
sensitivity_global = sensitivity.explain_global(name="Global Sensitivity")

show(sensitivity_global)

In [13]:
from interpret.glassbox import ClassificationTree

y_pred = blackbox.predict(X_train)

tree = ClassificationTree(feature_names=list(feature_names))
tree.fit(X_train, y_pred)

tree_global = tree.explain_global(name='Classification Tree')

show(tree_global)