# Setting Parameters

In [1]:
# Classes: '33+1', '8+1', '1+1'
apply_classes = ['33+1', '8+1', '1+1']

# Samplers: None, 'RandomOverSampler', 'RandomUnderSampler', 'SMOTE', ['Clustering', 'SMOTE']
apply_sampling = 'RandomOverSampler'    # Select ONE from above

# Evaluators: 'XGBoost', 'LogisticRegression', 'Perceptron', 'AdaBoost', 'RandomForest', 'DeepNeuralNetwork', 'KNearestNeighbor'
apply_evaluators = ['XGBoost', 'LogisticRegression', 'Perceptron', 'AdaBoost', 'RandomForest', 'DeepNeuralNetwork', 'KNearestNeighbor']


# Notebook parameter validation
for _class in apply_classes:
    if _class not in ['33+1', '8+1', '1+1']:
        assert False, f'{_class} is an invalid class structure.'

if apply_sampling not in [None, 'RandomOverSampler', 'RandomUnderSampler', 'SMOTE', ['Clustering', 'SMOTE']]:
    assert False, f'{apply_sampling} is an invalid sampling method.'
    
for evaluator in apply_evaluators:
    if evaluator not in ['XGBoost', 'LogisticRegression', 'Perceptron', 'AdaBoost', 
                         'RandomForest', 'DeepNeuralNetwork', 'KNearestNeighbor']:
        assert False, f'{evaluator} is an invalid evaluator.'

# Dataset Handling
## Common Imports

In [2]:
import os
import pandas as pd
import random
from datetime import datetime
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import shuffle
from ydata_profiling import ProfileReport
from tqdm import tqdm

## Loading the Dataset

In [3]:
DATASET_DIRECTORY = '../dataset/'  # If your dataset is within your python project directory, change this to the relative path to your dataset
csv_filepaths = [filename for filename in os.listdir(DATASET_DIRECTORY) if filename.endswith('.csv')]

print(csv_filepaths)

# If there are more than X CSV files, randomly select X files from the list
sample_size = 5

if len(csv_filepaths) > sample_size:
    csv_filepaths = random.sample(csv_filepaths, sample_size)
    print(csv_filepaths)

csv_filepaths.sort()

# list of csv files used
data_sets = csv_filepaths

full_data = pd.DataFrame()
for data_set in data_sets:
    print(f"data set {data_set} out of {len(data_sets)} \n")
    data_path = os.path.join(DATASET_DIRECTORY, data_set)
    df = pd.read_csv(data_path)
    full_data = pd.concat([full_data, df])

# prints an instance of each class
print("Before encoding:")
unique_labels = full_data['label'].unique()
for label in unique_labels:
    print(f"First instance of {label}:")
    print(full_data[full_data['label'] == label].iloc[0])

# Shuffle data
full_data = shuffle(full_data, random_state=42)

# prove if the data is loaded properly
print("Real data:")
print(full_data[:2])
print(full_data.shape)

# Assuming 'label' is the column name for the labels in the DataFrame `synth_data`
unique_labels = full_data['label'].nunique()

# Print the number of unique labels
print(f"There are {unique_labels} unique labels in the dataset.")

class_counts = full_data['label'].value_counts()
print(class_counts)

# Display the first few entries to verify the changes
full_data.describe()

