In [1]:
import time
from pathlib import Path

import numpy as np
import pandas as pd
import onnxruntime
import torch
import yaml

from rtal.datasets.dataset import ROMDataset
from torch.utils.data import DataLoader


In [2]:
# Paths and run-time configuration
from pathlib import Path

NOTEBOOK_DIR = Path(__file__).resolve().parent if '__file__' in globals() else Path.cwd()

ROOT_DIR = None
for candidate in [NOTEBOOK_DIR, *NOTEBOOK_DIR.parents]:
    if (candidate / 'checkpoints').exists() and (candidate / 'onnx_files_narrow_but_deep_untrained').exists():
        ROOT_DIR = candidate
        break

if ROOT_DIR is None:
    raise FileNotFoundError(
        'Could not locate project root. Please set ROOT_DIR manually to the directory that contains the checkpoints folder.'
    )

DATA_ROOT = ROOT_DIR / 'data/rom_det-3_part-200_cont-and-rounded_excerpt/'
CONFIG_PATH = ROOT_DIR / 'checkpoints/config.yaml'
ONNX_MODEL_PATH = ROOT_DIR / 'onnx_files_narrow_but_deep_untrained/mlp_half.onnx'
NUM_SAMPLES = 10  # number of inputs to benchmark
THREAD_COUNTS = [1, 2, 3, 4, 5, 6, 7, 8]  # CPU thread counts to test
BATCH_SIZE = 1

print(f'Using data from {DATA_ROOT}')
print(f'Using config {CONFIG_PATH}')
print(f'Using ONNX model {ONNX_MODEL_PATH}')

with open(CONFIG_PATH, 'r', encoding='utf-8') as handle:
    config = yaml.safe_load(handle)

num_particles = config['data']['num_particles']
model_in_features = config['model']['in_features']
print(f"Loaded configuration: num_particles={num_particles}, in_features={model_in_features}")


Using data from /home/synthara/VersalPrjs/LDRD/rtda_demo/model/RealTimeAlignment/onnx_no-residual/data/rom_det-3_part-200_cont-and-rounded_excerpt
Using config /home/synthara/VersalPrjs/LDRD/rtda_demo/model/RealTimeAlignment/onnx_no-residual/checkpoints/config.yaml
Using ONNX model /home/synthara/VersalPrjs/LDRD/rtda_demo/model/RealTimeAlignment/onnx_no-residual/onnx_files_narrow_but_deep_untrained/mlp_half.onnx
Loaded configuration: num_particles=50, in_features=6


In [3]:
# Collect inputs for benchmarking
if not DATA_ROOT.exists():
    raise FileNotFoundError(f'Dataset directory {DATA_ROOT} does not exist. Please check ROOT_DIR/data configuration.')

dataset = ROMDataset(str(DATA_ROOT), split='train', num_particles=num_particles)
dataset_size = len(dataset)
if dataset_size == 0:
    raise RuntimeError(f'Found zero samples under {DATA_ROOT} (split="train").')

samples_to_collect = min(NUM_SAMPLES, dataset_size)
if samples_to_collect < NUM_SAMPLES:
    print(f'Only {samples_to_collect} samples available; will benchmark on all of them instead of the requested {NUM_SAMPLES}.')

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

inputs = []
for event_idx, event in enumerate(dataloader):
    readout = event['readout_curr_cont']
    readout = torch.transpose(readout, 1, 2).flatten(-2, -1)
    sample = readout.detach().to(torch.float16).cpu().numpy()
    inputs.append(sample)
    if len(inputs) >= samples_to_collect:
        break

if not inputs:
    raise RuntimeError(f'Failed to collect samples from {DATA_ROOT} (processed 0 batches).')

print(f'Collected {len(inputs)} samples (requested {NUM_SAMPLES}) with shape {inputs[0].shape} for benchmarking.')


Collected 10 samples (requested 10) with shape (1, 50, 6) for benchmarking.


In [4]:
def benchmark_cpu(inputs, num_threads, onnx_path):
    '''Run the ONNX model on CPU and capture timing breakdowns.

    Parameters
    ----------
    inputs : list[np.ndarray]
        List of input batches shaped like (1, num_particles, in_features).
    num_threads : int
        Number of CPU threads to allow ONNX Runtime to use.
    onnx_path : Path
        Path to the ONNX model.
    '''
    session_options = onnxruntime.SessionOptions()
    session_options.intra_op_num_threads = num_threads
    session_options.inter_op_num_threads = 1
    session_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL

    t0 = time.perf_counter()
    session = onnxruntime.InferenceSession(str(onnx_path), sess_options=session_options, providers=['CPUExecutionProvider'])
    load_ms = (time.perf_counter() - t0) * 1e3

    input_transfer_ms = 0.0
    output_read_ms = 0.0
    graph_exec_ms = 0.0

    input_name = session.get_inputs()[0].name

    for sample in inputs:
        t_in_start = time.perf_counter()
        ort_inputs = {input_name: sample}
        t_in_end = time.perf_counter()
        outputs = session.run(None, ort_inputs)
        t_run_end = time.perf_counter()
        _ = outputs[0]
        t_out_end = time.perf_counter()

        input_transfer_ms += (t_in_end - t_in_start) * 1e3
        graph_exec_ms += (t_run_end - t_in_end) * 1e3
        output_read_ms += (t_out_end - t_run_end) * 1e3

    samples = len(inputs)
    graph_exec_per_input_ms = graph_exec_ms / samples
    host_total_overall_ms = load_ms + input_transfer_ms + output_read_ms + graph_exec_ms
    host_total_per_input_ms = host_total_overall_ms / samples

    return {
        'threads': num_threads,
        'weights/bias transfer (ms)': load_ms,
        'no. of inputs': samples,
        'input transfer (ms)': input_transfer_ms,
        'output read (ms)': output_read_ms,
        'graph exec (total ms)': graph_exec_ms,
        'graph exec per input (ms)': graph_exec_per_input_ms,
        'host total per input (ms)': host_total_per_input_ms,
        'host total (overall ms)': host_total_overall_ms,
    }


In [15]:
results = [benchmark_cpu(inputs, num_threads=n, onnx_path=ONNX_MODEL_PATH) for n in THREAD_COUNTS]
df_results = pd.DataFrame(results).set_index('threads')
df_results

Unnamed: 0_level_0,weights/bias transfer (ms),no. of inputs,input transfer (ms),output read (ms),graph exec (total ms),graph exec per input (ms),host total per input (ms),host total (overall ms)
threads,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1,6.491118,10,0.003276,0.008936,11.722563,1.172256,1.822589,18.225893
2,5.169051,10,0.002765,0.008943,13.673847,1.367385,1.885461,18.854606
3,5.684597,10,0.003606,0.009635,13.34278,1.334278,1.904062,19.040618
4,3.477669,10,0.002243,0.008062,10.682279,1.068228,1.417025,14.170253
5,3.521425,10,0.002284,0.007663,10.61081,1.061081,1.414218,14.142183
6,3.573643,10,0.002693,0.009264,11.925808,1.192581,1.551141,15.511408
7,3.701074,10,0.002625,0.008952,11.525888,1.152589,1.523854,15.238539
8,4.410831,10,0.002733,0.007693,13.262958,1.326296,1.768421,17.684215
