In [1]:
import copy
import io

import numpy as np

import graphviz
import matplotlib.pyplot as plt
import pandas as pd
import rootpath
import shap
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from interpret import show
from interpret.blackbox import LimeTabular
from skexplain.enums.feature_type import FeatureType
from skexplain.imitation import (ClassificationDagger,
                                 IncrementalClassificationDagger,
                                 RegressionDagger)
from skexplain.utils import dataset, log, persist
from skexplain.utils.const import (BOSTON_DATASET_META,
                                   CIC_IDS_2017_DATASET_META,
                                   HEARTBLEED_DATASET_META,
                                   HEARTBLEED_LARGE_DATASET_META,
                                   DIABETES_DATASET_META,
                                   DOWNLOAD_DATASET_META, IOT_DATASET_META,
                                   WINE_DATASET_META)
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import classification_report, f1_score, r2_score
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.svm import LinearSVC

df_test_meta=HEARTBLEED_LARGE_DATASET_META
df_train_meta=CIC_IDS_2017_DATASET_META 
model=RandomForestClassifier
as_df=False


""" Test using Reinforcement Learning to extract Decision Tree from a generic Blackbox model """
logger = log.Logger(
    "{}/res/log/{}/notebook_{}_{}.log".format(rootpath.detect(), df_train_meta['name'],  model.__name__, "Raw")
)
logger.log('Init done.')

2021-06-25 07:10:37,955 - INFO - Init done.


In [2]:
# Step 1: Parse test def
X_train, y_train, feature_names, _, _ = dataset.read(df_train_meta['path'], metadata=df_train_meta,
                                         verbose=True, logger=logger, as_df=as_df)

2021-04-21 21:02:59,424 - INFO - Names: ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', 'Total Length of Fwd Packets', 'Total Length of Bwd Packets', 'Fwd Packet Length Max', 'Fwd Packet Length Min', 'Fwd Packet Length Mean', 'Fwd Packet Length Std', 'Bwd Packet Length Max', 'Bwd Packet Length Min', 'Bwd Packet Length Mean', 'Bwd Packet Length Std', 'Flow Bytes/s', 'Flow Packets/s', 'Flow IAT Mean', 'Flow IAT Std', 'Flow IAT Max', 'Flow IAT Min', 'Fwd IAT Total', 'Fwd IAT Mean', 'Fwd IAT Std', 'Fwd IAT Max', 'Fwd IAT Min', 'Bwd IAT Total', 'Bwd IAT Mean', 'Bwd IAT Std', 'Bwd IAT Max', 'Bwd IAT Min', 'Fwd PSH Flags', 'Fwd URG Flags', 'Bwd Header Length', 'Fwd Packets/s', 'Bwd Packets/s', 'Min Packet Length', 'Max Packet Length', 'Packet Length Mean', 'Packet Length Std', 'Packet Length Variance', 'FIN Flag Count', 'SYN Flag Count', 'RST Flag Count', 'PSH Flag Count', 'ACK Flag Count', 'URG Flag Count', 'CWE Flag Count', 'ECE Flag Count', 'Down/Up Ratio', 'Average Packet

2021-04-21 21:03:19,957 - INFO - Targets shape: (2275074, 1) Index(['Label'], dtype='object')


In [3]:
# Step 2: Train black-box model with loaded dataset
logger.log("#" * 10, "Model init", "#" * 10)
model_path = "../res/weights/{}_{}_{}_{}.joblib".format(model.__name__,
                                                        "Raw", df_train_meta['name'], X_train.shape[1])
logger.log("Looking for pre-trained model: {}...".format(model_path))
blackbox = persist.load_model(model_path)
if not blackbox:
    raise ValueError("Trained model not found. Please train model before unit testing it.")
logger.log("#" * 10, "Done", "#" * 10)

2021-04-21 21:03:25,721 - INFO - ########## Model init ##########
2021-04-21 21:03:25,723 - INFO - Looking for pre-trained model: ../res/weights/RandomForestClassifier_Raw_cic_ids_2017_70.joblib...
2021-04-21 21:03:25,818 - INFO - ########## Done ##########


In [5]:
X_test, y_test, feature_names, _, _ = dataset.read(df_test_meta['path'], metadata=df_test_meta,
                                         verbose=True, logger=logger, as_df=as_df)                                                   


logger.log(df_test_meta['path'])
logger.log(X_test)
logger.log(y_test)

y_pred = blackbox.predict(X_test)
logger.log(list(y_pred))

logger.log("Blackbox model training classification report:")
logger.log("\n{}".format(classification_report(y_test, y_pred, digits=3)))

2021-04-21 21:04:09,385 - INFO - Names: ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets', 'Total Length of Fwd Packets', 'Total Length of Bwd Packets', 'Fwd Packet Length Max', 'Fwd Packet Length Min', 'Fwd Packet Length Mean', 'Fwd Packet Length Std', 'Bwd Packet Length Max', 'Bwd Packet Length Min', 'Bwd Packet Length Mean', 'Bwd Packet Length Std', 'Flow Bytes/s', 'Flow Packets/s', 'Flow IAT Mean', 'Flow IAT Std', 'Flow IAT Max', 'Flow IAT Min', 'Fwd IAT Total', 'Fwd IAT Mean', 'Fwd IAT Std', 'Fwd IAT Max', 'Fwd IAT Min', 'Bwd IAT Total', 'Bwd IAT Mean', 'Bwd IAT Std', 'Bwd IAT Max', 'Bwd IAT Min', 'Fwd PSH Flags', 'Fwd URG Flags', 'Bwd Header Length', 'Fwd Packets/s', 'Bwd Packets/s', 'Min Packet Length', 'Max Packet Length', 'Packet Length Mean', 'Packet Length Std', 'Packet Length Variance', 'FIN Flag Count', 'SYN Flag Count', 'RST Flag Count', 'PSH Flag Count', 'ACK Flag Count', 'URG Flag Count', 'CWE Flag Count', 'ECE Flag Count', 'Down/Up Ratio', 'Average Packet

2021-04-21 21:04:09,460 - INFO - Targets shape: (1041, 1) Index(['Label'], dtype='object')
2021-04-21 21:04:09,465 - INFO - /Users/asjacobs/workspace/explainability/scikit-explain/res/dataset/heartbleed-large/heartbleed-large.csv
2021-04-21 21:04:09,466 - INFO - [[1.4780e+03 6.0000e+00 6.0000e+00 ... 0.0000e+00 1.0000e+00 0.0000e+00]
 [1.2842e+04 9.0000e+00 8.0000e+00 ... 0.0000e+00 1.0000e+00 0.0000e+00]
 [1.0592e+04 9.0000e+00 8.0000e+00 ... 0.0000e+00 1.0000e+00 0.0000e+00]
 ...
 [1.1570e+03 6.0000e+00 6.0000e+00 ... 0.0000e+00 1.0000e+00 0.0000e+00]
 [9.6500e+02 6.0000e+00 6.0000e+00 ... 0.0000e+00 1.0000e+00 0.0000e+00]
 [1.3380e+03 6.0000e+00 6.0000e+00 ... 0.0000e+00 1.0000e+00 0.0000e+00]]
2021-04-21 21:04:09,467 - INFO - [[0]
 [1]
 [1]
 ...
 [0]
 [0]
 [0]]
2021-04-21 21:04:09,577 - INFO - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0