# Transmission benchmark

This notebook defines a single-pass benchmark and runs randomized, uniform sampling across all implementations in `src/impls` for the configured packet size and deletion probability. Results are collected in `results_table` with columns `impl`, `message_bitsize`, `packets_until_reconstructed`, `time`, `sender_time`, `receiver_time`.


### Collect samples

In [1]:
import importlib
from random import randint
from joblib import Parallel, delayed
import math
import time
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np

from impls._interface import Config, Deletion, Protocol


# -----------------------------
# Config
# -----------------------------
packet_bitsize = 7
DELETION_PROB = 0.1
SAMPLING_DENSITY = 100.0
CHUNK_SIZE = 100  # "emit an event every 100 samples" -> one task per 100
N_WORKERS = -1  # -1 = use all cores; or set int like 8


# -----------------------------
# Benchmark primitive
# -----------------------------
def benchmark_single_pass(
    name: str,
    message_bitsize: int,
    protocol: Protocol,
    max_iters: int = 10000,
    deletion_prob: float = 0.1,
):
    start_time = time.perf_counter()
    sender_time = 0.0
    receiver_time = 0.0

    data = np.random.randint(0, 2, size=message_bitsize, dtype=np.bool_)

    t = time.perf_counter()
    sampler = protocol.make_sampler(data)
    sender_time += time.perf_counter() - t

    t = time.perf_counter()
    estimator = protocol.make_estimator()
    receiver_time += time.perf_counter() - t

    def next_sample():
        nonlocal sender_time
        t = time.perf_counter()
        try:
            return next(sampler)
        except StopIteration:
            raise RuntimeError("Sampler exhausted")
        finally:
            sender_time += time.perf_counter() - t

    packets_until_reconstructed = 0

    try:
        # Simulate unknown start
        for _ in range(randint(1, 100)):
            next_sample()

        # Simulate transmission with deletions
        prev_was_deletion = True
        t = time.perf_counter()
        progress = next(estimator)
        receiver_time += time.perf_counter() - t
        while True:
            packet = next_sample()

            if np.random.rand() < deletion_prob:
                if not prev_was_deletion:
                    t = time.perf_counter()
                    progress = estimator.send(Deletion)
                    receiver_time += time.perf_counter() - t
                prev_was_deletion = True
            else:
                t = time.perf_counter()
                progress = estimator.send(packet)
                receiver_time += time.perf_counter() - t
                prev_was_deletion = False

            packets_until_reconstructed += 1
            if packets_until_reconstructed > max_iters:
                raise RuntimeError(
                    f"Protocol {name} exceeded max iters at {progress} on {message_bitsize} message bits"
                )

    except StopIteration as e:
        total_time = time.perf_counter() - start_time

        recovered = e.value

        # Robust comparison for numpy bool arrays (and avoids "truth value is ambiguous" errors)
        ok = (
            isinstance(recovered, np.ndarray)
            and recovered.dtype == np.bool_
            and recovered.shape == data.shape
            and np.array_equal(recovered, data)
        )

        if not ok:
            # Helpful diagnostics (first mismatch, etc.)
            msg = f"Invalid reconstruction for protocol={name}, message_bitsize={message_bitsize}, expected={data}, got={recovered}"
            if isinstance(recovered, np.ndarray) and recovered.shape == data.shape:
                mism = np.flatnonzero(recovered != data)
                if mism.size:
                    i = int(mism[0])
                    msg += f"; first mismatch at bit {i}: got={bool(recovered[i])}, expected={bool(data[i])}"
                    msg += f"; mismatches={int(mism.size)}/{int(data.size)}"
            raise AssertionError(msg)

        return packets_until_reconstructed, total_time, sender_time, receiver_time

    raise RuntimeError("Unreachable")


# -----------------------------
# Discover implementations (PASS ONLY STRINGS/RANGES)
# -----------------------------
impl_dir = Path("src/impls")
if not impl_dir.exists():
    impl_dir = Path("impls")