['part-00000-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00001-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00002-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00003-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00004-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00005-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00006-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00007-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00008-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00009-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00010-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00011-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00012-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00013-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00014-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00015-363d1ba3-8ab5-4f96-bc25-4d5862db7cb9-c000.csv', 'part-00016-363d1ba3-8ab5-4f96-bc25-4d5

Unnamed: 0,flow_duration,Header_Length,Protocol Type,Duration,Rate,Srate,Drate,fin_flag_number,syn_flag_number,rst_flag_number,...,AVG,Std,Tot size,IAT,Number,Magnitue,Radius,Covariance,Variance,Weight
count,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,...,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0,1613206.0
mean,5.732339,76760.34,9.064177,66.34553,8864.859,8864.859,5.967919e-06,0.08665291,0.2076015,0.09060839,...,124.4705,33.19336,124.4812,83176240.0,9.498189,13.11702,46.90881,30951.41,0.09635494,141.504
std,264.2836,461441.8,8.947365,13.97637,97426.69,97426.69,0.002453211,0.2813259,0.40559,0.2870515,...,240.4729,161.0986,240.4294,17013810.0,0.8172624,8.611963,227.8422,366879.3,0.2328841,21.02494
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,42.0,0.0,42.0,0.0,1.0,9.165151,0.0,0.0,0.0,1.0
25%,0.0,54.0,6.0,64.0,2.090256,2.090256,0.0,0.0,0.0,0.0,...,50.0,0.0,50.0,83068550.0,9.5,10.0,0.0,0.0,0.0,141.55
50%,0.0,54.0,6.0,64.0,15.7208,15.7208,0.0,0.0,0.0,0.0,...,54.0,0.0,54.0,83124520.0,9.5,10.3923,0.0,0.0,0.0,141.55
75%,0.1052889,266.3975,14.2,64.0,116.3404,116.3404,0.0,0.0,0.0,0.0,...,54.04638,0.3712949,54.06,83343990.0,9.5,10.39663,0.5059213,1.328944,0.08,141.55
max,61547.2,9844135.0,47.0,255.0,8388608.0,8388608.0,1.939832,1.0,1.0,1.0,...,11600.47,10932.14,9998.4,167639400.0,15.0,145.3904,15460.38,143542700.0,1.0,244.6


# Preprocessing
## Duplicating data for classes

In [4]:
all_data = {}

for _class in apply_classes:
    all_data[_class] = full_data.copy()
    
    match _class:            
        case '8+1':
            label_categories = {
                'Backdoor_Malware': 'Web',
                'BenignTraffic': 'Benign',
                'BrowserHijacking': 'Web',
                'CommandInjection': 'DDoS',
                'DDoS-ACK_Fragmentation': 'DDoS',
                'DDoS-HTTP_Flood': 'DDoS',
                'DDoS-ICMP_Flood': 'DDoS',
                'DDoS-ICMP_Fragmentation': 'DDoS',
                'DDoS-PSHACK_Flood': 'DDoS',
                'DDoS-RSTFINFlood': 'DDoS',
                'DDoS-SYN_Flood': 'DDoS',
                'DDoS-SlowLoris': 'DDoS',
                'DDoS-SynonymousIP_Flood': 'DDoS',
                'DDoS-TCP_Flood': 'DDoS',
                'DDoS-UDP_Flood': 'DDoS',
                'DDoS-UDP_Fragmentation': 'DDoS',
                'DNS_Spoofing': 'Spoofing',
                'DictionaryBruteForce': 'BruteForce',
                'DoS-HTTP_Flood': 'DoS',
                'DoS-SYN_Flood': 'DoS',
                'DoS-TCP_Flood': 'DoS',
                'DoS-UDP_Flood': 'DoS',
                'MITM-ArpSpoofing': 'Spoofing',
                'Mirai-greeth_flood': 'Mirai',
                'Mirai-greip_flood': 'Mirai',
                'Mirai-udpplain': 'Mirai',
                'Recon-HostDiscovery': 'Recon',
                'Recon-OSScan': 'Recon',
                'Recon-PingSweep': 'Recon',
                'Recon-PortScan': 'Recon',
                'SqlInjection': 'Web',
                'Uploading_Attack': 'Web',
                'VulnerabilityScan': 'Recon',
                'XSS': 'Web'
            }
            all_data['8+1']['label'] = all_data['8+1']['label'].map(label_categories)
            
        case '1+1':
            all_data['1+1'].loc[all_data['1+1']['label'] != 'BenignTraffic', 'label'] = 'Attack'
            all_data['1+1'].loc[all_data['1+1']['label'] == 'BenignTraffic', 'label'] = 'Benign'

all_data[apply_classes[0]].head()

Unnamed: 0,flow_duration,Header_Length,Protocol Type,Duration,Rate,Srate,Drate,fin_flag_number,syn_flag_number,rst_flag_number,...,Std,Tot size,IAT,Number,Magnitue,Radius,Covariance,Variance,Weight,label
91016,4.204911,108.0,6.0,64.0,0.475634,0.475634,0.0,0.0,0.0,0.0,...,0.0,54.0,82947240.0,9.5,10.392305,0.0,0.0,0.0,141.55,DoS-TCP_Flood
132583,0.03339,10377.0,17.0,64.0,35482.24667,35482.24667,0.0,0.0,0.0,0.0,...,0.0,50.0,83487430.0,9.5,10.0,0.0,0.0,0.0,141.55,DDoS-UDP_Flood
17850,0.0,54.93,6.11,64.64,26.850333,26.850333,0.0,0.0,0.0,0.0,...,1.299305,54.93,83076180.0,9.5,10.42277,1.844534,28.400844,0.06,141.55,DDoS-TCP_Flood
106569,0.040928,36075.0,17.0,64.0,17621.950302,17621.950302,0.0,0.0,0.0,0.0,...,0.0,50.0,83016670.0,9.5,10.0,0.0,0.0,0.0,141.55,DoS-UDP_Flood
76802,5.43375,108.0,6.0,64.0,0.36807,0.36807,0.0,0.0,0.0,0.0,...,0.0,54.0,82956230.0,9.5,10.392305,0.0,0.0,0.0,141.55,DoS-TCP_Flood


In [5]:
all_data[apply_classes[1]].head()

Unnamed: 0,flow_duration,Header_Length,Protocol Type,Duration,Rate,Srate,Drate,fin_flag_number,syn_flag_number,rst_flag_number,...,Std,Tot size,IAT,Number,Magnitue,Radius,Covariance,Variance,Weight,label
91016,4.204911,108.0,6.0,64.0,0.475634,0.475634,0.0,0.0,0.0,0.0,...,0.0,54.0,82947240.0,9.5,10.392305,0.0,0.0,0.0,141.55,DoS
132583,0.03339,10377.0,17.0,64.0,35482.24667,35482.24667,0.0,0.0,0.0,0.0,...,0.0,50.0,83487430.0,9.5,10.0,0.0,0.0,0.0,141.55,DDoS
17850,0.0,54.93,6.11,64.64,26.850333,26.850333,0.0,0.0,0.0,0.0,...,1.299305,54.93,83076180.0,9.5,10.42277,1.844534,28.400844,0.06,141.55,DDoS
106569,0.040928,36075.0,17.0,64.0,17621.950302,17621.950302,0.0,0.0,0.0,0.0,...,0.0,50.0,83016670.0,9.5,10.0,0.0,0.0,0.0,141.55,DoS
76802,5.43375,108.0,6.0,64.0,0.36807,0.36807,0.0,0.0,0.0,0.0,...,0.0,54.0,82956230.0,9.5,10.392305,0.0,0.0,0.0,141.55,DoS


In [6]:
all_data[apply_classes[2]].head()

Unnamed: 0,flow_duration,Header_Length,Protocol Type,Duration,Rate,Srate,Drate,fin_flag_number,syn_flag_number,rst_flag_number,...,Std,Tot size,IAT,Number,Magnitue,Radius,Covariance,Variance,Weight,label
91016,4.204911,108.0,6.0,64.0,0.475634,0.475634,0.0,0.0,0.0,0.0,...,0.0,54.0,82947240.0,9.5,10.392305,0.0,0.0,0.0,141.55,Attack
132583,0.03339,10377.0,17.0,64.0,35482.24667,35482.24667,0.0,0.0,0.0,0.0,...,0.0,50.0,83487430.0,9.5,10.0,0.0,0.0,0.0,141.55,Attack
17850,0.0,54.93,6.11,64.64,26.850333,26.850333,0.0,0.0,0.0,0.0,...,1.299305,54.93,83076180.0,9.5,10.42277,1.844534,28.400844,0.06,141.55,Attack
106569,0.040928,36075.0,17.0,64.0,17621.950302,17621.950302,0.0,0.0,0.0,0.0,...,0.0,50.0,83016670.0,9.5,10.0,0.0,0.0,0.0,141.55,Attack
76802,5.43375,108.0,6.0,64.0,0.36807,0.36807,0.0,0.0,0.0,0.0,...,0.0,54.0,82956230.0,9.5,10.392305,0.0,0.0,0.0,141.55,Attack



## Encoding Labels

In [7]:
for _class in apply_classes:    
    match _class:
        case '33+1':
            full_label_encoder = LabelEncoder()
            all_data['33+1']['label'] = full_label_encoder.fit_transform(all_data['33+1']['label'])
            
        case '8+1':
            class_label_encoder = LabelEncoder()
            all_data['8+1']['label'] = class_label_encoder.fit_transform(all_data['8+1']['label'])
            
        case '1+1':
            binary_label_encoder = LabelEncoder()
            all_data['1+1']['label'] = binary_label_encoder.fit_transform(all_data['1+1']['label'])
            

# Store label mappings
label_mapping = {index: label for index, label in enumerate(full_label_encoder.classes_)}
print("Label mappings:", label_mapping)

# Retrieve the numeric codes for classes
class_codes = {label: full_label_encoder.transform([label])[0] for label in full_label_encoder.classes_}

# Print specific instances after label encoding
print("After encoding:")
for label, code in class_codes.items():
    # Print the first instance of each class
    print(f"First instance of {label} (code {code}):")
    print(all_data[apply_classes[0]][all_data[apply_classes[0]]['label'] == code].iloc[0])

all_data[apply_classes[0]].head()

Label mappings: {0: 'Backdoor_Malware', 1: 'BenignTraffic', 2: 'BrowserHijacking', 3: 'CommandInjection', 4: 'DDoS-ACK_Fragmentation', 5: 'DDoS-HTTP_Flood', 6: 'DDoS-ICMP_Flood', 7: 'DDoS-ICMP_Fragmentation', 8: 'DDoS-PSHACK_Flood', 9: 'DDoS-RSTFINFlood', 10: 'DDoS-SYN_Flood', 11: 'DDoS-SlowLoris', 12: 'DDoS-SynonymousIP_Flood', 13: 'DDoS-TCP_Flood', 14: 'DDoS-UDP_Flood', 15: 'DDoS-UDP_Fragmentation', 16: 'DNS_Spoofing', 17: 'DictionaryBruteForce', 18: 'DoS-HTTP_Flood', 19: 'DoS-SYN_Flood', 20: 'DoS-TCP_Flood', 21: 'DoS-UDP_Flood', 22: 'MITM-ArpSpoofing', 23: 'Mirai-greeth_flood', 24: 'Mirai-greip_flood', 25: 'Mirai-udpplain', 26: 'Recon-HostDiscovery', 27: 'Recon-OSScan', 28: 'Recon-PingSweep', 29: 'Recon-PortScan', 30: 'SqlInjection', 31: 'Uploading_Attack', 32: 'VulnerabilityScan', 33: 'XSS'}
After encoding:
First instance of Backdoor_Malware (code 0):
flow_duration      1007.509440
Header_Length      7380.300000
Protocol Type         6.500000
Duration             90.600000
Rate    

Unnamed: 0,flow_duration,Header_Length,Protocol Type,Duration,Rate,Srate,Drate,fin_flag_number,syn_flag_number,rst_flag_number,...,Std,Tot size,IAT,Number,Magnitue,Radius,Covariance,Variance,Weight,label
91016,4.204911,108.0,6.0,64.0,0.475634,0.475634,0.0,0.0,0.0,0.0,...,0.0,54.0,82947240.0,9.5,10.392305,0.0,0.0,0.0,141.55,20
132583,0.03339,10377.0,17.0,64.0,35482.24667,35482.24667,0.0,0.0,0.0,0.0,...,0.0,50.0,83487430.0,9.5,10.0,0.0,0.0,0.0,141.55,14
17850,0.0,54.93,6.11,64.64,26.850333,26.850333,0.0,0.0,0.0,0.0,...,1.299305,54.93,83076180.0,9.5,10.42277,1.844534,28.400844,0.06,141.55,13
106569,0.040928,36075.0,17.0,64.0,17621.950302,17621.950302,0.0,0.0,0.0,0.0,...,0.0,50.0,83016670.0,9.5,10.0,0.0,0.0,0.0,141.55,21
76802,5.43375,108.0,6.0,64.0,0.36807,0.36807,0.0,0.0,0.0,0.0,...,0.0,54.0,82956230.0,9.5,10.392305,0.0,0.0,0.0,141.55,20


## X, y Splitting

In [8]:
X = {}
y = {}

for _class in apply_classes:
    X[_class] = all_data[_class].drop('label', axis=1)
    y[_class] = all_data[_class]['label']

print(f'X: {X[apply_classes[0]].shape}, y: {y[apply_classes[0]].shape}')

X: (1613206, 46), y: (1613206,)


# Sampling

In [9]:
if apply_sampling is not None:
    
    undersampler = None
    oversampler = None
    
    for sampler in apply_sampling:
        match apply_sampling:
            case 'RandomOverSampler':
                from imblearn.over_sampling import RandomOverSampler
                oversampler = RandomOverSampler(random_state=42)
            case 'RandomUnderSampler':
                from imblearn.under_sampling import RandomUnderSampler
                undersampler = RandomUnderSampler(random_state=42)
            case 'SMOTENC':
                from imblearn.over_sampling import SMOTENC
                cat_cols = [
                    'Protocol Type', 'Drate', 'fin_flag_number', 'syn_flag_number', 'rst_flag_number',
                    'psh_flag_number', 'ack_flag_number', 'ece_flag_number',
                    'cwr_flag_number', 'HTTP', 'HTTPS', 'DNS', 'Telnet',
                    'SMTP', 'SSH', 'IRC', 'TCP', 'UDP', 'DHCP', 'ARP',
                    'ICMP', 'IPv', 'LLC'
                ]
                oversampler = SMOTENC(categorical_features=cat_cols, random_state=42)
            case 'Clustering':
                from imblearn.under_sampling import ClusterCentroids
                undersampler = ClusterCentroids(random_state=42)
    
    for _class in apply_classes:
        if undersampler is not None:
            X[_class], y[_class] = undersampler.fit_resample(X[_class], y[_class])  
        if oversampler is not None:
            X[_class], y[_class] = oversampler.fit_resample(X[_class], y[_class])
    
    print(f'X: {X[apply_classes[0]].shape}, y: {y[apply_classes[0]].shape}')
else:
    print('No sampling selected.')

X: (8454882, 46), y: (8454882,)


In [10]:
# Recombine the resampled features and labels back
all_data_resampled = {}
for _class in apply_classes:
    all_data_resampled[_class] = pd.concat([X[_class], y[_class]], axis=1)

print("Resampled Data (UNSCALED):")
for label, code in class_codes.items():
    # Print the first instance of each class
    print(f"First instance of {label} (code {code}):")
    print(all_data_resampled[apply_classes[0]][all_data_resampled[apply_classes[0]]['label'] == code].iloc[0])

Resampled Data (UNSCALED):
First instance of Backdoor_Malware (code 0):
flow_duration      1007.509440
Header_Length      7380.300000
Protocol Type         6.500000
Duration             90.600000
Rate                  0.266572
Srate                 0.266572
Drate                 0.000000
fin_flag_number       0.000000
syn_flag_number       0.000000
rst_flag_number       0.000000
psh_flag_number       0.000000
ack_flag_number       1.000000
ece_flag_number       0.000000
cwr_flag_number       0.000000
ack_count             0.000000
syn_count             0.400000
fin_count             0.000000
urg_count            23.400000
rst_count            49.500000
HTTP                  0.000000
HTTPS                 0.000000
DNS                   0.000000
Telnet                0.000000
SMTP                  0.000000
SSH                   0.000000
IRC                   0.000000
TCP                   1.000000
UDP                   0.000000
DHCP                  0.000000
ARP                   0.00000

## Real vs Resampled Dataset Analysis

In [11]:
all_data_resampled[apply_classes[0]].describe()

Unnamed: 0,flow_duration,Header_Length,Protocol Type,Duration,Rate,Srate,Drate,fin_flag_number,syn_flag_number,rst_flag_number,...,Std,Tot size,IAT,Number,Magnitue,Radius,Covariance,Variance,Weight,label
count,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,...,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0,8454882.0
mean,122.4254,223895.3,10.39046,80.25666,5526.85,5526.85,2.178548e-06,0.02987067,0.1109603,0.04885107,...,180.6209,339.9129,81918680.0,9.427299,21.59087,255.2632,144834.5,0.5480666,139.678,16.5
std,1016.272,800925.3,9.803101,32.5519,68339.92,68339.92,0.00144938,0.1702305,0.314083,0.2155566,...,300.4174,443.6551,53562360.0,2.56766,13.83429,424.8616,586669.8,0.4357656,66.12926,9.810709
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,42.0,0.0,1.0,9.165151,0.0,0.0,0.0,1.0,0.0
25%,0.00271975,58.0,6.0,64.0,2.972179,2.972179,0.0,0.0,0.0,0.0,...,0.0,54.06,82972690.0,9.5,10.39781,0.0,0.0,0.0,141.55,8.0
50%,1.33093,3321.2,6.11,64.0,20.95643,20.95643,0.0,0.0,0.0,0.0,...,41.62112,115.6,83249960.0,9.5,15.04428,58.74259,2710.033,0.8,141.55,16.5
75%,28.73764,37043.5,11.33,86.7,97.71302,97.71302,0.0,0.0,0.0,0.0,...,241.413,558.76,83677410.0,9.5,33.28663,340.6369,110976.8,0.95,141.55,25.0
max,61547.2,9844135.0,47.0,255.0,8388608.0,8388608.0,1.939832,1.0,1.0,1.0,...,10932.14,9998.4,167639400.0,15.0,145.3904,15460.38,143542700.0,1.0,244.6,33.0


### Generate Reports

In [None]:
for _class in apply_classes:
    original_report = ProfileReport(all_data[_class], title=f'{_class} Original Data', minimal=True)
    resampled_report = ProfileReport(all_data_resampled[_class], title=f'{_class} Resampled Data', minimal=True)
    comparison_report = original_report.compare(resampled_report)
    comparison_report.to_file(f'./profile_reports/{apply_sampling}_{_class}_resampling_report.html')

# Evaluator Model

## Imports

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

## Preprocessing
### Scaling Numerical Features

In [None]:
num_cols = [
    'flow_duration', 'Header_Length',  'Duration', 'Rate', 'Srate', 'ack_count', 'syn_count', 'fin_count',
    'urg_count', 'rst_count', 'Tot sum', 'Min', 'Max', 'AVG', 'Std', 'Tot size', 'IAT', 'Number', 'Magnitue',
    'Radius', 'Covariance', 'Variance', 'Weight'
]

scaler = StandardScaler()
for _class in apply_classes:
    all_data_resampled[_class][num_cols] = scaler.fit_transform(all_data_resampled[_class][num_cols])

### X, y Train/Test Splitting

In [None]:
X_train = {}
X_test = {}
y_train = {}
y_test = {}
    
for _class in apply_classes:
    X = all_data_resampled[_class].drop('label', axis=1)
    y = all_data_resampled[_class]['label']
    
    X_train_temp, X_test_temp, y_train_temp, y_test_temp = train_test_split(X, y, test_size=0.2, random_state=42)
    
    X_train.update({_class: X_train_temp})
    X_test.update({_class: X_test_temp})
    y_train.update({_class: y_train_temp})
    y_test.update({_class: y_test_temp})
    
print(f'X_train: {X_train[apply_classes[0]].shape}, y_train: {y_train[apply_classes[0]].shape}, X_test: {X_test[apply_classes[0]].shape}, y_test: {y_test[apply_classes[0]].shape}')

## Training

In [None]:
for evaluator_type in apply_evaluators:
    match evaluator_type:
        case 'XGBoost':
            from xgboost import XGBClassifier
            evaluator = XGBClassifier()
        case 'LogisticRegression':
            from sklearn.linear_model import LogisticRegression
            evaluator = LogisticRegression(random_state=42, n_jobs=-1)
        case 'Perceptron':
            from sklearn.linear_model import Perceptron
            evaluator = Perceptron(random_state=42, n_jobs=-1)
        case 'AdaBoost':
            from sklearn.ensemble import AdaBoostClassifier
            evaluator = AdaBoostClassifier(random_state=42, algorithm='SAMME')
        case 'RandomForest':
            from sklearn.ensemble import RandomForestClassifier
            evaluator = RandomForestClassifier(random_state=42, n_jobs=-1)
        case 'DeepNeuralNetwork':
            from sklearn.neural_network import MLPClassifier
            evaluator = MLPClassifier(random_state=42)
        case 'KNearestNeighbor':
            from sklearn.neighbors import KNeighborsClassifier
            evaluator = KNeighborsClassifier(n_jobs=-1)
        case _:
            print(f'Invalid evaluator model: {evaluator_type}')
    
    
    
    for _class in apply_classes:
        # XGBoost for binary classification must be a binary objective
        if evaluator_type == 'XGBoost' and _class == '1+1':
            evaluator = XGBClassifier(objective='binary:logistic')
            
        print(f'{datetime.now()} : Training {evaluator_type} on {apply_sampling} balanced data with {_class} label classes')
        evaluator.fit(X_train[_class], y_train[_class])
    
        print(f'{datetime.now()} : Predicting {evaluator_type} on {_class} classes')
        y_pred = evaluator.predict(X_test[_class])
    
        print(f'{evaluator_type} {_class} Metrics')
        print(f'   Accuracy: {accuracy_score(y_test[_class], y_pred)}')
        print(f'   Precision: {precision_score(y_test[_class], y_pred, average='weighted', zero_division=0.0)}')
        print(f'   Recall: {recall_score(y_test[_class], y_pred, average='weighted')}')
        print(f'   F1: {f1_score(y_test[_class], y_pred, average='weighted')}')
        print()

## Model Analysis

In [None]:
# Disabled, needs multi-label-group and multi-estimator re-implementation
# 
# from sklearn.metrics import confusion_matrix
# 
# cm = pd.DataFrame(confusion_matrix(y_test, y_pred), columns = full_label_encoder.classes_)
# cm.insert(0, column='Actual', value=full_label_encoder.classes_)
# cm