In [1]:
import os
import sys
sys.path.append(os.path.join(os.getcwd().replace("model_inference", "")))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.profiler import profiler
import matplotlib.pyplot as plt
import time
from parse_dataset import NetworkDataset, parse_dataset, split_datasets, binary_dataset
from split_model import SplitModelDPU, SplitModelHost
from load_models import models
from transfer_tensors import HostSocket

In [2]:
conf = {
    "batch_size": 512,
    "epochs": 10,
    "learning_rate": 0.0001,
    "dpu": False
}

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [4]:
class BenchmarkHost:
    def __init__(self, host_model, host_model_path):
        self.host_model = host_model
        self.host_model.load(host_model_path)
    
    def run(self, logits, labels):
        """Run benchmark measuring model accuracy, cpu time, memory usage and total runtime"""
        self.host_model.model.eval()

        with torch.profiler.profile(
            activities=[profiler.ProfilerActivity.CPU],
            record_shapes=True,
            profile_memory=True,
            with_stack=True
        ) as prof:
            start = time.perf_counter()
            # warm-up before profiling:
            with torch.no_grad():
                pred, logits = self.host_model.model(logits)
            end = time.perf_counter()

        accuracy = (pred.argmax(dim=1) == labels).float().mean()

        # Wall-clock time
        wall_clock_time = end - start

        # CPU compute time (µs → s)
        cpu_time_total_s = sum([e.self_cpu_time_total for e in prof.key_averages()]) / 1e6  # microseconds
        cpu_utilization = cpu_time_total_s / wall_clock_time

        # peak memory during profiling
        peak_mem = max([e.cpu_memory_usage for e in prof.key_averages()])
        peak_mem = peak_mem / 1024**2 #MB

        return 100*accuracy, cpu_utilization, peak_mem, wall_clock_time, logits, labels

In [5]:
host_path = os.path.join(os.getcwd().replace("model_inference", ""), "checkpoint", "host_split_model.pth")
host_model = models["host"]

In [6]:
benchmark = BenchmarkHost(host_model, host_path)

Checkpoint loaded from /home/jorgetf/testmodel/Network-Packet-ML-Model/checkpoint/host_split_model.pth!


In [7]:
so_file = os.path.join(os.getcwd().replace("model_inference", ""), "socket_transfer", "socket_transfer.so")
socket = HostSocket(so_file)

In [8]:
logits = socket.receive()

In [9]:
targets = socket.receive()
targets = targets.to(dtype=torch.long)

In [10]:
acc, cpu, mem, runtime, _, labels = benchmark.run(logits=logits, labels=targets)

ERROR:2025-10-26 12:02:08 3127010:3127010 DeviceProperties.cpp:47] gpuGetDeviceCount failed with code 35


In [11]:
print(f"Benchmark (Host): Accuracy: {acc:.2f}%, CPU Usage: {cpu:.2f} cores, Memory Usage: {mem:.2f}MB, Runtime: {runtime:.2f}s")

Benchmark (Host): Accuracy: 93.95%, CPU Usage: 0.92 cores, Memory Usage: 3.30MB, Runtime: 0.01s


In [12]:
print(logits[15])
print(targets[:10])
print(logits.shape, targets.shape)

tensor([[ 0.7313, -0.6907, -0.2624,  0.0244, -1.0083,  0.7555, -0.5137, -0.8355,
          0.8498, -0.7010, -0.8411, -0.9054, -0.7942,  0.0053, -1.1667, -1.1217,
         -1.3613, -1.1800,  0.7099,  0.7456, -0.7751, -0.4873, -1.7606,  0.8915,
          0.7598,  0.7118, -0.9263, -1.0522,  0.8669, -0.8070,  0.7431,  1.9215,
         -0.7002,  0.7611,  0.6988, -0.1856,  1.2212,  0.7642, -0.7228,  0.8630,
         -0.6999, -0.7822, -0.8375,  0.8328,  0.8041,  0.2072, -0.8121,  0.9513,
          1.2035,  0.7163,  0.0371,  0.7709,  0.7297, -0.6607, -0.4662,  0.2501,
          0.6822, -0.6901,  0.7201, -0.7794, -0.7908, -0.9197,  0.7846,  0.7815,
          0.5245, -0.0925, -0.8306, -0.8944, -0.9474, -0.6359, -0.8204, -1.1732,
          0.0523, -0.7144, -0.9882, -1.0135, -0.9948,  0.5071, -1.1235, -0.5889,
          0.7702, -0.6648, -1.1151, -0.1653, -1.0198,  0.7425, -1.0049, -0.8738,
          1.0249, -0.9712, -0.6420,  1.0417, -0.8071,  1.2857,  0.6720,  0.8138,
          1.0195,  0.8417, -