In [1]:
import os
import sys
import json
sys.path.append(os.path.join(os.getcwd().replace("split_models", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from parse_dataset import NetworkDataset
from model import SplitModel

In [None]:
conf = {
    "batch_size": 512,
    "epochs": 20,
    "learning_rate": 0.0001,
    "model": "mlp",
    "train_dpu": False,
    "train_host": False
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
dataset_path = os.path.join(os.getcwd().replace("split_models", ""), "datasets")

# load label_dict
json_file = os.path.join(dataset_path, "label_index.json")
with open(json_file, 'r') as file:
    label_dict = json.load(file)

# load train, val and test datasets
train_dataset_file = os.path.join(dataset_path, "train_dataset.pt")
X_train, y_train = torch.load(train_dataset_file)

val_dataset_file = os.path.join(dataset_path, "val_dataset.pt")
X_val, y_val = torch.load(val_dataset_file)

test_dataset_file = os.path.join(dataset_path, "test_dataset.pt")
X_test, y_test = torch.load(test_dataset_file)

if conf["model"] != "mlp" and conf["model"] != "light_mlp":
    X_train, X_val, X_test = X_train.unsqueeze(-1), X_val.unsqueeze(-1), X_test.unsqueeze(-1)
    
print(X_train.shape)
print(y_train.shape)

# create train, val and test datasets
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=conf["batch_size"], shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=conf["batch_size"], shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=conf["batch_size"])

torch.Size([356334, 513])
torch.Size([356334])


In [5]:
class SplitMLP_DPU(nn.Module):
    def __init__(self, size):
        super(SplitMLP_DPU, self).__init__()
        self.input_size = size
        self.mlp = nn.Sequential(
            nn.Linear(size, 320),
            nn.BatchNorm1d(320),
            nn.ReLU(),
            nn.Dropout(p=0.10),
        )

        self.out = nn.Linear(320, 24)

    def forward(self, x):
        logits = self.mlp(x)
        out = self.out(logits)
        return logits, out
        

In [6]:
class SplitMLP_Host(nn.Module):
    def __init__(self, size=320):
        super(SplitMLP_Host, self).__init__()
        self.input_size = size
        self.mlp = nn.Sequential(
            nn.Linear(size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.10),

            nn.Linear(128, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.10),

            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(p=0.10),
            
            nn.Linear(64, 24)
        )

    def forward(self, x):
        out = self.mlp(x)
        return None, out

In [7]:
# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.1, alpha=0.9, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [8]:
torch.manual_seed(42)
dpu_mlp_model = SplitMLP_DPU(513).to(device=device)
dpu_mlp_criterion = FocalLoss()
dpu_mlp_optimizer = torch.optim.AdamW(dpu_mlp_model.parameters(), lr=conf["learning_rate"], weight_decay=0.01)
dpu_mlp_scheduler = torch.optim.lr_scheduler.ExponentialLR(dpu_mlp_optimizer, 0.9)
dpu_mlp = SplitModel(dpu_mlp_model, dpu_mlp_criterion, dpu_mlp_optimizer, dpu_mlp_scheduler, device)

In [9]:
dpu_path = os.path.join(os.getcwd().replace("split_models", ""), "checkpoint", "dpu_mlp_model.pth")
if conf["train_dpu"]:
    acc, train_loss, val_loss = dpu_mlp.train(train_loader, val_loader, conf["epochs"])
    dpu_mlp.save(dpu_path)

Epoch: 1/20, Accuracy: 50.04%, Train loss: 1.4713, Val loss: 1.5798
Epoch: 2/20, Accuracy: 64.03%, Train loss: 0.8306, Val loss: 1.1392
Epoch: 3/20, Accuracy: 63.58%, Train loss: 0.6459, Val loss: 1.0531
Epoch: 4/20, Accuracy: 73.74%, Train loss: 0.5491, Val loss: 0.6675
Epoch: 5/20, Accuracy: 75.00%, Train loss: 0.4898, Val loss: 0.6006
Epoch: 6/20, Accuracy: 71.33%, Train loss: 0.4288, Val loss: 0.6749
Epoch: 7/20, Accuracy: 77.18%, Train loss: 0.4027, Val loss: 0.4766
Epoch: 8/20, Accuracy: 78.31%, Train loss: 0.3629, Val loss: 0.4563
Epoch: 9/20, Accuracy: 80.48%, Train loss: 0.3425, Val loss: 0.3869
Epoch: 10/20, Accuracy: 77.22%, Train loss: 0.3171, Val loss: 0.4721
Epoch: 11/20, Accuracy: 79.32%, Train loss: 0.3006, Val loss: 0.4163
Epoch: 12/20, Accuracy: 79.61%, Train loss: 0.2815, Val loss: 0.3876
Epoch: 13/20, Accuracy: 81.35%, Train loss: 0.2693, Val loss: 0.3482
Epoch: 14/20, Accuracy: 82.83%, Train loss: 0.2496, Val loss: 0.2987
Epoch: 15/20, Accuracy: 85.01%, Train loss:

In [10]:
# load model: 
dpu_mlp.load(dpu_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_mlp_model.pth!


In [16]:
val_loss, acc = dpu_mlp.evaluate(val_loader)
print(f"Accuracy: {100*acc:.2f}%, Val loss: {val_loss:.4f}")

Accuracy: 87.90%, Val loss: 0.1871


In [12]:
torch.manual_seed(42)
host_mlp_model = SplitMLP_Host().to(device=device)
host_mlp_criterion = FocalLoss()
host_mlp_optimizer = torch.optim.AdamW(host_mlp_model.parameters(), lr=conf["learning_rate"], weight_decay=0.01)
host_mlp_scheduler = torch.optim.lr_scheduler.ExponentialLR(host_mlp_optimizer, 0.9)
host_mlp = SplitModel(host_mlp_model, host_mlp_criterion, host_mlp_optimizer, host_mlp_scheduler, device, dpu_model=dpu_mlp.model)

In [13]:
host_path = os.path.join(os.getcwd().replace("split_models", ""), "checkpoint", "host_mlp_model.pth")
if conf["train_host"]:
    acc, train_loss, val_loss = host_mlp.train(train_loader, val_loader, conf["epochs"])
    host_mlp.save(host_path)

Epoch: 1/20, Accuracy: 87.04%, Train loss: 0.7930, Val loss: 0.3117
Epoch: 2/20, Accuracy: 89.83%, Train loss: 0.2701, Val loss: 0.1857
Epoch: 3/20, Accuracy: 91.28%, Train loss: 0.1924, Val loss: 0.1362
Epoch: 4/20, Accuracy: 89.99%, Train loss: 0.1576, Val loss: 0.1514
Epoch: 5/20, Accuracy: 92.18%, Train loss: 0.1402, Val loss: 0.1080
Epoch: 6/20, Accuracy: 92.27%, Train loss: 0.1274, Val loss: 0.1090
Epoch: 7/20, Accuracy: 93.11%, Train loss: 0.1190, Val loss: 0.0926
Epoch: 8/20, Accuracy: 93.28%, Train loss: 0.1122, Val loss: 0.0878
Epoch: 9/20, Accuracy: 92.83%, Train loss: 0.1079, Val loss: 0.0939
Epoch: 10/20, Accuracy: 93.49%, Train loss: 0.1044, Val loss: 0.0820
Epoch: 11/20, Accuracy: 93.58%, Train loss: 0.1006, Val loss: 0.0795
Epoch: 12/20, Accuracy: 93.45%, Train loss: 0.0980, Val loss: 0.0821
Epoch: 13/20, Accuracy: 93.76%, Train loss: 0.0961, Val loss: 0.0782
Epoch: 14/20, Accuracy: 94.16%, Train loss: 0.0945, Val loss: 0.0755
Epoch: 15/20, Accuracy: 94.24%, Train loss:

In [14]:
# load model:
host_mlp.load(host_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_mlp_model.pth!


In [15]:
val_loss, acc = host_mlp.evaluate(val_loader)
print(f"Accuracy: {100*acc:.2f}%, Val loss: {val_loss:.4f}")

Accuracy: 94.31%, Val loss: 0.0708
