### Libraries

In [71]:
'generic imports'
import os
import pandas as pd
import sys
sys.path.append(os.path.abspath('..'))
from src import utils
from psutil import virtual_memory      # memory usage

'machine learning imports'
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split

### GPU

In [72]:
gpu = !nvidia-smi --query-gpu=gpu_name --format=csv,noheader
ram_gb = virtual_memory().total / 1e9
print(f'{gpu.s} with {round(ram_gb,1)} GB available RAM.')
!nvcc --version

'nvidia-smi' is not recognized as an internal or external command, operable program or batch file. with 8.5 GB available RAM.


'nvcc' is not recognized as an internal or external command,
operable program or batch file.


### Load Data

In [73]:
data_dir = os.path.abspath('../data')

# Non-augmented dataset
df_train = pd.read_csv(os.path.join(data_dir, 'EdgeIIot_train_100k.csv'), low_memory=False)
AUGMENTATION = 'None'

# SMOTE augmented dataset
# df_train = pd.read_csv(os.path.join(data_dir, 'EdgeIIot_train_100k_SMOTE.csv'), low_memory=False)
# AUGMENTATION = 'SMOTE'

# SMOTE-NC augmented dataset
# df_train = pd.read_csv(os.path.join(data_dir, 'EdgeIIot_train_100k_SMOTE_NC.csv'), low_memory=False)
# AUGMENTATION = 'SMOTE-NC'

# RealTabFormer augmentation dataset
# df_train = pd.read_csv(os.path.join(data_dir, 'EdgeIIot_train_100k_RealTabFormer.csv'), low_memory=False)
# AUGMENTATION = 'RealTabFormer'

# GReaT augmentation dataset
# df_train = pd.read_csv(os.path.join(data_dir, 'EdgeIIot_train_100k_GReaT.csv'), low_memory=False)
# AUGMENTATION = 'GReaT'

# Test data for all datasets
df_test = pd.read_csv(os.path.join(data_dir, 'EdgeIIot_test.csv'), low_memory=False)

In [74]:
# list columns of df_train
df_train.columns

Index(['arp.opcode', 'arp.hw.size', 'icmp.checksum', 'icmp.seq_le',
       'icmp.unused', 'http.content_length', 'http.request.method',
       'http.referer', 'http.request.version', 'http.response',
       'http.tls_port', 'tcp.ack', 'tcp.ack_raw', 'tcp.checksum',
       'tcp.connection.fin', 'tcp.connection.rst', 'tcp.connection.syn',
       'tcp.connection.synack', 'tcp.flags', 'tcp.flags.ack', 'tcp.len',
       'tcp.seq', 'udp.stream', 'udp.time_delta', 'dns.qry.name',
       'dns.qry.name.len', 'dns.qry.qu', 'dns.qry.type', 'dns.retransmission',
       'dns.retransmit_request', 'dns.retransmit_request_in',
       'mqtt.conack.flags', 'mqtt.conflag.cleansess', 'mqtt.conflags',
       'mqtt.hdrflags', 'mqtt.len', 'mqtt.msg_decoded_as', 'mqtt.msgtype',
       'mqtt.proto_len', 'mqtt.protoname', 'mqtt.topic', 'mqtt.topic_len',
       'mqtt.ver', 'mbtcp.len', 'mbtcp.trans_id', 'mbtcp.unit_id',
       'Attack_label', 'Attack_type'],
      dtype='object')

### Data Preparation

In [75]:
# Drop columns mbtcp.unit_id and mbtcp.trans_id from train and test data    
df_train = df_train.drop(['mbtcp.unit_id', 'mbtcp.trans_id'], axis=1)
df_test = df_test.drop(['mbtcp.unit_id', 'mbtcp.trans_id'], axis=1)

# Creates X_train, y_train
X_train = df_train.drop(['Attack_label', 'Attack_type'], axis=1)
y_train = df_train['Attack_type']

# Creates X_test, y_test
X_test = df_test.drop(['Attack_label', 'Attack_type'], axis=1)
y_test = df_test['Attack_type']

In [76]:
# Extract categorical features
categorical_features = X_train.select_dtypes(include="object").columns

# Get the unique values of all categorical columns
for col in X_train[categorical_features].columns:
        unique_values = X_train[col].unique()
        print(f'{col}: \n{unique_values}\n')

http.request.method: 
['0.0' '0' 'GET' 'POST' 'TRACE' 'OPTIONS' 'SEARCH' 'PROPFIND' 'PUT']

http.referer: 
['0.0' '0' '127.0.0.1'
 '() { _; } >_[$($())] { echo 93e4r0-CVE-2014-6278: true; echo;echo; }'
 'TESTING_PURPOSES_ONLY']

