### Libraries

In [73]:
%%capture
%reset -f                        # clear all variables from the workspace
'generic imports'
import os
import pandas as pd
import datetime
import numpy as np
import sys
sys.path.append(os.path.abspath('..'))
from src import utils
import importlib
importlib.reload(utils)        

'machine learning imports'
import torch
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.pretraining import TabNetPretrainer
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import shuffle

### GPU

In [74]:
DEVICE = 'CUDA' if torch.cuda.is_available() else 'CPU'
print("Using {}".format(DEVICE))

# Info on the device available memory
if DEVICE == 'CUDA':
    gpu = torch.device('cuda')
    total_memory = torch.cuda.get_device_properties(gpu).total_memory / 1024**3
    current_memory = torch.cuda.memory_allocated(gpu) / 1024**3

    print(f'Total GPU memory: {total_memory:.1f} GB | Current usage: {current_memory:.1f} GB')

Using CPU


### Load data

In [75]:
AUGMENTATION = 'RealTabFormer'

data_dir = os.path.abspath('../data')

# Load the train and test datasets
df_train, df_test = utils.load_dataset(data_directory=data_dir, 
                                       augmentation=AUGMENTATION, 
                                       ignore_columns=['mbtcp.unit_id', 
                                                       'mbtcp.trans_id']) 

Loading complete.
Train data: 1500000 rows, 46 columns. 
Test data: 381934 rows, 46 columns.


### Data preparation

In [76]:
# Creates X_train, y_train
X_train = df_train.drop(['Attack_label', 'Attack_type'], axis=1)
y_train = df_train['Attack_type']

# Creates X_test, y_test
X_test = df_test.drop(['Attack_label', 'Attack_type'], axis=1)
y_test = df_test['Attack_type']

#### Convert categorical features to one-hot encoded features

In [77]:
# Encode the training and test labels if needed
X_train_enc, X_test_enc, cat_cols, cat_dims = utils.encode_categorical(X_train, X_test, encoding='label')

Categorical features to be encoded:

dns.qry.name.len
mqtt.conack.flags
http.request.version
http.request.method
mqtt.topic
http.referer
mqtt.protoname

Encoding complete.
No of features before encoding: 44
No of features after encoding: 44


#### Label encoding

In [81]:
y_train_enc, y_test_enc, le = utils.encode_labels(y_train, y_test)

Attack_type and encoded labels:

Backdoor                0
DDoS_HTTP               1
DDoS_ICMP               2
DDoS_TCP                3
DDoS_UDP                4
Fingerprinting          5
MITM                    6
Normal                  7
Password                8
Port_Scanning           9
Ransomware              10
SQL_injection           11
Uploading               12
Vulnerability_scanner   13
XSS                     14


### Model Training

In [82]:
# Shuffle training data
X_train, y_train = shuffle(X_train, y_train, random_state=42)

if AUGMENTATION == 'SMOTE'or AUGMENTATION == 'SMOTE-NC':
    # pytorch_tabnet default parameters
    tabnet = TabNetClassifier()
    
    tabnet.fit(X_train=X_train_enc.values, 
               y_train=y_train_enc,
               augmentations=None,
               max_epochs=100,
               )
else: # AUGMENTATION == 'None', 'RealTabFormer', 'GReaT'

    cat_idxs = [ i for i, f in enumerate(X_train_enc.columns) if f in cat_cols]
    cat_dims = [ cat_dims[f] for i, f in enumerate(X_train_enc.columns) if f in cat_dims]

    tabnet = TabNetClassifier(cat_idxs=cat_idxs,
                              cat_dims=cat_dims,
                              cat_emb_dim=10,    # categorical features embedding dimension
                              )
    tabnet.fit(X_train=X_train_enc.values, 
               y_train=y_train_enc,
               augmentations=None,
               max_epochs=100,
               )



epoch 0  | loss: 0.68884 |  0:01:11s
epoch 1  | loss: 0.52337 |  0:02:15s
epoch 2  | loss: 0.50097 |  0:03:18s
epoch 3  | loss: 0.50112 |  0:04:21s
epoch 4  | loss: 0.49437 |  0:05:24s
epoch 5  | loss: 0.49036 |  0:06:27s
epoch 6  | loss: 0.48705 |  0:07:30s
epoch 7  | loss: 0.48715 |  0:08:33s
epoch 8  | loss: 0.48315 |  0:09:36s
epoch 9  | loss: 0.48321 |  0:10:40s
epoch 10 | loss: 0.47921 |  0:11:42s
epoch 11 | loss: 0.48297 |  0:12:45s
epoch 12 | loss: 0.48542 |  0:13:48s
epoch 13 | loss: 0.49166 |  0:14:56s
epoch 14 | loss: 0.48334 |  0:16:16s
epoch 15 | loss: 0.44547 |  0:17:32s
epoch 16 | loss: 0.43948 |  0:18:51s
epoch 17 | loss: 0.44841 |  0:20:03s
epoch 18 | loss: 0.45203 |  0:21:08s
epoch 19 | loss: 0.44218 |  0:22:13s
epoch 20 | loss: 0.44596 |  0:23:18s
epoch 21 | loss: 0.4444  |  0:24:23s
epoch 22 | loss: 0.44329 |  0:25:28s
epoch 23 | loss: 0.48609 |  0:26:33s
epoch 24 | loss: 0.46481 |  0:27:39s
epoch 25 | loss: 0.45282 |  0:28:44s
epoch 26 | loss: 0.48839 |  0:29:50s
e

