In [1]:
import numpy as np
import pandas as pd
import subprocess
import time
import sys
import tenseal as ts
import torch
import time
import pickle
import json
import base64

In [2]:
# Client Side
def create_context(N, q, scale, galois = False):
    poly_mod_degree = N
    coeff_mod_bit_sizes = q
    ctx = ts.context(ts.SCHEME_TYPE.CKKS, poly_mod_degree, -1, coeff_mod_bit_sizes)
    ctx.global_scale = 2 ** scale
    if galois == True:
        ctx.generate_galois_keys()
    return ctx

In [3]:
# Client Side
def load_prepare_input(context, csv_path, model, n = None, m = None):
    # Importing the data
    df = pd.read_csv(csv_path)
    df = df.drop(df[df.Protocol  == 'Protocol'].index)
    cols = ['Flow Duration', 'Total Fwd Packet', 'Total Bwd packets',
            'Total Length of Fwd Packet', 'Total Length of Bwd Packet',
            'Fwd Packet Length Min', 'Fwd Packet Length Mean',
            'Fwd Packet Length Std', 'Flow Bytes/s', 'Flow Packets/s',
            'Flow IAT Std', 'Flow IAT Min', 'Fwd IAT Total', 'Fwd IAT Min',
            'Bwd IAT Min', 'Fwd PSH Flags', 'Fwd URG Flags', 'Bwd Packets/s',
            'Packet Length Max', 'FIN Flag Count', 'RST Flag Count',
            'PSH Flag Count', 'ACK Flag Count', 'URG Flag Count', 'Down/Up Ratio',
            'FWD Init Win Bytes', 'Bwd Init Win Bytes',
            'Fwd Seg Size Min', 'Active Mean', 'Active Std', 'Label']
    org_cols_names = ['Flow Duration', 'Total Fwd Packets', 'Total Backward Packets',
       'Total Length of Fwd Packets', 'Total Length of Bwd Packets',
       'Fwd Packet Length Min', 'Fwd Packet Length Mean',
       'Fwd Packet Length Std', 'Flow Bytes/s', 'Flow Packets/s',
       'Flow IAT Std', 'Flow IAT Min', 'Fwd IAT Total', 'Fwd IAT Min',
       'Bwd IAT Min', 'Fwd PSH Flags', 'Fwd URG Flags', 'Bwd Packets/s',
       'Max Packet Length', 'FIN Flag Count', 'RST Flag Count',
       'PSH Flag Count', 'ACK Flag Count', 'URG Flag Count', 'Down/Up Ratio',
       'Init_Win_bytes_forward', 'Init_Win_bytes_backward',
       'min_seg_size_forward', 'Active Mean', 'Active Std', 'Label']
    df = df[cols]
    
    # Scaling
    df.columns= org_cols_names
    x = df.drop('Label', axis=1)
    scaler_path = r"C:\Users\manig\Downloads\Mitacs\Anomaly-Detection-On-Encrypted-Traffic\Code\scaler.pkl"
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)
    x = scaler.transform(x)
    
    #  If the model we are using is CNN, we have to reshape and perform im2col encoding
    windows_nb = None
    if model == "CNN":
        x = x.reshape(len(df), n, m)
        x = torch.from_numpy(x).float().unsqueeze(1)
        enc_x, windows_nb = ts.im2col_encoding(context, data.view(n, m).tolist(), 3, 3, 1)
    else:
        enc_x = []
        x = torch.from_numpy(x).float()
        for i in range(100):
            enc_x.append(ts.ckks_vector(context, x[i].tolist()))
        

    return enc_x, windows_nb

In [4]:
# def serialize_input(context, enc_x):
#     server_context = context.copy()
#     server_context.make_context_public()
#     server_context = server_context.serialize()
#     for i in range(len(enc_x)):
#         enc_x[i] =  enc_x[i].serialize()
#     encrypted_input = enc_x
    
#     client_data = {
#         "context" : server_context,
#         "data" : encrypted_input
#     }

    
#     return client_data

# Client Side
def serialize_input(context, enc_x):
    server_context = context.copy()
    server_context.make_context_public()
    server_context = base64.b64encode(server_context.serialize()).decode()

    encrypted_input = []
    for i in range(len(enc_x)):
        serialized = base64.b64encode(enc_x[i].serialize()).decode()
        encrypted_input.append(serialized)
    
    client_data = {
        "context" : server_context,
        "data" : encrypted_input
    }

    return json.dumps(client_data)

In [5]:
poly_mod_degree = 8192
coeff_mod_bit_sizes = [60, 40, 60]
scl = 40
ctx = create_context(poly_mod_degree, coeff_mod_bit_sizes, scl, True)

In [6]:
enc_ip, windows_nb = load_prepare_input(ctx, csv_path = r"C:\Users\manig\Downloads\Mitacs\Anomaly-Detection-On-Encrypted-Traffic\main\test_ISCX.csv", model="ANN")

  df = pd.read_csv(csv_path)


In [7]:
send = serialize_input(ctx, enc_ip)

In [8]:
def preapre_client_data(client_json):
    data = json.loads(client_json)  
    server_context = base64.b64decode(data["context"])
    context = ts.context_from(server_context)
    encrypted_input = []
    for enc in data['data']:
        enc = base64.b64decode(enc)
        encrypted_input.append(ts.ckks_vector_from(context, enc))
    return context, encrypted_input

In [10]:
context_no_sk, encrypted_input = preapre_client_data(send)

In [21]:
import models
from pathlib import Path
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)
MODEL_NAME = "lr.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

model = models.LR(30)
model.load_state_dict(torch.load(f=MODEL_SAVE_PATH))

eelr = models.EncryptedLR(model)
server_output = models.encrypted_evaluation(eelr, encrypted_input)
encrypted_output = []
for i in range(len(server_output)):
    serialized = base64.b64encode(server_output[i].serialize()).decode()
    encrypted_output.append(serialized)
server_response = {
    "data" : encrypted_output
}
server_response = json.dumps(server_response)

In [22]:
data = json.loads(server_response)  

In [30]:
encrypted_output = []
for enc in data['data']:
    enc = base64.b64decode(enc)
    encrypted_output.append(ts.ckks_vector_from(ctx, enc).decrypt())
op = torch.sigmoid(torch.tensor(encrypted_output))

In [31]:
op

tensor([[0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0020],
        [0.0020],
        [0.0020],
        [0.0018],
        [0.0018],
        [0.0000],
        [0.0018],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0018],
        [0.0000],
        [0.0018],
        [0.0020],
        [0.0000],
        [0.0000],
        [0.0018],
        [0.0020],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0020],
        [0.0018],
        [0.0020],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0020],
        [0.0020],
        [0.0000],
        [0.0000],
        [0.0018],
        [0.0000],
        [0.0000],
        [0.0018],
        [0.0018],
        [0.0018],
        [0.0000],
        [0.0020],
        [0.0018],
        [0.0000],
        [0.0000],
        [0.0018],
        [0.0020],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0020],
        [0