This will be the user space program to pretrain the ML model to
classify between malicious and normal packets. There are multiple ways
in which an ML model can be deployed and fit in the flow of the packets.

One way is using the XDP_AF sockets and sending data collected about the
packets from kernel space to the user space, where the ML model decides the
fate of the packet, does nothing (for drop) or sends packet to the corresponding
user space application after processing it in the required way. This could
also be implemented in other ways using perf buffer, polling etc.


Another approach would be to implement the ML model within the kernel and that
is the approach currently adopted.

Currently, we are implemented a basic in-kernel very simple (logisticRegression)NN
for classification of the packets. We are going to do quantization-aware-training in
the user space and we will store the quantized weights in a BPF map.

Since the kernel has the following restrictions :
    (1) limitations on the quantity of eBPF instructions and stack space,
    (2) prohibitions on unbounded loops, non-static global variables, variadic functions,
        multi-threaded programming, and floating-point representation, and
    (3) enforcement of array bound checks

We will go for a simple model, which would involve less complex calculations sa well

These weights can be updated as we collect more data and used to train the model
after certain intervals of time. Maybe at the end of every 24 hours. But this feature
will be implemented soon.


In [25]:
import glob
import os
import random
import zipfile

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub

from tqdm import tqdm
from pathlib import Path

# Getting the CICIDS Dataset from Kaggle

In [2]:
!pip install kaggle



In [3]:
mkdir ~/.kaggle

In [4]:
cp kaggle.json ~/.kaggle/

In [5]:
!chmod 600 ~/.kaggle/kaggle.json

In [6]:
!kaggle datasets download -d cicdataset/cicids2017

Downloading cicids2017.zip to /content
 91% 209M/230M [00:03<00:00, 89.7MB/s]
100% 230M/230M [00:03<00:00, 77.4MB/s]


In [7]:
!unzip cicids2017.zip

Archive:  cicids2017.zip
  inflating: MachineLearningCSV.md5  
  inflating: MachineLearningCSV/MachineLearningCVE/Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv  
  inflating: MachineLearningCSV/MachineLearningCVE/Friday-WorkingHours-Afternoon-PortScan.pcap_ISCX.csv  
  inflating: MachineLearningCSV/MachineLearningCVE/Friday-WorkingHours-Morning.pcap_ISCX.csv  
  inflating: MachineLearningCSV/MachineLearningCVE/Monday-WorkingHours.pcap_ISCX.csv  
  inflating: MachineLearningCSV/MachineLearningCVE/Thursday-WorkingHours-Afternoon-Infilteration.pcap_ISCX.csv  
  inflating: MachineLearningCSV/MachineLearningCVE/Thursday-WorkingHours-Morning-WebAttacks.pcap_ISCX.csv  
  inflating: MachineLearningCSV/MachineLearningCVE/Tuesday-WorkingHours.pcap_ISCX.csv  
  inflating: MachineLearningCSV/MachineLearningCVE/Wednesday-workingHours.pcap_ISCX.csv  


# Loading the dataset into a dataframe

In [8]:
# list all csv files only
csv_files = glob.glob('*.{}'.format('csv'))
csv_files

[]

In [9]:
# merging the files

# if local system use PATH = "dataset/archive/MachineLearningCSV/MachineLearningCVE/"
# else if on colab use PATH = "/content/MachineLearningCSV/MachineLearningCVE"
PATH = "/content/MachineLearningCSV/MachineLearningCVE"
joined_files = os.path.join(PATH, "*.csv")

# A list of all joined files is returned
joined_list = glob.glob(joined_files)

# Finally, the files are joined
df_concat = pd.concat(map(pd.read_csv, joined_list), ignore_index=True)
print(df_concat)

          Destination Port   Flow Duration   Total Fwd Packets  \
0                      389       113095465                  48   
1                      389       113473706                  68   
2                        0       119945515                 150   
3                      443        60261928                   9   
4                       53             269                   2   
...                    ...             ...                 ...   
2830738                443          196135                  49   
2830739                443          378424                  49   
2830740                443          161800                  70   
2830741                443          142864                  50   
2830742                443          186928                  46   

          Total Backward Packets  Total Length of Fwd Packets  \
0                             24                         9668   
1                             40                        11364   
2           

