In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd().replace("large_models", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.profiler import profiler
import matplotlib.pyplot as plt
import time
from parse_dataset import NetworkDataset, parse_dataset, split_datasets
from model import Model
from load_models import models

In [2]:
conf = {
    "batch_size": 516,
    "epochs": 10,
    "learning_rate": 0.0001,
    "load_model": True
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# extract dataset from csv file
network_data_file = os.path.join(os.getcwd().replace("large_models", ""), "datasets", "network_packet_data_test.csv")

data, labels, label_dict = parse_dataset(network_data_file)

X_train, y_train, X_val, y_val, X_test, y_test = split_datasets(data, labels)

X_train, X_val, X_test = X_train.unsqueeze(-1), X_val.unsqueeze(-1), X_test.unsqueeze(-1)
print(X_train.shape)
print(y_train.shape)

# create train, val and test datasets
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=conf["batch_size"], shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=conf["batch_size"], shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=conf["batch_size"])

packet_dat,attack_cat

torch.Size([356334, 513, 1])
torch.Size([356334])


In [26]:
# binary dataset
def binary_dataset(labels, label_dict):
    binary_labels = []
    label_names = list(label_dict.keys())
    for label in labels:
        if label_names[label] == "Normal":
            binary_labels.append(0)
        else:
            binary_labels.append(1)
    y_binary = torch.tensor(binary_labels, dtype=torch.float).to(device)
    return y_binary

In [6]:
# create GRU model:
class SPLITGRU(nn.Module):
    def __init__(self, i_size, h_size, binary=False):
        super(SPLITGRU, self).__init__()
        self.i_size = i_size
        self.h_size = h_size
        self.binary = binary
        self.gru1 = nn.GRU(input_size=i_size, hidden_size=h_size, num_layers=2, batch_first=True, dropout=0.15, bidirectional=True, device=device)
        self.bn1 = nn.BatchNorm1d(2 * h_size)

        self.fc_binary = nn.Sequential(
            nn.Linear(2*self.h_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.15),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        self.fc_class = nn.Sequential(
            nn.Linear(2*self.h_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(p=0.15),
            nn.Linear(128, 24)
        )
        
    def forward(self, x, h0=None):
        if h0 is None:
            h0 = torch.zeros(4, x.shape[0], self.h_size).to(device)

        output, h0 = self.gru1(x, h0)  # output: [B, T, 2*h_size]

        # take last layer's hidden state (both directions)
        # h0 shape: [num_layers*2, B, size]
        h_last = h0.view(2, 2, x.shape[0], self.h_size)[-1]  # [2, B, h_size]
        h_last = torch.cat((h_last[0], h_last[1]), dim=1)  # [B, 2*h_size]

        # apply BN + FC
        logits = self.bn1(h_last)  # [B, 2*h_size] → batch norm
        if self.binary:
            out = self.fc_binary(logits) # [B, 1]
        else:
            out = self.fc_class(logits) # [B, 24]

        return out, logits

In [7]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.1, alpha=0.9, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [8]:
dpu_model = SPLITGRU(X_train.shape[2], 64, binary=True).to(device)
dpu_criterion = nn.BCELoss()
dpu_optimizer = torch.optim.AdamW(dpu_model.parameters(), lr=conf["learning_rate"], weight_decay=0.01)
dpu_scheduler = torch.optim.lr_scheduler.ExponentialLR(dpu_optimizer, 0.9)

In [9]:
host_model = SPLITGRU(128, 64).to(device)
host_criterion = FocalLoss()
host_optimizer = torch.optim.AdamW(host_model.parameters(), lr=conf["learning_rate"], weight_decay=0.01)
host_scheduler = torch.optim.lr_scheduler.ExponentialLR(host_optimizer, 0.9)

In [32]:
def evaluate(val_loader):
        y_true, y_pred = [], []
        y_bin_true, y_bin_pred = [], []

        dpu_model.eval()
        host_model.eval()
        for (data, labels) in val_loader:
            if not data.is_cuda or not labels.is_cuda:
                data, labels = data.to(device), labels.to(device)
            
            bin_labels = binary_dataset(labels, label_dict)
            
            with torch.no_grad():
                bin_pred, logits = dpu_model(data)
                logits = logits.detach()  # break the graph here
                logits = logits.unsqueeze(1)
                pred, _ = host_model(logits)
                bin_pred = torch.squeeze(bin_pred)
            
            y_true.append(labels)
            y_pred.append(pred)
            y_bin_true.append(bin_labels)
            y_bin_pred.append(bin_pred)

        # concat predictions and targets
        y_true, y_pred = torch.cat(y_true, dim=0), torch.cat(y_pred, dim=0)
        y_bin_true, y_bin_pred = torch.cat(y_bin_true, dim=0), torch.cat(y_bin_pred, dim=0)

        # calculate loss
        bin_loss = dpu_criterion(y_bin_pred, y_bin_true)
        class_loss = host_criterion(y_pred, y_true)

        # evaluate accuracy
        class_acc = (y_pred.argmax(dim=1) == y_true).float().mean()
        bin_acc = (y_bin_pred.round() == y_bin_true).float().mean()

        return class_loss, class_acc, bin_loss, bin_acc

In [33]:
def train(train_loader, val_loader, epochs):
    class_acc_list = []
    bin_acc_list = []

    for epoch in range(1, epochs + 1):
        dpu_model.train()
        host_model.train()
        running_class_loss = 0.0
        running_bin_loss = 0.0

        # train model
        for (data, labels) in train_loader:
            if not data.is_cuda or not labels.is_cuda:
                data, labels = data.to(device), labels.to(device)

            bin_labels = binary_dataset(labels, label_dict)
            
            bin_pred, logits = dpu_model(data)
            logits = logits.detach()  # break the graph here
            logits = logits.unsqueeze(1)
            pred, _ = host_model(logits)
            bin_pred = torch.squeeze(bin_pred)

            # calculate loss
            bin_loss = dpu_criterion(bin_pred, bin_labels)
            class_loss = host_criterion(pred, labels)
            running_bin_loss += bin_loss.item() * data.size(0)
            running_class_loss += class_loss.item() * logits.size(0)

            dpu_optimizer.zero_grad()
            bin_loss.backward()
            dpu_optimizer.step()

            host_optimizer.zero_grad()
            class_loss.backward()
            host_optimizer.step()
        
        # evaluate model
        val_class_loss, class_acc, val_bin_loss, bin_acc = evaluate(val_loader)
        #train_bin_loss = running_bin_loss / (train_loader.__len__() * train_loader.batch_size)
        #train_class_loss = running_class_loss / (train_loader.__len__() * logits.size(0))
        dpu_scheduler.step()
        host_scheduler.step()

        bin_acc_list.append(bin_acc)
        class_acc_list.append(class_acc)
    
        print(f"Epoch: {epoch}/{epochs}, DPU model accuracy: {100*bin_acc:.2f}%, Host model accuracy: {100*class_acc:.2f}%")

    return bin_acc, class_acc

In [34]:
# train dpu model
bin_acc, class_acc = train(train_loader, val_loader, conf["epochs"])

Epoch: 1/10, DPU model accuracy: 98.74%, Host model accuracy: 82.23%
Epoch: 2/10, DPU model accuracy: 99.28%, Host model accuracy: 83.98%
Epoch: 3/10, DPU model accuracy: 99.39%, Host model accuracy: 87.27%
Epoch: 4/10, DPU model accuracy: 99.44%, Host model accuracy: 88.42%
Epoch: 5/10, DPU model accuracy: 99.48%, Host model accuracy: 89.31%
Epoch: 6/10, DPU model accuracy: 99.52%, Host model accuracy: 92.33%
Epoch: 7/10, DPU model accuracy: 99.57%, Host model accuracy: 91.81%
Epoch: 8/10, DPU model accuracy: 99.59%, Host model accuracy: 92.31%
Epoch: 9/10, DPU model accuracy: 99.61%, Host model accuracy: 93.63%
Epoch: 10/10, DPU model accuracy: 99.60%, Host model accuracy: 92.89%


In [35]:
def save(checkpoint_path, model, optimizer, scheduler):
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict()
        }

        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

In [36]:
# save models:
dpu_checkpoint_path = os.path.join(os.getcwd().replace("large_models", ""), "checkpoint", "dpu_split_model.pth")
host_checkpoint_path = os.path.join(os.getcwd().replace("large_models", ""), "checkpoint", "host_split_model.pth")

save(dpu_checkpoint_path, dpu_model, dpu_optimizer, dpu_scheduler)
save(host_checkpoint_path, host_model, host_optimizer, host_scheduler)

Checkpoint saved at /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_split_model.pth
Checkpoint saved at /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_split_model.pth
