In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd().replace("model_inference", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.profiler import profiler
import matplotlib.pyplot as plt
import time
from parse_dataset import NetworkDataset, parse_dataset, split_datasets
from split_model import SplitModelDPU, SplitModelHost
from load_models import models

In [2]:
conf = {
    "batch_size": 512,
    "epochs": 10,
    "learning_rate": 0.0001
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
# extract dataset from csv file
network_data_file = os.path.join(os.getcwd().replace("model_inference", ""), "datasets", "network_packet_data_test.csv")

data, labels, label_dict = parse_dataset(network_data_file)

X_train, y_train, X_val, y_val, X_test, y_test = split_datasets(data, labels)

X_train, X_val, X_test = X_train.unsqueeze(-1), X_val.unsqueeze(-1), X_test.unsqueeze(-1)
print(X_train.shape)
print(y_train.shape)

# create train, val and test datasets
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=conf["batch_size"], shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=conf["batch_size"], shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=conf["batch_size"])

packet_dat,attack_cat

torch.Size([356334, 513, 1])
torch.Size([356334])


In [5]:
dpu_path = os.path.join(os.getcwd().replace("model_inference", ""), "checkpoint", "dpu_split_model.pth")
dpu_model = models["dpu"]
dpu_model.load(dpu_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_split_model.pth!


In [6]:
host_path = os.path.join(os.getcwd().replace("model_inference", ""), "checkpoint", "host_split_model.pth")
host_model = models["host"]
host_model.load(host_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_split_model.pth!


In [7]:
# test inference:
data, labels = next(iter(test_loader))
if not data.is_cuda or not labels.is_cuda:
    data, labels = data.to(device), labels.to(device)

In [8]:
dpu_model.model.eval()
pred, logits, targets = dpu_model.inference(data, labels)

In [9]:
print(pred.shape, logits.shape, targets.shape)
print(targets[:10])
print(pred[:10])

torch.Size([512]) torch.Size([430, 1, 128]) torch.Size([430])
tensor([ 0,  0, 14,  0, 12, 10, 15, 10, 19, 14], device='cuda:0')
tensor([1., 1., 1., 1., 1., 1., 0., 1., 1., 0.], device='cuda:0')


In [10]:
host_model.model.eval()
pred = host_model.inference(logits)

In [11]:
print(pred.shape)
print(pred[:10])

torch.Size([430, 24])
tensor([[ 4.6968e+00,  3.1629e+00, -3.6513e+00, -3.6988e+00, -4.2607e+00,
         -4.6674e-01, -2.0412e+00, -4.6080e+00, -3.6925e+00, -3.7131e+00,
         -3.3246e+00, -1.8385e+00, -2.9786e+00, -3.6216e+00, -3.7784e+00,
          4.5936e+00,  2.4731e-02, -3.2962e+00, -1.7925e+00, -3.3871e+00,
         -3.5597e+00, -2.5878e+00, -6.3501e-01, -2.8789e+00],
        [ 6.0477e+00,  3.8143e+00, -3.1752e+00, -4.1990e+00, -5.5945e+00,
         -1.4785e+00, -5.4158e+00, -2.9913e+00, -1.7090e+00, -4.6664e+00,
         -6.5852e+00, -1.7637e+00, -2.1553e+00, -4.0471e+00, -4.9171e+00,
         -1.8685e+00,  2.1936e-01, -4.0545e+00, -2.5911e+00, -2.5408e+00,
         -5.2973e+00, -3.9422e+00, -6.8539e-01, -5.3514e+00],
        [-3.7172e+00, -6.4881e+00,  4.3727e+00,  1.5676e+00, -2.2972e+00,
         -1.5643e+00, -4.5361e+00, -5.7751e+00,  3.6252e+00, -4.1444e+00,
         -3.7361e+00, -2.8719e+00,  2.9702e+00,  2.7930e-01,  4.9404e+00,
         -5.6994e+00, -3.0737e+00, -2.73

In [12]:
dpu_acc, dpu_loss, logits, logits_target = dpu_model.evaluate(test_loader, label_dict)

In [13]:
print(logits[0].shape)
print(logits_target[0].shape)

print(logits_target[:10])

torch.Size([430, 1, 128])
torch.Size([430])
[tensor([ 0,  0, 14,  0, 12, 10, 15, 10, 19, 14, 10,  4,  8,  1,  1, 19, 12,  4,
         0,  3,  0,  3, 11,  0,  1,  1,  0, 12, 14,  4,  0,  4, 12, 14,  0,  4,
         0,  0, 14,  0,  1,  0,  0,  0,  7,  0, 15, 10, 10,  0,  0, 14,  0,  0,
         1,  0,  4, 17,  4,  0,  1,  0, 14,  0,  0,  2,  0,  0,  3,  0,  2,  2,
         0,  0,  2, 10,  1, 14,  0,  0,  0,  0, 16,  0,  4,  1,  1,  4, 10,  1,
        11, 12,  1,  0,  8,  0,  4,  1,  0,  0,  1,  8,  0, 10,  3,  0, 15, 15,
         0,  0,  0,  0,  0,  0,  0,  8,  0,  0,  8,  0,  0,  4,  1,  0,  0,  0,
         0, 12,  8,  4,  7,  0,  0,  1,  0,  0,  0, 10,  0,  0,  2,  0,  4,  3,
         8, 12,  0,  0,  0,  0,  1,  7,  0,  2, 14,  0,  0,  0,  1,  0,  1,  0,
         5,  4,  0,  7,  8,  7,  0,  0,  0, 13,  0,  4,  0,  0,  0,  0,  0,  0,
         0, 10,  0,  0,  0,  0,  0,  4, 10,  1,  0,  0,  0, 12,  0,  8,  0,  2,
         0,  7,  7,  1,  4,  0,  0,  4,  0, 15, 10,  4, 12,  1, 10, 15,  0,

In [14]:
host_acc, host_loss = host_model.evaluate(logits, logits_target)

In [15]:
print(f"DPU model accuracy: {100*dpu_acc:.2f}%, Host model accuracy: {100*host_acc:.2f}%")

DPU model accuracy: 99.67%, Host model accuracy: 91.62%
