In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd().replace("model_inference", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.profiler import profiler
import matplotlib.pyplot as plt
import time
from parse_dataset import NetworkDataset, parse_dataset, split_datasets, binary_dataset
from split_model import SplitModelDPU, SplitModelHost
from load_models import models
from transfer_tensors import HostSocket

In [2]:
conf = {
    "batch_size": 128,
    "epochs": 10,
    "learning_rate": 0.0001,
    "dpu": False
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [4]:
class BenchmarkHost:
    def __init__(self, host_model, host_model_path):
        self.host_model = host_model
        self.host_model.load(host_model_path)
    
    def run(self, logits, labels):
        """Run benchmark measuring model accuracy, cpu time, memory usage and total runtime"""
        self.host_model.model.eval()

        with torch.profiler.profile(
            activities=[profiler.ProfilerActivity.CPU],
            record_shapes=True,
            profile_memory=True,
            with_stack=True
        ) as prof:
            start = time.perf_counter()
            # warm-up before profiling:
            with torch.no_grad():
                pred, logits = self.host_model.model(logits)
            end = time.perf_counter()

        accuracy = (pred.argmax(dim=1) == labels).float().mean()

        # Wall-clock time
        wall_clock_time = end - start

        # CPU compute time (µs → s)
        cpu_time_total_s = sum([e.self_cpu_time_total for e in prof.key_averages()]) / 1e6  # microseconds
        cpu_utilization = cpu_time_total_s / wall_clock_time

        # peak memory during profiling
        peak_mem = max([e.cpu_memory_usage for e in prof.key_averages()])
        peak_mem = peak_mem / 1024**2 #MB

        return 100*accuracy, cpu_utilization, peak_mem, wall_clock_time, logits, labels

In [5]:
host_path = os.path.join(os.getcwd().replace("model_inference", ""), "checkpoint", "host_split_model.pth")
host_model = models["host"]
#host_model.load(host_path)

In [6]:
benchmark = BenchmarkHost(host_model, host_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_split_model.pth!


In [7]:
so_file = os.path.join(os.getcwd().replace("model_inference", ""), "socket_transfer", "socket_transfer.so")
socket = HostSocket(so_file)

In [8]:
logits = socket.receive()

In [9]:
targets = socket.receive()
targets = targets.to(dtype=torch.long)

In [10]:
acc, cpu, mem, runtime, _, labels = benchmark.run(logits=logits, labels=targets)

ERROR:2025-10-24 12:08:50 2866632:2866632 DeviceProperties.cpp:47] gpuGetDeviceCount failed with code 35


In [11]:
print(f"Benchmark (Host): Accuracy: {acc:.2f}%, CPU Usage: {cpu:.2f} cores, Memory Usage: {mem:.2f}MB, Runtime: {runtime:.2f}s")

Benchmark (Host): Accuracy: 4.69%, CPU Usage: 0.93 cores, Memory Usage: 0.82MB, Runtime: 0.01s


In [12]:
print(logits[15])
print(targets[:10])
print(logits.shape, targets.shape)

tensor([[ 0.7519, -0.7116,  0.7584, -0.7286, -0.8710, -0.7440, -0.7528, -0.6558,
         -0.7405,  0.7365,  0.0868,  0.6960,  0.5404,  0.4739, -1.0177, -0.7227,
          0.2148, -0.8008,  0.4949,  0.7224,  0.3411, -0.8695,  0.9343,  0.7056,
          0.5952, -0.6922, -0.4433,  0.2624,  0.7037,  0.1183,  0.7128, -0.7314,
          0.8568,  0.8124, -0.8194,  0.8370,  0.8752,  0.0959,  0.8442, -0.8577,
         -0.2609,  0.8041,  0.8916,  0.4259,  0.4938, -0.2489,  0.7129,  0.8262,
         -0.7765,  0.5222, -0.5039, -0.6198, -0.7860,  0.8677,  0.0428,  0.7230,
         -0.7287,  0.1237,  0.6157, -0.4312,  0.8740,  0.7747,  0.2759, -0.9036,
          0.8885, -0.7883, -0.9757,  0.7501,  0.3384, -0.9236, -1.0406, -0.8356,
         -0.8641, -1.2406,  0.8834, -0.8488,  0.9031,  1.1479,  0.8998, -0.9417,
         -0.0137, -1.3901, -0.9367,  1.2979,  1.2158, -1.1115,  0.7707,  0.5537,
          0.9766,  1.9593,  0.8900, -0.9053,  0.1653, -0.8392,  1.2220,  0.8334,
          1.0840,  0.0981,  