In [1]:
import os
import sys
import json
sys.path.append(os.path.join(os.getcwd().replace("split_models", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from parse_dataset import NetworkDataset
from model import SplitModel

In [2]:
conf = {
    "batch_size": 512,
    "epochs": 10,
    "learning_rate": 0.0001,
    "model": "lstm",
    "train_dpu": True,
    "train_host": True
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
dataset_path = os.path.join(os.getcwd().replace("split_models", ""), "datasets")

# load label_dict
json_file = os.path.join(dataset_path, "label_index.json")
with open(json_file, 'r') as file:
    label_dict = json.load(file)

# load train, val and test datasets
train_dataset_file = os.path.join(dataset_path, "train_dataset.pt")
X_train, y_train = torch.load(train_dataset_file)

val_dataset_file = os.path.join(dataset_path, "val_dataset.pt")
X_val, y_val = torch.load(val_dataset_file)

test_dataset_file = os.path.join(dataset_path, "test_dataset.pt")
X_test, y_test = torch.load(test_dataset_file)

if conf["model"] != "mlp" and conf["model"] != "light_mlp":
    X_train, X_val, X_test = X_train.unsqueeze(-1), X_val.unsqueeze(-1), X_test.unsqueeze(-1)
    
print(X_train.shape)
print(y_train.shape)

# create train, val and test datasets
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=conf["batch_size"], shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=conf["batch_size"], shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=conf["batch_size"])

torch.Size([356334, 513, 1])
torch.Size([356334])


In [5]:
class SplitGRU(nn.Module):
    def __init__(self, i_size, h_size, num_layers):
        super(SplitGRU, self).__init__()
        self.i_size = i_size
        self.h_size = h_size
        self.num_layers = num_layers
        self.dropout = 0.0

        if self.num_layers > 1:
            self.dropout = 0.15

        self.gru = nn.GRU(input_size=i_size, 
                          hidden_size=h_size, 
                          num_layers=self.num_layers, 
                          batch_first=True, 
                          dropout=self.dropout, 
                          bidirectional=True, 
                          device=device)
        self.bn1 = nn.BatchNorm1d(2 * h_size)
        self.fc = nn.Sequential(
            nn.Linear(2*self.h_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.15),
            nn.Linear(128, 24)
        )

    def forward(self, x, h0=None, c0=None):
        if h0 is None:
            h0 = torch.zeros(2*self.num_layers, x.shape[0], self.h_size).to(device)
        
        output, h0 = self.gru(x, h0)  # output: [B, T, 2*h_size]

        # take last layer's hidden state (both directions)
        # h0 shape: [num_layers*2, B, size]
        h_last = h0.view(self.num_layers, 2, x.shape[0], self.h_size)[-1]  # [2, B, h_size]
        h_last = torch.cat((h_last[0], h_last[1]), dim=1)  # [B, 2*h_size]

        # apply BN + FC
        h_last = self.bn1(h_last)         # [B, 2*h_size] â†’ batch norm
        logits = h_last.unsqueeze(1)
        out = self.fc(h_last)

        return logits, out
        

In [6]:
# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.1, alpha=0.9, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [7]:
torch.manual_seed(42)
dpu_gru_model = SplitGRU(i_size=1, h_size=64, num_layers=1).to(device=device)
dpu_gru_criterion = FocalLoss()
dpu_gru_optimizer = torch.optim.AdamW(dpu_gru_model.parameters(), conf["learning_rate"], weight_decay=0.01)
dpu_gru_scheduler = torch.optim.lr_scheduler.ExponentialLR(dpu_gru_optimizer, 0.9)
dpu_gru = SplitModel(dpu_gru_model, dpu_gru_criterion, dpu_gru_optimizer, dpu_gru_scheduler, device)

In [8]:
dpu_path = os.path.join(os.getcwd().replace("split_models", ""), "checkpoint", "dpu_gru_model.pth")
if conf["train_dpu"]:
    acc, train_loss, val_loss = dpu_gru.train(train_loader, val_loader, conf["epochs"])
    dpu_gru.save(dpu_path)

Epoch: 1/10, Accuracy: 88.88%, Train loss: 0.7028, Val loss: 0.2626
Epoch: 2/10, Accuracy: 93.36%, Train loss: 0.1905, Val loss: 0.1264
Epoch: 3/10, Accuracy: 94.51%, Train loss: 0.1190, Val loss: 0.0937
Epoch: 4/10, Accuracy: 95.47%, Train loss: 0.0938, Val loss: 0.0765
Epoch: 5/10, Accuracy: 95.73%, Train loss: 0.0815, Val loss: 0.0682
Epoch: 6/10, Accuracy: 95.95%, Train loss: 0.0739, Val loss: 0.0636
Epoch: 7/10, Accuracy: 96.08%, Train loss: 0.0683, Val loss: 0.0582
Epoch: 8/10, Accuracy: 96.13%, Train loss: 0.0639, Val loss: 0.0556
Epoch: 9/10, Accuracy: 96.25%, Train loss: 0.0610, Val loss: 0.0534
Epoch: 10/10, Accuracy: 96.34%, Train loss: 0.0586, Val loss: 0.0516
Checkpoint saved at /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_gru_model.pth


In [9]:
# load model: 
dpu_gru.load(dpu_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_gru_model.pth!


In [10]:
val_loss, acc = dpu_gru.evaluate(val_loader)
print(f"Accuracy: {100*acc:.2f}%, Val loss: {val_loss:.4f}")

Accuracy: 96.34%, Val loss: 0.0516


In [11]:
torch.manual_seed(42)
host_gru_model = SplitGRU(i_size=2*64, h_size=64, num_layers=3).to(device=device)
host_gru_criterion = FocalLoss()
host_gru_optimizer = torch.optim.AdamW(host_gru_model.parameters(), lr=conf["learning_rate"], weight_decay=0.01)
host_gru_scheduler = torch.optim.lr_scheduler.ExponentialLR(host_gru_optimizer, 0.9)
host_gru = SplitModel(host_gru_model, host_gru_criterion, host_gru_optimizer, host_gru_scheduler, device, dpu_model=dpu_gru.model)

In [12]:
host_path = os.path.join(os.getcwd().replace("split_models", ""), "checkpoint", "host_gru_model.pth")
if conf["train_host"]:
    acc, train_loss, val_loss = host_gru.train(train_loader, val_loader, conf["epochs"])
    host_gru.save(host_path)

Epoch: 1/10, Accuracy: 93.30%, Train loss: 0.4826, Val loss: 0.1354
Epoch: 2/10, Accuracy: 94.85%, Train loss: 0.1263, Val loss: 0.0834
Epoch: 3/10, Accuracy: 95.61%, Train loss: 0.0906, Val loss: 0.0669
Epoch: 4/10, Accuracy: 96.08%, Train loss: 0.0763, Val loss: 0.0587
Epoch: 5/10, Accuracy: 96.35%, Train loss: 0.0684, Val loss: 0.0539
Epoch: 6/10, Accuracy: 96.43%, Train loss: 0.0634, Val loss: 0.0506
Epoch: 7/10, Accuracy: 96.57%, Train loss: 0.0598, Val loss: 0.0485
Epoch: 8/10, Accuracy: 96.59%, Train loss: 0.0569, Val loss: 0.0465
Epoch: 9/10, Accuracy: 96.69%, Train loss: 0.0549, Val loss: 0.0455
Epoch: 10/10, Accuracy: 96.73%, Train loss: 0.0532, Val loss: 0.0442
Checkpoint saved at /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_gru_model.pth


In [13]:
# load model:
host_gru.load(host_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_gru_model.pth!


In [14]:
val_loss, acc = host_gru.evaluate(val_loader)
print(f"Accuracy: {100*acc:.2f}%, Val loss: {val_loss:.4f}")

Accuracy: 96.73%, Val loss: 0.0442
