In [11]:

import numpy as np
import pandas as pd

# Pre-processing
from sklearn.pipeline import Pipeline
from sklearn.base import TransformerMixin
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer

from src.data_loader import DataRetriever, LabelManager

class CustomColumnTransformer(ColumnTransformer):

    def get_feature_names(self):
        feature_names = super().get_feature_names()

        for idx, _ in enumerate(feature_names):
            feature_names[idx] = feature_names[idx].replace("b'", "")
            feature_names[idx] = feature_names[idx].replace("'", "")

        return feature_names

label_manager = LabelManager(config_file="../data.json")
data_retriever = DataRetriever(label_manager=label_manager)
X, y = data_retriever.X_y_dataset()

X_discreet = label_manager.X_discreet
X_continuous = label_manager.X_continuous

discreet_cols = []
for item in X_discreet:
    print(f"{item}: {X.columns.get_loc(item)}")
    discreet_cols.append(X.columns.get_loc(item))

protocol_type: 1
service: 2
flag: 3
land: 6
logged_in: 11
is_host_login: 20
is_guest_login: 21


In [12]:
print(discreet_cols)
ct1 = CustomColumnTransformer(
        transformers=[
            ("ohe", OneHotEncoder(), discreet_cols)
        ],
        remainder="drop",
    
        n_jobs=-1
    )

_X = ct1.fit_transform(X)

[1, 2, 3, 6, 11, 20, 21]


In [13]:
names = ct1.named_transformers_['ohe'].get_feature_names()
d = pd.DataFrame(data=_X.toarray(), columns=names)
print(names)
print(d)

["x0_b'icmp'" "x0_b'tcp'" "x0_b'udp'" "x1_b'IRC'" "x1_b'X11'"
 "x1_b'Z39_50'" "x1_b'auth'" "x1_b'bgp'" "x1_b'courier'" "x1_b'csnet_ns'"
 "x1_b'ctf'" "x1_b'daytime'" "x1_b'discard'" "x1_b'domain'"
 "x1_b'domain_u'" "x1_b'echo'" "x1_b'eco_i'" "x1_b'ecr_i'" "x1_b'efs'"
 "x1_b'exec'" "x1_b'finger'" "x1_b'ftp'" "x1_b'ftp_data'" "x1_b'gopher'"
 "x1_b'hostnames'" "x1_b'http'" "x1_b'http_443'" "x1_b'imap4'"
 "x1_b'iso_tsap'" "x1_b'klogin'" "x1_b'kshell'" "x1_b'ldap'" "x1_b'link'"
 "x1_b'login'" "x1_b'mtp'" "x1_b'name'" "x1_b'netbios_dgm'"
 "x1_b'netbios_ns'" "x1_b'netbios_ssn'" "x1_b'netstat'" "x1_b'nnsp'"
 "x1_b'nntp'" "x1_b'ntp_u'" "x1_b'other'" "x1_b'pm_dump'" "x1_b'pop_2'"
 "x1_b'pop_3'" "x1_b'printer'" "x1_b'private'" "x1_b'red_i'"
 "x1_b'remote_job'" "x1_b'rje'" "x1_b'shell'" "x1_b'smtp'" "x1_b'sql_net'"
 "x1_b'ssh'" "x1_b'sunrpc'" "x1_b'supdup'" "x1_b'systat'" "x1_b'telnet'"
 "x1_b'tftp_u'" "x1_b'tim_i'" "x1_b'time'" "x1_b'urh_i'" "x1_b'urp_i'"
 "x1_b'uucp'" "x1_b'uucp_path'" "x1_b'vmne

In [15]:
X[0:10]

Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,hot,...,dst_host_count,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate
0,0,b'tcp',b'http',b'SF',181,5450,0,0,0,0,...,9,9,1,0,0.11,0.0,0,0,0,0
1,0,b'tcp',b'http',b'SF',239,486,0,0,0,0,...,19,19,1,0,0.05,0.0,0,0,0,0
2,0,b'tcp',b'http',b'SF',235,1337,0,0,0,0,...,29,29,1,0,0.03,0.0,0,0,0,0
3,0,b'tcp',b'http',b'SF',219,1337,0,0,0,0,...,39,39,1,0,0.03,0.0,0,0,0,0
4,0,b'tcp',b'http',b'SF',217,2032,0,0,0,0,...,49,49,1,0,0.02,0.0,0,0,0,0
5,0,b'tcp',b'http',b'SF',217,2032,0,0,0,0,...,59,59,1,0,0.02,0.0,0,0,0,0
6,0,b'tcp',b'http',b'SF',212,1940,0,0,0,0,...,1,69,1,0,1.0,0.04,0,0,0,0
7,0,b'tcp',b'http',b'SF',159,4087,0,0,0,0,...,11,79,1,0,0.09,0.04,0,0,0,0
8,0,b'tcp',b'http',b'SF',210,151,0,0,0,0,...,8,89,1,0,0.12,0.04,0,0,0,0
9,0,b'tcp',b'http',b'SF',212,786,0,0,0,1,...,8,99,1,0,0.12,0.05,0,0,0,0


d