In [10]:
df_concat.columns = df_concat.columns.str.strip().str.lower().str.replace(' ', '_').str.replace('(', '').str.replace(')', '')
df_concat.head()

  df_concat.columns = df_concat.columns.str.strip().str.lower().str.replace(' ', '_').str.replace('(', '').str.replace(')', '')
  df_concat.columns = df_concat.columns.str.strip().str.lower().str.replace(' ', '_').str.replace('(', '').str.replace(')', '')


Unnamed: 0,destination_port,flow_duration,total_fwd_packets,total_backward_packets,total_length_of_fwd_packets,total_length_of_bwd_packets,fwd_packet_length_max,fwd_packet_length_min,fwd_packet_length_mean,fwd_packet_length_std,...,min_seg_size_forward,active_mean,active_std,active_max,active_min,idle_mean,idle_std,idle_max,idle_min,label
0,389,113095465,48,24,9668,10012,403,0,201.416667,203.548293,...,32,203985.5,575837.3,1629110,379,13800000.0,4277541.0,16500000,6737603,BENIGN
1,389,113473706,68,40,11364,12718,403,0,167.117647,171.919413,...,32,178326.875,503426.9,1424245,325,13800000.0,4229413.0,16500000,6945512,BENIGN
2,0,119945515,150,0,0,0,0,0,0.0,0.0,...,0,6909777.333,11700000.0,20400000,6,24400000.0,24300000.0,60100000,5702188,BENIGN
3,443,60261928,9,7,2330,4221,1093,0,258.888889,409.702161,...,20,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN
4,53,269,2,2,102,322,51,51,51.0,0.0,...,32,0.0,0.0,0,0,0.0,0.0,0,0,BENIGN


In [11]:
df_labels = df_concat['label']
df_labels.unique()

array(['BENIGN', 'Web Attack � Brute Force', 'Web Attack � XSS',
       'Web Attack � Sql Injection', 'Infiltration', 'FTP-Patator',
       'SSH-Patator', 'Bot', 'DoS slowloris', 'DoS Slowhttptest',
       'DoS Hulk', 'DoS GoldenEye', 'Heartbleed', 'DDoS', 'PortScan'],
      dtype=object)

In [12]:
df_concat.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2830743 entries, 0 to 2830742
Data columns (total 79 columns):
 #   Column                       Dtype  
---  ------                       -----  
 0   destination_port             int64  
 1   flow_duration                int64  
 2   total_fwd_packets            int64  
 3   total_backward_packets       int64  
 4   total_length_of_fwd_packets  int64  
 5   total_length_of_bwd_packets  int64  
 6   fwd_packet_length_max        int64  
 7   fwd_packet_length_min        int64  
 8   fwd_packet_length_mean       float64
 9   fwd_packet_length_std        float64
 10  bwd_packet_length_max        int64  
 11  bwd_packet_length_min        int64  
 12  bwd_packet_length_mean       float64
 13  bwd_packet_length_std        float64
 14  flow_bytes/s                 float64
 15  flow_packets/s               float64
 16  flow_iat_mean                float64
 17  flow_iat_std                 float64
 18  flow_iat_max                 int64  
 19  

In [13]:
df_concat.corr()

  df_concat.corr()


