In [1]:
import os
import sys
import json
sys.path.append(os.path.join(os.getcwd().replace("split_models", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from parse_dataset import NetworkDataset
from model import SplitModel

In [None]:
conf = {
    "batch_size": 512,
    "epochs": 10,
    "learning_rate": 0.0001,
    "model": "lstm",
    "train_dpu": False,
    "train_host": False
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
dataset_path = os.path.join(os.getcwd().replace("split_models", ""), "datasets")

# load label_dict
json_file = os.path.join(dataset_path, "label_index.json")
with open(json_file, 'r') as file:
    label_dict = json.load(file)

# load train, val and test datasets
train_dataset_file = os.path.join(dataset_path, "train_dataset.pt")
X_train, y_train = torch.load(train_dataset_file)

val_dataset_file = os.path.join(dataset_path, "val_dataset.pt")
X_val, y_val = torch.load(val_dataset_file)

test_dataset_file = os.path.join(dataset_path, "test_dataset.pt")
X_test, y_test = torch.load(test_dataset_file)

if conf["model"] != "mlp" and conf["model"] != "light_mlp":
    X_train, X_val, X_test = X_train.unsqueeze(-1), X_val.unsqueeze(-1), X_test.unsqueeze(-1)
    
print(X_train.shape)
print(y_train.shape)

# create train, val and test datasets
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=conf["batch_size"], shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=conf["batch_size"], shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=conf["batch_size"])

torch.Size([356334, 513, 1])
torch.Size([356334])


In [5]:
class SplitLSTM(nn.Module):
    def __init__(self, i_size, h_size, num_layers):
        super(SplitLSTM, self).__init__()
        self.i_size = i_size
        self.h_size = h_size
        self.num_layers = num_layers
        self.dropout = 0.0

        if self.num_layers > 1:
            self.dropout = 0.15

        self.lstm = nn.LSTM(input_size=i_size, 
                            hidden_size=h_size, 
                            num_layers=self.num_layers, 
                            batch_first=True,
                            dropout=self.dropout, 
                            bidirectional=True, 
                            device=device)
        self.bn1 = nn.BatchNorm1d(2 * h_size)
        self.fc = nn.Sequential(
            nn.Linear(2*self.h_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.15),
            nn.Linear(128, 24)
        )

    def forward(self, x, h0=None, c0=None):
        if h0 is None or c0 is None:
            h0 = torch.zeros(2*self.num_layers, x.shape[0], self.h_size).to(device)
            c0 = torch.zeros(2*self.num_layers, x.shape[0], self.h_size).to(device)
        
        output, (h0, c0) = self.lstm(x, (h0, c0))  # output: [B, T, 2*h_size]

        # take last layer's hidden state (both directions)
        # h0 shape: [num_layers*2, B, size]
        h_last = h0.view(self.num_layers, 2, x.shape[0], self.h_size)[-1]  # [2, B, h_size]
        h_last = torch.cat((h_last[0], h_last[1]), dim=1)  # [B, 2*h_size]

        # apply BN + FC
        h_last = self.bn1(h_last)         # [B, 2*h_size] â†’ batch norm
        logits = h_last.unsqueeze(1)
        out = self.fc(h_last)

        return logits, out
        

In [6]:
# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.1, alpha=0.9, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [9]:
torch.manual_seed(42)
dpu_lstm_model = SplitLSTM(i_size=1, h_size=64, num_layers=1).to(device=device)
dpu_lstm_criterion = FocalLoss()
dpu_lstm_optimizer = torch.optim.AdamW(dpu_lstm_model.parameters(), conf["learning_rate"], weight_decay=0.01)
dpu_lstm_scheduler = torch.optim.lr_scheduler.ExponentialLR(dpu_lstm_optimizer, 0.9)
dpu_lstm = SplitModel(dpu_lstm_model, dpu_lstm_criterion, dpu_lstm_optimizer, dpu_lstm_scheduler, device)

In [10]:
dpu_path = os.path.join(os.getcwd().replace("split_models", ""), "checkpoint", "dpu_lstm_model.pth")
if conf["train_dpu"]:
    acc, train_loss, val_loss = dpu_lstm.train(train_loader, val_loader, conf["epochs"])
    dpu_lstm.save(dpu_path)

Epoch: 1/10, Accuracy: 85.65%, Train loss: 0.7184, Val loss: 0.2934
Epoch: 2/10, Accuracy: 92.59%, Train loss: 0.2262, Val loss: 0.1494
Epoch: 3/10, Accuracy: 93.83%, Train loss: 0.1385, Val loss: 0.1077
Epoch: 4/10, Accuracy: 94.70%, Train loss: 0.1083, Val loss: 0.0885
Epoch: 5/10, Accuracy: 95.12%, Train loss: 0.0922, Val loss: 0.0778
Epoch: 6/10, Accuracy: 95.47%, Train loss: 0.0827, Val loss: 0.0702
Epoch: 7/10, Accuracy: 95.67%, Train loss: 0.0757, Val loss: 0.0644
Epoch: 8/10, Accuracy: 95.78%, Train loss: 0.0696, Val loss: 0.0605
Epoch: 9/10, Accuracy: 95.96%, Train loss: 0.0659, Val loss: 0.0575
Epoch: 10/10, Accuracy: 95.98%, Train loss: 0.0629, Val loss: 0.0560
Checkpoint saved at /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_lstm_model.pth


In [11]:
# load model: 
dpu_lstm.load(dpu_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_lstm_model.pth!


In [12]:
val_loss, acc = dpu_lstm.evaluate(val_loader)
print(f"Accuracy: {100*acc:.2f}%, Val loss: {val_loss:.4f}")

Accuracy: 95.98%, Val loss: 0.0560


In [14]:
torch.manual_seed(42)
host_lstm_model = SplitLSTM(i_size=2*64, h_size=64, num_layers=3).to(device=device)
host_lstm_criterion = FocalLoss()
host_lstm_optimizer = torch.optim.AdamW(host_lstm_model.parameters(), lr=conf["learning_rate"], weight_decay=0.01)
host_lstm_scheduler = torch.optim.lr_scheduler.ExponentialLR(host_lstm_optimizer, 0.9)
host_lstm = SplitModel(host_lstm_model, host_lstm_criterion, host_lstm_optimizer, host_lstm_scheduler, device, dpu_model=dpu_lstm.model)

In [15]:
host_path = os.path.join(os.getcwd().replace("split_models", ""), "checkpoint", "host_lstm_model.pth")
if conf["train_host"]:
    acc, train_loss, val_loss = host_lstm.train(train_loader, val_loader, conf["epochs"])
    host_lstm.save(host_path)

Epoch: 1/10, Accuracy: 93.40%, Train loss: 0.4761, Val loss: 0.1320
Epoch: 2/10, Accuracy: 94.83%, Train loss: 0.1225, Val loss: 0.0830
Epoch: 3/10, Accuracy: 95.54%, Train loss: 0.0887, Val loss: 0.0674
Epoch: 4/10, Accuracy: 95.89%, Train loss: 0.0755, Val loss: 0.0599
Epoch: 5/10, Accuracy: 96.04%, Train loss: 0.0681, Val loss: 0.0550
Epoch: 6/10, Accuracy: 96.18%, Train loss: 0.0634, Val loss: 0.0518
Epoch: 7/10, Accuracy: 96.28%, Train loss: 0.0596, Val loss: 0.0495
Epoch: 8/10, Accuracy: 96.36%, Train loss: 0.0567, Val loss: 0.0476
Epoch: 9/10, Accuracy: 96.44%, Train loss: 0.0545, Val loss: 0.0463
Epoch: 10/10, Accuracy: 96.56%, Train loss: 0.0529, Val loss: 0.0448
Checkpoint saved at /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_lstm_model.pth


In [16]:
# load model:
host_lstm.load(host_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_lstm_model.pth!


In [17]:
val_loss, acc = host_lstm.evaluate(val_loader)
print(f"Accuracy: {100*acc:.2f}%, Val loss: {val_loss:.4f}")

Accuracy: 96.56%, Val loss: 0.0448
