In [1]:
import os
import sys
import json
sys.path.append(os.path.join(os.getcwd().replace("model_inference", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from parse_dataset import NetworkDataset
from model import SplitModel
from load_models import models

In [2]:
conf = {
    "batch_size": 128,
    "epochs": 10,
    "learning_rate": 0.0001,
    "dpu_model": "dpu_mlp",
    "host_model": "host_mlp",
    "model": "mlp"
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
dataset_path = os.path.join(os.getcwd().replace("model_inference", ""), "datasets")

# load label_dict
json_file = os.path.join(dataset_path, "label_index.json")
with open(json_file, 'r') as file:
    label_dict = json.load(file)

# load train, val and test datasets
train_dataset_file = os.path.join(dataset_path, "train_dataset.pt")
X_train, y_train = torch.load(train_dataset_file)

val_dataset_file = os.path.join(dataset_path, "val_dataset.pt")
X_val, y_val = torch.load(val_dataset_file)

test_dataset_file = os.path.join(dataset_path, "test_dataset.pt")
X_test, y_test = torch.load(test_dataset_file)

if conf["model"] != "mlp" and conf["model"] != "light_mlp":
    X_train, X_val, X_test = X_train.unsqueeze(-1), X_val.unsqueeze(-1), X_test.unsqueeze(-1)
    
print(X_train.shape)
print(y_train.shape)

# create train, val and test datasets
train_dataset = NetworkDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=conf["batch_size"], shuffle=True)

val_dataset = NetworkDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=conf["batch_size"], shuffle=True)

test_dataset = NetworkDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=conf["batch_size"])

torch.Size([356334, 513])
torch.Size([356334])


In [5]:
dpu_mlp = models[conf["dpu_model"]]
dpu_path = os.path.join(os.getcwd().replace("model_inference", ""), "checkpoint", "dpu_mlp_model.pth")
# load model: 
dpu_mlp.load(dpu_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/dpu_mlp_model.pth!


In [6]:
host_mlp = models[conf["host_model"]]
host_path = os.path.join(os.getcwd().replace("model_inference", ""), "checkpoint", "host_mlp_model.pth")
host_mlp.load(host_path)


Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_mlp_model.pth!


In [7]:
data, labels = next(iter(test_loader))
if not data.is_cuda or not labels.is_cuda:
    data, labels = data.to(device), labels.to(device)
logits, _ = dpu_mlp.model(data)
print(logits.shape)
_, pred = host_mlp.model(logits)
print(pred.argmax(dim=1)[:10])
print(labels[:10])

torch.Size([128, 320])
tensor([ 0,  0, 14,  0,  8,  0,  5, 15, 10,  5], device='cuda:0')
tensor([ 0,  0, 14,  0, 12, 10,  5, 15, 10,  5], device='cuda:0')