impl_names = sorted(p.stem for p in impl_dir.glob("*.py") if not p.name.startswith("_"))

base_samples = 0
impl_specs = []  # list[(impl_name: str, (lo, hi))]
for name in impl_names:
    module = importlib.import_module(f"impls.{name}")
    max_payload = int(module.max_message_bitsize(packet_bitsize))
    start = packet_bitsize + 1
    if max_payload < start:
        continue
    impl_specs.append((name, (start, max_payload)))
    base_samples += max_payload - start + 1

if not impl_specs:
    raise RuntimeError("No eligible implementations found for this packet size.")

samples = int(base_samples * SAMPLING_DENSITY)
n_tasks = math.ceil(samples / CHUNK_SIZE)


# -----------------------------
# One task = one chunk of samples
# (imports modules inside, maintains per-process protocol cache)
# -----------------------------
def run_chunk(n_in_chunk: int, seed: int, impl_specs_local):
    import importlib
    import random
    import numpy as np

    random.seed(seed)
    np.random.seed(seed)

    protocol_cache = {}
    module_cache = {}

    def get_module(impl_name: str):
        m = module_cache.get(impl_name)
        if m is None:
            m = importlib.import_module(f"impls.{impl_name}")
            module_cache[impl_name] = m
        return m

    def get_protocol(impl_name: str, message_bitsize: int):
        key = (impl_name, message_bitsize)
        p = protocol_cache.get(key)
        if p is None:
            module = get_module(impl_name)
            config = Config(
                packet_bitsize=packet_bitsize, message_bitsize=message_bitsize
            )
            p = module.create_protocol(config)
            protocol_cache[key] = p
        return p

    impl_specs_local_weights = [hi - lo + 1 for _, (lo, hi) in impl_specs_local]

    rows = []
    for _ in range(n_in_chunk):
        impl_name, (lo, hi) = random.choices(
            impl_specs_local, k=1, weights=impl_specs_local_weights
        )[0]
        message_bitsize = random.randint(lo, hi)
        protocol = get_protocol(impl_name, message_bitsize)

        packets, total_time, sender_time, receiver_time = benchmark_single_pass(
            impl_name,
            message_bitsize,
            protocol,
            deletion_prob=DELETION_PROB,
        )
        rows.append(
            [
                impl_name,
                message_bitsize,
                packets,
                total_time,
                sender_time,
                receiver_time,
            ]
        )

    return rows


# -----------------------------
# Launch: tqdm advances once per finished chunk
# -----------------------------

# Make task sizes: mostly CHUNK_SIZE, last one smaller
task_sizes = [CHUNK_SIZE] * (n_tasks - 1) + [samples - CHUNK_SIZE * (n_tasks - 1)]
seeds = np.random.SeedSequence(12345).spawn(n_tasks)  # reproducible-ish across runs


results_chunks = Parallel(n_jobs=N_WORKERS, prefer="processes")(
    delayed(run_chunk)(task_sizes[i], int(seeds[i].generate_state(1)[0]), impl_specs)
    for i in tqdm(range(n_tasks), desc="Benchmarking", total=n_tasks)
)

Benchmarking:   0%|          | 0/1811 [00:00<?, ?it/s]

In [2]:
all_rows = [row for chunk in results_chunks for row in chunk]

results_table = pd.DataFrame(
    all_rows,
    columns=[
        "impl",
        "message_bitsize",
        "packets_until_reconstructed",
        "time",
        "sender_time",
        "receiver_time",
    ],
)

results_table

Unnamed: 0,impl,message_bitsize,packets_until_reconstructed,time,sender_time,receiver_time
0,chain_r,124,52,0.104554,0.083395,0.000931
1,chain_r,128,34,0.003195,0.001823,0.000688
2,chain_2n_1,38,13,0.092736,0.043665,0.000331
3,chain_2n_1,779,354,0.028239,0.006445,0.008736
4,chain_r,371,128,0.016561,0.007977,0.002156
...,...,...,...,...,...,...
181095,chain_r,387,135,0.006117,0.002953,0.000777
181096,chain_r,774,278,0.033644,0.015505,0.001728
181097,chain_r,799,361,0.040119,0.018958,0.002668
181098,chain_2n_1,278,63,0.001583,0.000627,0.000443


