In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd().replace("distilled_models", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.profiler import profiler
import matplotlib.pyplot as plt
import time
from parse_dataset import NetworkDataset, parse_dataset, split_datasets
from split_model import SplitModelDPU
from load_models import models
from knowledge_distillation import SplitModelKD

In [2]:
conf = {
    "batch_size": 516,
    "epochs": 10,
    "learning_rate": 0.0001,
    "load_model": True
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# extract dataset from csv file
network_data_file = os.path.join(os.getcwd().replace("distilled_models", ""), "datasets", "network_packet_data_test.csv")

data, labels, label_dict = parse_dataset(network_data_file)

X_train, y_train, X_val, y_val, X_test, y_test = split_datasets(data, labels)

X_train, X_val, X_test = X_train.unsqueeze(-1), X_val.unsqueeze(-1), X_test.unsqueeze(-1)
print(X_train.shape)
print(y_train.shape)

# create train, val and test datasets
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=conf["batch_size"], shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=conf["batch_size"], shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=conf["batch_size"])

packet_dat,attack_cat

torch.Size([356334, 513, 1])
torch.Size([356334])


In [5]:
dpu_path = os.path.join(os.getcwd().replace("distilled_models", ""), "checkpoint", "dpu_split_model.pth")
full_model_dpu = models["dpu"]
full_model_dpu.load(dpu_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_split_model.pth!


In [6]:
class LightSplitDPU(nn.Module):
    def __init__(self, i_size, h_size):
        super(LightSplitDPU, self).__init__()
        self.i_size = i_size
        self.h_size = h_size
        self.gru1 = nn.GRU(input_size=i_size, hidden_size=h_size, num_layers=1, batch_first=True, bidirectional=True, device=device)
        self.bn1 = nn.BatchNorm1d(2 * h_size)

        self.fc = nn.Sequential(
            nn.Linear(2*h_size, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(p=0.15),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x, h0=None, c0=None):
        if h0 is None:
            h0 = torch.zeros(2, x.shape[0], self.h_size).to(device)

        output, h0 = self.gru1(x, h0)  # output: [B, T, 2*h_size]

        # take last layer's hidden state (both directions)
        # h0 shape: [num_layers*2, B, size]
        h_last = h0.view(1, 2, x.shape[0], self.h_size)[-1]  # [2, B, h_size]
        h_last = torch.cat((h_last[0], h_last[1]), dim=1)  # [B, 2*h_size]

        # apply BN + FC
        logits = self.bn1(h_last)         # [B, 2*h_size] → batch norm

        out = self.fc(logits)            # [B, 1]
        return out, logits

In [7]:
torch.manual_seed(42)
dpu_model = LightSplitDPU(X_train.shape[2], 64).to(device)
dpu_criterion = nn.BCELoss()
dpu_optimizer = torch.optim.AdamW(dpu_model.parameters(), lr=conf["learning_rate"], weight_decay=0.01)
dpu_scheduler = torch.optim.lr_scheduler.ExponentialLR(dpu_optimizer, 0.9)
light_dpu_model = SplitModelDPU(dpu_model, dpu_criterion, dpu_optimizer, dpu_scheduler, device)

In [8]:
kd = SplitModelKD(teacher=full_model_dpu, student=light_dpu_model, device=device, label_dict=label_dict)

In [9]:
accuracy, train_loss, val_loss = kd.train_kd(train_loader, val_loader, conf["epochs"])

Epoch: 1/10, Accuracy: 98.57%, Train loss: 0.0007, Val loss: 1.4282
Epoch: 2/10, Accuracy: 99.13%, Train loss: 0.0004, Val loss: 0.8722
Epoch: 3/10, Accuracy: 99.42%, Train loss: 0.0002, Val loss: 0.5834
Epoch: 4/10, Accuracy: 99.39%, Train loss: 0.0001, Val loss: 0.6149
Epoch: 5/10, Accuracy: 99.52%, Train loss: 0.0001, Val loss: 0.4754
Epoch: 6/10, Accuracy: 99.53%, Train loss: 0.0001, Val loss: 0.4656
Epoch: 7/10, Accuracy: 99.59%, Train loss: 0.0001, Val loss: 0.4066
Epoch: 8/10, Accuracy: 99.61%, Train loss: 0.0001, Val loss: 0.3850
Epoch: 9/10, Accuracy: 99.46%, Train loss: 0.0001, Val loss: 0.5422
Epoch: 10/10, Accuracy: 99.63%, Train loss: 0.0000, Val loss: 0.3674


In [10]:
new_dpu_path = dpu_path = os.path.join(os.getcwd().replace("distilled_models", ""), "checkpoint", "light_dpu_split_model.pth")
light_dpu_model.save(new_dpu_path)

Checkpoint saved at /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/light_dpu_split_model.pth