http.request.version: 
['0.0' '0' 'HTTP/1.1' 'HTTP/1.0' 'script>alert(1)/script><\\" HTTP/1.1'
 '/etc/passwd|?data=Download HTTP/1.1'
 '-al&ABSOLUTE_PATH_STUDIP=http://cirt.net/rfiinc.txt?? HTTP/1.1'
 '-al&_PHPLIB[libdir]=http://cirt.net/rfiinc.txt?? HTTP/1.1' '-a HTTP/1.1'
 'Src=javascript:alert(\'Vulnerable\')><Img Src=\\" HTTP/1.1'
 "name=a><input name=i value=XSS>&lt;script>alert('Vulnerable')</script> HTTP/1.1"
 'By Dr HTTP/1.1' '> HTTP/1.1']

dns.qry.name.len: 
['0.0' '1.0' '0' '2.debian.pool.ntp.org' '1.debian.pool.ntp.org'
 '3.debian.pool.ntp.org' '0.debian.pool.ntp.org' 'raspberrypi.local']

mqtt.conack.flags: 
['0.0' '0' '0x00000000' '1574358' '1461589' '1461383' '1574359']

mqtt.protoname: 
['0.0' '0' 'MQTT']

mqtt.topic: 
['0.0' '0' 'Temperature_and

In [77]:
# Concatenate X_train and X_test
X_comb = pd.concat([X_train[categorical_features], X_test[categorical_features]], axis=0)

# Apply one-hot encoding (get_dummies)
X_comb_enc = pd.get_dummies(X_comb)

# Split back into X_train and X_test
X_train_enc, X_test_enc = train_test_split(
    X_comb_enc, test_size=len(X_test), random_state=42)

# Print the shape of X_train_enc and X_test_enc
print(f'X_train_enc shape: {X_train_enc.shape}, X_test_enc shape: {X_test_enc.shape}')

X_train_enc shape: (536515, 53), X_test_enc shape: (381934, 53)


In [78]:
# Drop columns categorical_features from X_train and X_test 

# VER QU COLUNAS TÊM NOMES IGUAIS

X_train = X_train.drop(categorical_features, axis=1)
X_test = X_test.drop(categorical_features, axis=1)

In [80]:
# Print the shape of X_train and X_test
print(f'X_train shape: {X_train.shape}, X_test shape: {X_test.shape}')

X_train shape: (536515, 37), X_test shape: (381934, 37)


In [81]:
# Concatenate X_train and X_train_enc
X_train = pd.concat([X_train, X_train_enc], axis=1)


InvalidIndexError: Reindexing only valid with uniquely valued Index objects

In [82]:
X_train_enc

Unnamed: 0,http.request.method_0,http.request.method_0.0,http.request.method_GET,http.request.method_OPTIONS,http.request.method_POST,http.request.method_PROPFIND,http.request.method_PUT,http.request.method_SEARCH,http.request.method_TRACE,http.referer_() { _; } >_[$($())] { echo 93e4r0-CVE-2014-6278: true; echo;echo; },...,mqtt.conack.flags_1461589,mqtt.conack.flags_1461591,mqtt.conack.flags_1574358,mqtt.conack.flags_1574359,mqtt.protoname_0,mqtt.protoname_0.0,mqtt.protoname_MQTT,mqtt.topic_0,mqtt.topic_0.0,mqtt.topic_Temperature_and_Humidity
370188,False,True,False,False,False,False,False,False,False,False,...,False,False,False,False,False,True,False,False,True,False
152099,False,True,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,False,True,False,False
342773,True,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,True,False,False,True,False
486881,False,True,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,False,True,False,False
143248,False,True,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,False,True,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
259178,True,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,True,False,False,True,False
365838,True,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,True,False,False,True,False
131932,False,True,False,False,False,False,False,False,False,False,...,False,False,False,False,False,True,False,False,True,False
134640,False,True,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,False,True,False,False


#### Label Encoding

In [28]:
# instantiate the label encoder
le = LabelEncoder()

# fit and encode the training labels
y_train = le.fit_transform(y_train)

# encode the test labels
y_test = le.transform(y_test)

print('Attack_type and encoded labels:\n')
for i, label in enumerate(le.classes_):
    print(f'{label:23s} {i:d}')

Attack_type and encoded labels:

Backdoor                0
DDoS_HTTP               1
DDoS_ICMP               2
DDoS_TCP                3
DDoS_UDP                4
Fingerprinting          5
MITM                    6
Normal                  7
Password                8
Port_Scanning           9
Ransomware              10
SQL_injection           11
Uploading               12
Vulnerability_scanner   13
XSS                     14