### Plot results

In [3]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# =========================
# Config
# =========================

BANDS = [0.25, 0.5]

# Kernel bandwidth measured in "message_bitsize units"
# Higher => more smoothing (wider neighborhood)
KERNEL_BW = 3.0

PACKETS_COL = "packets_until_reconstructed"
TIME_COLS = [
    ("time", "Total time"),
    ("sender_time", "Sender time"),
    ("receiver_time", "Receiver time"),
]


# =========================
# Colors
# =========================

palette = px.colors.qualitative.D3
impls = sorted(results_table["impl"].unique())
color_map = {impl: palette[i % len(palette)] for i, impl in enumerate(impls)}


def to_rgba(color, alpha):
    if color.startswith("rgb"):
        r, g, b = map(int, color[color.find("(") + 1 : color.find(")")].split(","))
    else:
        c = color.lstrip("#")
        r, g, b = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16)
    return f"rgba({r},{g},{b},{alpha})"


# =========================
# Weighted stats
# =========================


def gaussian_kernel(dx, bw):
    # dx: array of (x - x0)
    return np.exp(-0.5 * (dx / bw) ** 2)


def weighted_mean(y, w):
    sw = w.sum()
    if sw <= 0:
        return np.nan
    return (w * y).sum() / sw


def weighted_quantile(y, w, q):
    """
    Weighted quantile for q in [0,1].
    Uses the standard definition based on cumulative normalized weights.
    """
    y = np.asarray(y, dtype=float)
    w = np.asarray(w, dtype=float)

    mask = np.isfinite(y) & np.isfinite(w) & (w > 0)
    y = y[mask]
    w = w[mask]

    if len(y) == 0:
        return np.nan

    order = np.argsort(y)
    y = y[order]
    w = w[order]

    cw = np.cumsum(w)
    total = cw[-1]
    if total <= 0:
        return np.nan

    t = q * total
    idx = np.searchsorted(cw, t, side="left")
    idx = min(max(idx, 0), len(y) - 1)
    return y[idx]


# =========================
# Kernel-smoothed asymmetric bands from neighborhood samples
# =========================


def compute_kernel_asymmetric_bands(
    raw_df, bands, bw, value_col, min_total_weight=1e-12
):
    rows = []

    for impl, g in raw_df.groupby("impl"):
        x_all = g["message_bitsize"].to_numpy(dtype=float)
        y_all = g[value_col].to_numpy(dtype=float)

        # Evaluate summaries at observed payload values (unique x)
        x_grid = np.sort(g["message_bitsize"].unique().astype(float))

        for x0 in x_grid:
            w = gaussian_kernel(x_all - x0, bw)
            sw = w.sum()

            if sw < min_total_weight:
                continue

            m = weighted_mean(y_all, w)
            if not np.isfinite(m):
                continue

            # asymmetric split around the *kernel-weighted* mean at x0
            above_mask = y_all >= m
            below_mask = y_all <= m

            ya, wa = y_all[above_mask], w[above_mask]
            yb, wb = y_all[below_mask], w[below_mask]

            row = {"impl": impl, "message_bitsize": x0, "mean": m}

            for p in bands:
                # upper band: quantile among samples >= mean
                # lower band: quantile among samples <= mean
                up = weighted_quantile(ya, wa, p) if wa.sum() > 0 else m
                lo = weighted_quantile(yb, wb, 1 - p) if wb.sum() > 0 else m

                # Safety: keep ordering sane
                if not np.isfinite(up):
                    up = m
                if not np.isfinite(lo):
                    lo = m

                row[f"upper_{p}"] = max(up, m)
                row[f"lower_{p}"] = min(lo, m)

            rows.append(row)

    return (
        pd.DataFrame(rows)
        .sort_values(["impl", "message_bitsize"])
        .reset_index(drop=True)
    )