Unnamed: 0,destination_port,flow_duration,total_fwd_packets,total_backward_packets,total_length_of_fwd_packets,total_length_of_bwd_packets,fwd_packet_length_max,fwd_packet_length_min,fwd_packet_length_mean,fwd_packet_length_std,...,act_data_pkt_fwd,min_seg_size_forward,active_mean,active_std,active_max,active_min,idle_mean,idle_std,idle_max,idle_min
destination_port,1.000000,-0.151680,-0.004236,-0.003947,0.011145,-0.003082,0.097926,-0.045388,0.140220,0.128861,...,-0.003226,0.000897,-0.035562,-0.043717,-0.051859,-0.023194,-0.112585,0.010399,-0.108185,-0.114614
flow_duration,-0.151680,1.000000,0.020857,0.019670,0.065456,0.016186,0.273308,-0.105230,0.143689,0.234437,...,0.015942,-0.001357,0.189299,0.241060,0.294034,0.121171,0.768034,0.243154,0.779527,0.738328
total_fwd_packets,-0.004236,0.020857,1.000000,0.999070,0.365508,0.996993,0.009358,-0.002989,0.000032,0.001403,...,0.887387,-0.000184,0.039937,0.008329,0.030459,0.041283,0.001820,0.000809,0.001906,0.001670
total_backward_packets,-0.003947,0.019670,0.999070,1.000000,0.359451,0.994429,0.009039,-0.002600,-0.000333,0.001026,...,0.882566,0.000018,0.038963,0.006437,0.028602,0.041278,0.001425,0.000492,0.001456,0.001330
total_length_of_fwd_packets,0.011145,0.065456,0.365508,0.359451,1.000000,0.353762,0.197030,-0.000275,0.185262,0.159787,...,0.407448,-0.001209,0.101084,0.103326,0.126493,0.068325,0.022660,0.027064,0.026079,0.018634
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
active_min,-0.023194,0.121171,0.041283,0.041278,0.068325,0.039069,0.105641,-0.025912,0.081170,0.094164,...,0.031394,-0.006834,0.905862,0.033874,0.584503,1.000000,0.118133,0.038302,0.122651,0.112880
idle_mean,-0.112585,0.768034,0.001820,0.001425,0.022660,0.000809,0.181135,-0.071304,0.127959,0.183139,...,0.000837,-0.000876,0.120171,0.036551,0.088904,0.118133,1.000000,0.150248,0.990387,0.990215
idle_std,0.010399,0.243154,0.000809,0.000492,0.027064,0.000105,0.178091,-0.029951,0.178462,0.191278,...,0.000721,-0.003720,0.070586,0.081435,0.070002,0.038302,0.150248,1.000000,0.283330,0.011609
idle_max,-0.108185,0.779527,0.001906,0.001456,0.026079,0.000797,0.199559,-0.073419,0.148402,0.203304,...,0.000929,-0.001407,0.132700,0.055300,0.102816,0.122651,0.990387,0.283330,1.000000,0.961812


In [15]:
def clean_df(df):
    # Remove the space before each feature names
    df.columns = df.columns.str.strip()
    print('dataset shape', df.shape)

    # This set of feature should have >= 0 values
    num = df._get_numeric_data()
    num[num < 0] = 0

    zero_variance_cols = []
    for col in df.columns:
        if len(df[col].unique()) == 1:
            zero_variance_cols.append(col)
    df.drop(zero_variance_cols, axis = 1, inplace = True)
    print('zero variance columns', zero_variance_cols, 'dropped')
    print('shape after removing zero variance columns:', df.shape)

    df.replace([np.inf, -np.inf], np.nan, inplace = True)
    print(df.isna().any(axis = 1).sum(), 'rows dropped')
    df.dropna(inplace = True)
    print('shape after removing nan:', df.shape)

    # Drop duplicate rows
    df.drop_duplicates(inplace = True)
    print('shape after dropping duplicates:', df.shape)

    column_pairs = [(i, j) for i, j in combinations(df, 2) if df[i].equals(df[j])]
    ide_cols = []
    for column_pair in column_pairs:
        ide_cols.append(column_pair[1])
    df.drop(ide_cols, axis = 1, inplace = True)
    print('columns which have identical values', column_pairs, 'dropped')
    print('shape after removing identical value columns:', df.shape)
    return df
df_concat = clean_df(df_concat)

dataset shape (2830743, 79)
zero variance columns ['bwd_psh_flags', 'bwd_urg_flags', 'fwd_avg_bytes/bulk', 'fwd_avg_packets/bulk', 'fwd_avg_bulk_rate', 'bwd_avg_bytes/bulk', 'bwd_avg_packets/bulk', 'bwd_avg_bulk_rate'] dropped
shape after removing zero variance columns: (2830743, 71)
2867 rows dropped
shape after removing nan: (2827876, 71)
shape after dropping duplicates: (2520798, 71)
columns which have identical values [('total_fwd_packets', 'subflow_fwd_packets'), ('total_backward_packets', 'subflow_bwd_packets'), ('fwd_psh_flags', 'syn_flag_count'), ('fwd_urg_flags', 'cwe_flag_count'), ('fwd_header_length', 'fwd_header_length.1')] dropped
shape after removing identical value columns: (2520798, 66)


In [16]:
df_concat.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 2520798 entries, 0 to 2830742
Data columns (total 66 columns):
 #   Column                       Dtype  
