Imports, function defs, some light dataloading

In [None]:
import sys
import time
from importlib import reload
from pathlib import Path

import h5py
import numpy as np
import torch
from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from hermes.aeriel.client import InferenceClient
from hermes.aeriel.serve import serve
from hermes import quiver as qv
from tqdm import tqdm, trange

import aframe.architectures
import export.main
from export.snapshotter import BackgroundSnapshotter, BatchWhitener

# in case you need to make any changes to the code
# and want them reflected without having to restart
# the notebook
export = reload(export.main)
architectures = reload(aframe.architectures)

output_notebook()

DATA_DIR = Path.home() / "bbhnet" / "data" / "test" / "background"
IFOS = ["H1", "L1"]
SAMPLE_RATE = 2048
PSD_LENGTH = 64
KERNEL_LENGTH = 1.5
FDURATION = 1
HIGHPASS = 32

NUM_IFOS = len(IFOS)
FFTLENGTH = KERNEL_LENGTH + FDURATION
KERNEL_SIZE = int(KERNEL_LENGTH * SAMPLE_RATE)
SNAPSHOT_SIZE = int(FFTLENGTH * SAMPLE_RATE)

fnames = list(DATA_DIR.iterdir())
data = []
with h5py.File(fnames[0], "r") as f:
    for ifo in IFOS:
        data.append(f[ifo][:])
data = np.stack(data)
length = data.shape[-1] / SAMPLE_RATE


@torch.no_grad()
def do_local_inference(nn, batch_size, inference_sampling_rate):
    snapshotter = BackgroundSnapshotter(
        PSD_LENGTH,
        KERNEL_LENGTH,
        FDURATION,
        SAMPLE_RATE,
        inference_sampling_rate,
    ).to("cuda")
    whitener = BatchWhitener(
        KERNEL_LENGTH,
        SAMPLE_RATE,
        inference_sampling_rate,
        batch_size,
        FDURATION,
        FFTLENGTH,
        HIGHPASS
    ).to("cuda")

    state = torch.zeros((1, NUM_IFOS, snapshotter.state_size)).to("cuda")
    integration_size = int(KERNEL_LENGTH * inference_sampling_rate)
    step_size = int(batch_size * SAMPLE_RATE / inference_sampling_rate)

    num_steps = (data.shape[-1] - SNAPSHOT_SIZE) // step_size + 1
    start_time = time.time()
    latencies = np.zeros(num_steps)
    for i in trange(num_steps):
        tick = time.time()
        update = data[:, i * step_size: (i + 1) * step_size]
        update = torch.Tensor(update).to("cuda").view(1, NUM_IFOS, -1)

        x, state = snapshotter(update, state)
        x = whitener(x)
        y = nn(x).cpu().numpy()[0, 0]

        tock = time.time()
        latencies[i] = tock - tick

    end_time = time.time()
    throughput = length / (end_time - start_time)
    return throughput, latencies


def configure_deployment(
    batch_size: int,
    inference_sampling_rate: float,
    instances: int
):
    export.main(
        lambda num_ifos: architectures.ResNet(num_ifos, layers=[3, 4, 4, 3], norm_groups=16),
        repository_directory=Path("model_repo"),
        logdir=Path("."),
        num_ifos=NUM_IFOS,
        kernel_length=KERNEL_LENGTH,
        inference_sampling_rate=inference_sampling_rate,
        sample_rate=SAMPLE_RATE,
        batch_size=batch_size,
        fduration=FDURATION,
        psd_length=PSD_LENGTH,
        fftlength=KERNEL_LENGTH + FDURATION,
        highpass=32,
        streams_per_gpu=1,
        aframe_instances=instances,
        platform=qv.Platform.TENSORRT,
        clean=True
    )


def do_inference(batch_size, inference_sampling_rate):
    step_size = batch_size * int(SAMPLE_RATE / inference_sampling_rate)
    num_steps = (data.shape[-1] - SNAPSHOT_SIZE) // step_size + 1

    num_steps = min(num_steps, 10000)
    total_time = num_steps * step_size / SAMPLE_RATE

    ctx = serve("model_repo", image="hermes/tritonserver:22.12", gpus=[0], wait=True)
    pbar = tqdm(total=total_time)
    ticks = {}
    latencies = []

    def callback(y, request_id, sequence_id):
        tock = time.time()
        tick = ticks.pop(request_id)
        latencies.append(tock - tick)
        pbar.update(step_size / SAMPLE_RATE)

    with ctx, pbar:
        client = InferenceClient(
            "localhost:8001",
            model_name="aframe-stream",
            model_version=-1,
            callback=callback
        )
        with client:
            for i in range(min(num_steps, 10000)):
                ticks[i] = time.time()
                update = data[:, i * step_size: (i + 1) * step_size].astype("float32")
                client.infer(
                    update,
                    sequence_id=1001,
                    request_id=i,
                    sequence_start=i == 0,
                    sequence_end=i == (num_steps - 1)
                )
                if i < 5:
                    while i in ticks:
                        time.sleep(0.01)
                time.sleep(0.8 * batch_size / inference_sampling_rate)
    return latencies


def run_expt(batch_size, inference_sampling_rate, instances):
    configure_deployment(batch_size, inference_sampling_rate, instances)
    return do_inference(batch_size, inference_sampling_rate)

Here's where the actual heavy lifting happens. For each inference sampling rate, it will export a model, launch it in a Triton server, then do inference over it using a "real-time" request rate and measure the end-to-end latency for each request made.

In [None]:
latencies = {}
for i in range(5):
    batch_size = 2**i
    inference_sampling_rate = 2**(7 + i)
    latencies[batch_size] = run_expt(batch_size, inference_sampling_rate, 6)
# TODO: might make sense to save these out as you iterate on the plotting code

Now plot all these measurements

In [None]:
p = figure(
    height=300,
    width=500,
    x_axis_label=r"$$\text{Predictions per second [Hz]}$$",
    y_axis_label=r"$$\text{Expected trigger latency [s]}$$",
    x_axis_type="log",
    title=r"$$\text{Expected latency vs. predictive resolution}$$"
)
p.toolbar_location = None
p.title.text_font_style = "normal"

for batch_size, measurements in latencies.items():
    inference_sampling_rate = 2**7 * batch_size

    # only take the last few thousand samples to allow
    # server to stabilize after some burn-in
    measurements = np.array(measurements[-8000:])

    # add 1.5 to account for integration and whitening
    # filter settle-in, then account for the average
    # delay encountered by batch samples for having to
    # wait for a full batch to generate (i.e. relative
    # to when we _could_ have gotten an inference response
    # if we were doing batch 1 inference)
    latency = 1.5 + measurements + (batch_size - 1) / 2 / inference_sampling_rate
    low, mid, high = np.percentile(latency, [5, 50, 95])

    kwargs = {}
    if batch_size == 1:
        kwargs["legend_label"] = "90% interval"

    p.line(
        [inference_sampling_rate, inference_sampling_rate],
        [low, high],
        line_color="#555555",
        line_width=1.2,
        **kwargs
    )
    for val in [low, high]:
        p.line(
            [inference_sampling_rate * 0.98, inference_sampling_rate / 0.98],
            [val, val],
            line_color="#555555",
            line_width=1.2,
        )

    if kwargs:
        kwargs["legend_label"] = "Median"
    p.circle(
        inference_sampling_rate,
        mid,
        fill_alpha=0.5,
        size=9,
        **kwargs
    )
p.legend.location = "bottom_right"
show(p)