def add_kernel_traces(
    fig,
    raw_df,
    bands_df,
    value_col,
    row=1,
    col=1,
    legend_tag="",
    show_legend=True,
):
    for impl in impls:
        color = color_map[impl]
        group_prefix = f"{impl}{legend_tag}"
        show_legend_item = bool(show_legend and row == 1 and col == 1)

        # ---------- raw scatter ----------
        raw = raw_df[raw_df["impl"] == impl]
        fig.add_trace(
            go.Scatter(
                x=raw["message_bitsize"],
                y=raw[value_col],
                mode="markers",
                name=f"{impl} (samples)",
                legendgroup=f"{group_prefix}_scatter",
                marker=dict(color=to_rgba(color, 0.35), size=5),
                visible="legendonly",
                showlegend=show_legend_item,
            ),
            row=row,
            col=col,
        )

        # ---------- kernel bands ----------
        df = bands_df[bands_df["impl"] == impl]
        x = df["message_bitsize"].to_numpy()

        alpha_step = 0.1 / max(1, len(BANDS) - 1)
        for i, p in enumerate(sorted(BANDS, reverse=True)):
            alpha = 0.2 + i * alpha_step
            fillcolor = to_rgba(color, alpha)

            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=df[f"upper_{p}"],
                    mode="lines",
                    line=dict(width=0),
                    legendgroup=group_prefix,
                    showlegend=False,
                    hoverinfo="skip",
                ),
                row=row,
                col=col,
            )
            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=df[f"lower_{p}"],
                    mode="lines",
                    line=dict(width=0),
                    fill="tonexty",
                    fillcolor=fillcolor,
                    legendgroup=group_prefix,
                    showlegend=False,
                    hoverinfo="skip",
                ),
                row=row,
                col=col,
            )

        # ---------- kernel mean ----------
        fig.add_trace(
            go.Scatter(
                x=x,
                y=df["mean"],
                mode="lines+markers",
                name=impl,
                legendgroup=group_prefix,
                line=dict(color=color),
                marker=dict(color=color, size=6),
                showlegend=show_legend_item,
            ),
            row=row,
            col=col,
        )


# =========================
# Plot (packets)
# =========================

bands_df = compute_kernel_asymmetric_bands(
    results_table, BANDS, bw=KERNEL_BW, value_col=PACKETS_COL
)

fig = make_subplots(rows=1, cols=1)
add_kernel_traces(fig, results_table, bands_df, PACKETS_COL)

fig.update_layout(
    title=f"Convergence (packet_bitsize={packet_bitsize}, deletion_prob={DELETION_PROB}, kernel_bw={KERNEL_BW})",
    xaxis_title="message_bitsize",
    yaxis_title=PACKETS_COL,
    legend=dict(groupclick="togglegroup"),
)

fig.show()

# =========================
# Plot (compute time)
# =========================

time_bands = {
    col: compute_kernel_asymmetric_bands(
        results_table, BANDS, bw=KERNEL_BW, value_col=col
    )
    for col, _ in TIME_COLS
}

fig_time = make_subplots(
    rows=len(TIME_COLS),
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=[title for _, title in TIME_COLS],
)

for i, (col, _) in enumerate(TIME_COLS, start=1):
    add_kernel_traces(
        fig_time,
        results_table,
        time_bands[col],
        col,
        row=i,
        col=1,
        show_legend=True,
    )
    fig_time.update_yaxes(title_text="seconds", row=i, col=1)

fig_time.update_layout(
    title=f"Compute time (packet_bitsize={packet_bitsize}, deletion_prob={DELETION_PROB}, kernel_bw={KERNEL_BW})",
    height=1200,
    xaxis_title="message_bitsize",
    legend=dict(groupclick="togglegroup"),
)

fig_time.show()