---  ------                       -----  
 0   destination_port             int64  
 1   flow_duration                int64  
 2   total_fwd_packets            int64  
 3   total_backward_packets       int64  
 4   total_length_of_fwd_packets  int64  
 5   total_length_of_bwd_packets  int64  
 6   fwd_packet_length_max        int64  
 7   fwd_packet_length_min        int64  
 8   fwd_packet_length_mean       float64
 9   fwd_packet_length_std        float64
 10  bwd_packet_length_max        int64  
 11  bwd_packet_length_min        int64  
 12  bwd_packet_length_mean       float64
 13  bwd_packet_length_std        float64
 14  flow_bytes/s                 float64
 15  flow_packets/s               float64
 16  flow_iat_mean                float64
 17  flow_iat_std                 float64
 18  flow_iat_max                 int64  
 19  

In [17]:
unique_vals = df_concat['label'].unique()
df_concat['label'].replace(to_replace=unique_vals,
           value= list(range(len(unique_vals))),
           inplace=True)

In [18]:
mask = df_concat['label'] != 0
df_concat.loc[mask, 'label'] = 1

In [19]:
df_concat.describe()

Unnamed: 0,destination_port,flow_duration,total_fwd_packets,total_backward_packets,total_length_of_fwd_packets,total_length_of_bwd_packets,fwd_packet_length_max,fwd_packet_length_min,fwd_packet_length_mean,fwd_packet_length_std,...,min_seg_size_forward,active_mean,active_std,active_max,active_min,idle_mean,idle_std,idle_max,idle_min,label
count,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,...,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0,2520798.0
mean,8690.59,16591610.0,10.28174,11.5728,611.9477,18144.4,231.2292,19.20349,63.50497,77.32347,...,25.8855,91578.47,46191.77,172017.1,65463.59,9337367.0,565794.1,9763770.0,8892671.0,0.1688914
std,19012.8,35232760.0,794.4201,1056.922,10588.27,2398177.0,756.3755,60.79834,195.5526,296.8814,...,6.525341,686650.3,416584.4,1085571.0,611158.5,24848180.0,4874169.0,25617460.0,24581430.0,0.374656
min,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,53.0,208.0,2.0,1.0,12.0,6.0,6.0,0.0,6.0,0.0,...,20.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,80.0,50622.0,2.0,2.0,66.0,156.0,40.0,2.0,36.25,0.0,...,20.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,443.0,5333340.0,6.0,5.0,332.0,997.0,202.0,37.0,52.0,74.1928,...,32.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,65535.0,120000000.0,219759.0,291922.0,12900000.0,655453000.0,24820.0,2325.0,5940.857,7125.597,...,138.0,110000000.0,74200000.0,110000000.0,110000000.0,120000000.0,76900000.0,120000000.0,120000000.0,1.0


In [20]:
df_concat.head()

Unnamed: 0,destination_port,flow_duration,total_fwd_packets,total_backward_packets,total_length_of_fwd_packets,total_length_of_bwd_packets,fwd_packet_length_max,fwd_packet_length_min,fwd_packet_length_mean,fwd_packet_length_std,...,min_seg_size_forward,active_mean,active_std,active_max,active_min,idle_mean,idle_std,idle_max,idle_min,label
0,389,113095465,48,24,9668,10012,403,0,201.416667,203.548293,...,32,203985.5,575837.3,1629110,379,13800000.0,4277541.0,16500000,6737603,0
1,389,113473706,68,40,11364,12718,403,0,167.117647,171.919413,...,32,178326.875,503426.9,1424245,325,13800000.0,4229413.0,16500000,6945512,0
2,0,119945515,150,0,0,0,0,0,0.0,0.0,...,0,6909777.333,11700000.0,20400000,6,24400000.0,24300000.0,60100000,5702188,0
3,443,60261928,9,7,2330,4221,1093,0,258.888889,409.702161,...,20,0.0,0.0,0,0,0.0,0.0,0,0,0
4,53,269,2,2,102,322,51,51,51.0,0.0,...,32,0.0,0.0,0,0,0.0,0.0,0,0,0


In [21]:
feature_list = ['destination_port', 'packet_length_mean','packet_length_std','packet_length_variance','average_packet_size','fwd_iat_mean','fwd_iat_std','fwd_iat_max']

In [22]:
from sklearn.model_selection import train_test_split