#### Save model

In [83]:
saved_filename = tabnet.save_model(f'checkpoints/tabnet/tabnet_{AUGMENTATION}')

Successfully saved model at checkpoints/tabnet/tabnet_RealTabFormer.zip


### Model Evaluation

In [84]:
predictions = tabnet.predict(X_test_enc.values)

#### Metrics

In [85]:
accuracy = metrics.accuracy_score(y_test_enc, predictions)
precision_w = metrics.precision_score(y_test_enc, predictions, average='weighted', zero_division=1)
recall_w = metrics.recall_score(y_test_enc, predictions, average='weighted')
f1_score_w = metrics.f1_score(y_test_enc, predictions, average='weighted')
precision_m = metrics.precision_score(y_test_enc, predictions, average='macro', zero_division=1)
recall_m = metrics.recall_score(y_test_enc, predictions, average='macro')
f1_score_m = metrics.f1_score(y_test_enc, predictions, average='macro')

In [86]:
# Create dictionary for results
results = {
    "model": "TabNet",
    "augmentations": AUGMENTATION,
    "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "accuracy": accuracy,
    "precision_macro": precision_m,
    "recall_macro": recall_m,
    "f1_macro": f1_score_m,
    "precision_weighted": precision_w,
    "recall_weighted": recall_w,
    "f1_weighted": f1_score_w
    }

utils.print_results_table(results)

╒══════════════════════╤═════════╕
│ Metric               │ Value   │
╞══════════════════════╪═════════╡
│ Accuracy             │ 90.89%  │
├──────────────────────┼─────────┤
│ Precision (macro)    │ 70.16%  │
├──────────────────────┼─────────┤
│ Recall (macro)       │ 65.02%  │
├──────────────────────┼─────────┤
│ F1 (macro)           │ 63.40%  │
├──────────────────────┼─────────┤
│ Precision (weighted) │ 92.08%  │
├──────────────────────┼─────────┤
│ Recall (weighted)    │ 90.89%  │
├──────────────────────┼─────────┤
│ F1 (weighted)        │ 90.35%  │
╘══════════════════════╧═════════╛


#### Save Metrics Results 

In [87]:
# save results to csv   
utils.save_results_to_csv([results], '../results/metrics/tabnet.csv')

#### Confusion matrix

In [88]:
conf_mat = metrics.confusion_matrix(y_test_enc, predictions)

attack_labels = ['Backdoor', 'DDoS_HTTP', 'DDoS_ICMP', 'DDoS_TCP', 'DDoS_UDP', 
'Fingerprinting', 'MITM', 'Normal', 'Password', 'Port_Scanning', 'Ransomware', 
'SQL_injection', 'Uploading', 'Vulnerability_scanner', 'XSS']

# Create a dataframe from the confusion matrix
conf_mat_df = pd.DataFrame(conf_mat, 
                           index = attack_labels, 
                           columns = attack_labels)
conf_mat_df.index.name = 'Actual'
conf_mat_df.columns.name = 'Predicted'

# Save the confusion matrix
conf_mat_df.to_csv(f"../results/conf_matrix/{results['model']}_{results['augmentations']}.csv")
conf_mat_df

Predicted,Backdoor,DDoS_HTTP,DDoS_ICMP,DDoS_TCP,DDoS_UDP,Fingerprinting,MITM,Normal,Password,Port_Scanning,Ransomware,SQL_injection,Uploading,Vulnerability_scanner,XSS
Actual,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
Backdoor,4446,0,0,248,1,0,0,0,0,2,85,0,0,0,0
DDoS_HTTP,0,3633,0,0,0,0,0,0,0,0,0,0,0,3201,2794
DDoS_ICMP,0,0,13422,0,25,54,0,0,0,0,0,0,0,0,0
DDoS_TCP,0,0,0,9941,0,0,0,0,0,68,0,0,0,0,0
DDoS_UDP,0,0,0,0,24601,0,0,0,0,0,0,0,0,0,0
Fingerprinting,0,0,0,47,2,96,0,0,0,0,0,0,0,0,1
MITM,0,0,0,0,0,0,76,0,0,0,0,0,0,0,0
Normal,0,0,0,0,0,0,0,272776,0,0,0,0,0,0,0
Password,0,52,0,0,0,0,0,0,8254,0,0,1802,0,0,0
Port_Scanning,0,0,0,4033,0,0,0,0,0,29,0,0,0,0,0