X = df_concat[feature_list]
y = df_concat['label']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [23]:
class LogisticRegression(torch.nn.Module):

    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        y = self.linear(x)
        y = torch.sigmoid(y)
        y = self.dequant(y)
        return y

In [24]:
def evaluate(model, data, criterion):
    loss = 0.0
    with torch.no_grad():
        for (x, y_target) in data:
            y = model(x)
            loss += criterion(y, y_target)
    return loss

In [26]:
X_train = torch.tensor(X_train.values, dtype=torch.float)
y_train = torch.tensor(y_train.values, dtype=torch.float)
X_test = torch.tensor(X_test.values, dtype=torch.float)
y_test = torch.tensor(y_test.values, dtype=torch.float)

In [27]:
del csv_files, df_concat, X, feature_list, joined_files, joined_list, mask, unique_vals, y, df_labels

In [28]:
_DIM_INPUT = 8  # 8
_DIM_OUTPUT = 1  # it's a binary classifier
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [29]:
model = LogisticRegression(_DIM_INPUT, _DIM_OUTPUT)

In [30]:
# Insert min-max observers in the model

model.qconfig = torch.ao.quantization.default_qconfig
model.train()
model_quantized = torch.ao.quantization.prepare_qat(model) # Insert observers
print(model_quantized)

LogisticRegression(
  (linear): Linear(
    in_features=8, out_features=1, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (dequant): DeQuantStub()
)


In [31]:
def train(x_data, y_data,model):

    criterion = torch.nn.BCELoss(reduction="sum")

    optimizer = torch.optim.Adagrad(model.parameters())

    for epoch in tqdm(range(1000)):
        model.train()
        optimizer.zero_grad()
        # Forward pass
        y_pred = torch.reshape(model(x_data),(-1,))
        # Compute Loss
        loss = criterion(y_pred, y_data)
        # Backward pass
        loss.backward()
        optimizer.step()
        if epoch % 10 == 0:
            print('epoch {}, loss {}'.format(epoch, loss.item() / len(x_data)))
    return
train(X_train, y_train,model_quantized)

  0%|          | 2/1000 [00:00<02:25,  6.84it/s]

epoch 0, loss 70.5285053638779


  1%|          | 12/1000 [00:01<03:01,  5.45it/s]

epoch 10, loss 37.20277808907697


  2%|▏         | 22/1000 [00:03<02:28,  6.58it/s]

epoch 20, loss 33.85971701415921


  3%|▎         | 31/1000 [00:05<03:33,  4.54it/s]

epoch 30, loss 33.6054919127776


  4%|▍         | 41/1000 [00:07<02:50,  5.63it/s]

epoch 40, loss 33.556832708696355


  5%|▌         | 52/1000 [00:08<02:22,  6.64it/s]

epoch 50, loss 33.48142006646706


  6%|▌         | 62/1000 [00:10<02:23,  6.54it/s]

epoch 60, loss 33.4035082151581


  7%|▋         | 72/1000 [00:12<02:23,  6.49it/s]

epoch 70, loss 33.35253625092852


  8%|▊         | 82/1000 [00:13<02:20,  6.55it/s]

epoch 80, loss 33.31947528510322


  9%|▉         | 92/1000 [00:15<02:19,  6.53it/s]

epoch 90, loss 32.765821133986364


 10%|█         | 101/1000 [00:16<03:16,  4.57it/s]

epoch 100, loss 32.70620904693852


 11%|█         | 112/1000 [00:18<02:18,  6.40it/s]

epoch 110, loss 32.657668852813444


 12%|█▏        | 122/1000 [00:20<02:11,  6.67it/s]

epoch 120, loss 32.60026440045264


 13%|█▎        | 132/1000 [00:21<02:12,  6.55it/s]

epoch 130, loss 32.540626527914284


 14%|█▍        | 142/1000 [00:23<02:09,  6.61it/s]

epoch 140, loss 32.47997707074845


 15%|█▌        | 152/1000 [00:24<02:10,  6.49it/s]

epoch 150, loss 32.38985876493451


 16%|█▌        | 162/1000 [00:26<02:07,  6.60it/s]

epoch 160, loss 32.24763591680807


 17%|█▋        | 171/1000 [00:27<02:29,  5.55it/s]

epoch 170, loss 32.06409479539709


 18%|█▊        | 182/1000 [00:30<02:24,  5.65it/s]

epoch 180, loss 31.81753393519313


 19%|█▉        | 192/1000 [00:31<02:08,  6.28it/s]

epoch 190, loss 31.64137143106497


 20%|██        | 202/1000 [00:33<02:05,  6.34it/s]

epoch 200, loss 31.443882342790328


 21%|██        | 212/1000 [00:34<02:05,  6.28it/s]

epoch 210, loss 31.22230167238741


 22%|██▏       | 222/1000 [00:36<02:01,  6.38it/s]

epoch 220, loss 30.94448284719419


 23%|██▎       | 232/1000 [00:38<02:02,  6.25it/s]

epoch 230, loss 30.712597898085825


 24%|██▍       | 241/1000 [00:40<03:49,  3.31it/s]

epoch 240, loss 30.250111323896505


 25%|██▌       | 252/1000 [00:43<02:13,  5.61it/s]

epoch 250, loss 29.251998623451506


 26%|██▌       | 262/1000 [00:44<02:02,  6.01it/s]

epoch 260, loss 25.821241095327967


 27%|██▋       | 272/1000 [00:47<02:15,  5.38it/s]

epoch 270, loss 14.777189560049944


 28%|██▊       | 281/1000 [00:48<02:39,  4.50it/s]

epoch 280, loss 14.705113163592078


 29%|██▉       | 292/1000 [00:51<02:15,  5.24it/s]

epoch 290, loss 14.668703059250099


 30%|███       | 301/1000 [00:53<02:38,  4.42it/s]

epoch 300, loss 14.623171833516972


 31%|███       | 312/1000 [00:56<01:59,  5.77it/s]

epoch 310, loss 14.52927396984486


 32%|███▏      | 322/1000 [00:57<01:49,  6.17it/s]

epoch 320, loss 14.43533048568955


 33%|███▎      | 332/1000 [00:59<01:46,  6.25it/s]

epoch 330, loss 14.320708029899269


 34%|███▍      | 342/1000 [01:00<01:46,  6.16it/s]

epoch 340, loss 14.171011356525067


 35%|███▌      | 352/1000 [01:02<01:44,  6.18it/s]

epoch 350, loss 13.862751768041662


 36%|███▌      | 362/1000 [01:04<01:47,  5.96it/s]

epoch 360, loss 13.224065994987697


 37%|███▋      | 371/1000 [01:06<02:28,  4.23it/s]

epoch 370, loss 14.562856595978058


 38%|███▊      | 382/1000 [01:08<01:41,  6.12it/s]

epoch 380, loss 14.514981865857928


 39%|███▉      | 392/1000 [01:09<01:36,  6.28it/s]

epoch 390, loss 14.46095828800211


 40%|████      | 402/1000 [01:11<01:35,  6.29it/s]

epoch 400, loss 14.402442084300702


 41%|████      | 412/1000 [01:12<01:32,  6.33it/s]

epoch 410, loss 14.333937970027343


 42%|████▏     | 422/1000 [01:14<01:31,  6.31it/s]

epoch 420, loss 14.25744233719686


 43%|████▎     | 432/1000 [01:16<01:32,  6.13it/s]

epoch 430, loss 14.169250009173684


 44%|████▍     | 441/1000 [01:18<02:06,  4.42it/s]

epoch 440, loss 14.04322044908407


 45%|████▌     | 452/1000 [01:20<01:34,  5.82it/s]

epoch 450, loss 13.838942834559301


 46%|████▌     | 462/1000 [01:21<01:28,  6.10it/s]

epoch 460, loss 13.297397946483207


 47%|████▋     | 472/1000 [01:23<01:25,  6.21it/s]

epoch 470, loss 14.298167544199801


 48%|████▊     | 482/1000 [01:24<01:23,  6.23it/s]

epoch 480, loss 14.23501292745649


 49%|████▉     | 492/1000 [01:26<01:24,  6.01it/s]

epoch 490, loss 17.729147224241533


 50%|█████     | 502/1000 [01:28<01:20,  6.15it/s]

epoch 500, loss 16.351274745393074


 51%|█████     | 511/1000 [01:30<01:49,  4.47it/s]

epoch 510, loss 15.804305978564324


 52%|█████▏    | 522/1000 [01:32<01:19,  6.01it/s]

epoch 520, loss 15.575151316200527


 53%|█████▎    | 532/1000 [01:33<01:20,  5.83it/s]

epoch 530, loss 15.121288005085692


 54%|█████▍    | 542/1000 [01:35<01:12,  6.32it/s]

epoch 540, loss 15.053774648697486


 55%|█████▌    | 552/1000 [01:36<01:13,  6.11it/s]

epoch 550, loss 15.686259011285118


 56%|█████▌    | 562/1000 [01:38<01:11,  6.13it/s]

epoch 560, loss 15.293605495879776


 57%|█████▋    | 572/1000 [01:40<01:08,  6.23it/s]

epoch 570, loss 15.916554185728922


 58%|█████▊    | 581/1000 [01:42<01:35,  4.40it/s]

epoch 580, loss 15.810994338101334


 59%|█████▉    | 592/1000 [01:44<01:07,  6.08it/s]

epoch 590, loss 15.672515344846225


 60%|██████    | 602/1000 [01:45<01:02,  6.33it/s]

epoch 600, loss 15.429676521021621


 61%|██████    | 612/1000 [01:47<01:03,  6.07it/s]

epoch 610, loss 14.945360545620979


 62%|██████▏   | 622/1000 [01:48<00:58,  6.41it/s]

epoch 620, loss 13.689115250233309


 63%|██████▎   | 632/1000 [01:50<01:01,  5.98it/s]

epoch 630, loss 13.252567887741876


 64%|██████▍   | 642/1000 [01:52<00:56,  6.31it/s]

epoch 640, loss 13.974018143067818


 65%|██████▌   | 651/1000 [01:53<01:15,  4.61it/s]

epoch 650, loss 13.900658422582536


 66%|██████▌   | 662/1000 [01:55<00:56,  5.95it/s]

epoch 660, loss 13.648122270828974


 67%|██████▋   | 672/1000 [01:57<00:53,  6.08it/s]

epoch 670, loss 13.233830761891822


 68%|██████▊   | 682/1000 [01:59<00:49,  6.39it/s]

epoch 680, loss 13.985917155186007


 69%|██████▉   | 692/1000 [02:00<00:48,  6.30it/s]

epoch 690, loss 13.87649840972946


 70%|███████   | 702/1000 [02:02<00:49,  6.03it/s]

epoch 700, loss 13.69353547835556


 71%|███████   | 712/1000 [02:04<00:47,  6.07it/s]

epoch 710, loss 13.34591533036668


 72%|███████▏  | 721/1000 [02:05<01:01,  4.54it/s]

epoch 720, loss 13.891288372033056


 73%|███████▎  | 732/1000 [02:08<00:45,  5.95it/s]

epoch 730, loss 13.72844010675193


 74%|███████▍  | 742/1000 [02:09<00:44,  5.80it/s]

epoch 740, loss 13.441529912656609


 75%|███████▌  | 752/1000 [02:11<00:55,  4.49it/s]

epoch 750, loss 13.907771250963236


 76%|███████▌  | 762/1000 [02:13<00:39,  5.96it/s]

epoch 760, loss 13.765982789176839


 77%|███████▋  | 772/1000 [02:15<00:37,  6.04it/s]

epoch 770, loss 13.538207650555032


 78%|███████▊  | 782/1000 [02:16<00:37,  5.77it/s]

epoch 780, loss 13.126168405038484


 79%|███████▉  | 792/1000 [02:19<00:42,  4.91it/s]

epoch 790, loss 13.685376354110158


 80%|████████  | 801/1000 [02:20<00:36,  5.48it/s]

epoch 800, loss 13.388714285855965


 81%|████████  | 812/1000 [02:22<00:31,  5.91it/s]

epoch 810, loss 15.473433506658111


 82%|████████▏ | 822/1000 [02:24<00:34,  5.20it/s]

epoch 820, loss 15.158871349245626


 83%|████████▎ | 832/1000 [02:26<00:29,  5.64it/s]

epoch 830, loss 15.469661882797011


 84%|████████▍ | 842/1000 [02:28<00:26,  6.04it/s]

epoch 840, loss 15.164658208364616


 85%|████████▌ | 852/1000 [02:31<00:46,  3.16it/s]

epoch 850, loss 15.309044062444524


 86%|████████▌ | 862/1000 [02:33<00:22,  6.07it/s]

epoch 860, loss 15.586776605419516


 87%|████████▋ | 872/1000 [02:34<00:21,  6.06it/s]

epoch 870, loss 15.398397729290036


 88%|████████▊ | 882/1000 [02:36<00:19,  5.97it/s]

epoch 880, loss 15.118782845508218


 89%|████████▉ | 892/1000 [02:38<00:17,  6.13it/s]

epoch 890, loss 15.398587153470281


 90%|█████████ | 902/1000 [02:39<00:17,  5.74it/s]

epoch 900, loss 15.102529060743674


 91%|█████████ | 911/1000 [02:41<00:19,  4.57it/s]

epoch 910, loss 14.589971030993167


 92%|█████████▏| 922/1000 [02:43<00:14,  5.53it/s]

epoch 920, loss 13.288312528078912


 93%|█████████▎| 932/1000 [02:46<00:14,  4.68it/s]

epoch 930, loss 13.711385980032112


 94%|█████████▍| 942/1000 [02:48<00:11,  5.22it/s]

epoch 940, loss 13.489883657850344


 95%|█████████▌| 951/1000 [02:49<00:08,  5.75it/s]

epoch 950, loss 13.132460064721581


 96%|█████████▌| 961/1000 [02:51<00:07,  5.17it/s]

epoch 960, loss 13.604935541232487


 97%|█████████▋| 971/1000 [02:54<00:07,  3.91it/s]

epoch 970, loss 13.343218763109691


 98%|█████████▊| 982/1000 [02:56<00:03,  5.44it/s]

epoch 980, loss 13.552981744864473


 99%|█████████▉| 992/1000 [02:58<00:01,  5.42it/s]

epoch 990, loss 13.282826169099263


100%|██████████| 1000/1000 [02:59<00:00,  5.56it/s]


In [32]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

In [33]:
print_size_of_model(model_quantized)
print(f'Check statistics of the various layers')
print(model_quantized)

Size (KB): 4.194
Check statistics of the various layers
LogisticRegression(
  (linear): Linear(
    in_features=8, out_features=1, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.275702565908432, max_val=0.33877867460250854)
    (activation_post_process): MinMaxObserver(min_val=-33557856.0, max_val=17030178.0)
  )
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=0.0, max_val=120000000.0)
  )
  (dequant): DeQuantStub()
)


In [37]:
def acc(model_quantized):
    y_pred = model_quantized(X_test)
    num_correct = 0
    for i in range(len(y_pred)) :
        if y_pred[i] > 0.5 and y_test[i] == 1:
            num_correct+=1
        elif y_pred[i] <= 0.5 and y_test[i] == 0:
            num_correct+=1

    print("\nTest on %d samples: %d malicious pkts, predicted correctly %d or %.2f%%\n" % (\
        len(y_test), y_test.sum(), num_correct, num_correct * 100.0 / len(y_test)))

# testing before quantization

In [38]:
acc(model_quantized)


Test on 504160 samples: 85403 malicious pkts, predicted correctly 404926 or 80.32%



# Quantize the model using the statistics collected

In [39]:
model.eval()
model_quantized = torch.ao.quantization.convert(model_quantized)
print(f'Check statistics of the various layers')
print(model_quantized)

Check statistics of the various layers
LogisticRegression(
  (linear): QuantizedLinear(in_features=8, out_features=1, scale=398330.96875, zero_point=84, qscheme=torch.per_tensor_affine)
  (quant): Quantize(scale=tensor([944881.8750]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)


# Print weights and size of the model after quantization

In [40]:
print('Weights after quantization')
print(torch.int_repr(model_quantized.linear.weight()))
print(model_quantized.linear.bias())
print("Size after quantization")
print_size_of_model(model_quantized)

Weights after quantization
tensor([[  0, -80, 106,  -9, -85, -52, 106, -45]], dtype=torch.int8)
Parameter containing:
tensor([0.0278], requires_grad=True)
Size after quantization
Size (KB): 2.854


# testing after quantization


In [41]:
acc(model_quantized)


Test on 504160 samples: 85403 malicious pkts, predicted correctly 418576 or 83.02%



## Saving the model weights


In [None]:
torch.save(model.state_dict(), './src/model_weights.pth')

w = (torch.load('./src/model_weights.pth'))
print(w)

In [None]:
for i in w:
  print("Key:",i)
  print("Value:",(w[i]))