# Transmission benchmark

This notebook defines a single-pass benchmark and runs randomized, uniform sampling across all implementations in `src/impls` for the configured packet size and deletion probability. Results are collected in `results_table` with columns `impl`, `message_bitsize`, `packets_until_reconstructed`, `time`, `sender_time`, `receiver_time`.


### Collect samples

In [3]:
import importlib
from typing import Tuple
from joblib import Parallel, delayed
import math
import time
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np

from impls._interface import Config, Deletion, Protocol


# -----------------------------
# Config
# -----------------------------
packet_bitsize = 5
GILBERT_ELIOTT = (0.01, 0.99, 1.0, 0.0)  # P(G->B), P(B->G), P(1 | G), P(1 | B)
SAMPLING_DENSITY = 100.0
CHUNK_SIZE = 100  # "emit an event every 100 samples" -> one task per 100
N_WORKERS = -1  # -1 = use all cores; or set int like 8


# -----------------------------
# Benchmark primitive
# -----------------------------
def benchmark_single_pass(
    name: str,
    message_bitsize: int,
    protocol: Protocol,
    gilbert_eliott_k: Tuple[float, float, float, float],  # pGB, pBG, pG, pB
    max_iters: int = 10000,
):
    pGB, pBG, pG, pB = gilbert_eliott_k

    start_time = time.perf_counter()
    sender_time = 0.0
    receiver_time = 0.0

    data = np.random.randint(0, 2, size=message_bitsize, dtype=np.bool_)

    t = time.perf_counter()
    sampler = protocol.make_sampler(data)
    sender_time += time.perf_counter() - t

    t = time.perf_counter()
    estimator = protocol.make_estimator()
    receiver_time += time.perf_counter() - t

    def next_sample():
        nonlocal sender_time
        t = time.perf_counter()
        try:
            return next(sampler)
        except StopIteration:
            raise RuntimeError("Sampler exhausted")
        finally:
            sender_time += time.perf_counter() - t

    packets_until_reconstructed = 0

    try:
        # Simulate transmission with skips
        prev_was_deletion = True
        t = time.perf_counter()
        progress = next(estimator)
        receiver_time += time.perf_counter() - t
        is_in_good_state = False
        while True:
            packet = next_sample()

            obs_p = pG if is_in_good_state else pB
            if np.random.rand() > obs_p:
                if not prev_was_deletion:
                    t = time.perf_counter()
                    progress = estimator.send(Deletion)
                    receiver_time += time.perf_counter() - t
                prev_was_deletion = True
            else:
                t = time.perf_counter()
                progress = estimator.send(packet)
                receiver_time += time.perf_counter() - t
                prev_was_deletion = False

            trans_p = pGB if is_in_good_state else pBG

            if np.random.rand() < trans_p:
                is_in_good_state = not is_in_good_state

            packets_until_reconstructed += 1
            if packets_until_reconstructed > max_iters:
                raise RuntimeError(
                    f"Protocol {name} exceeded max iters at {progress} on {message_bitsize} message bits"
                )

    except StopIteration as e:
        total_time = time.perf_counter() - start_time

        recovered = e.value

        # Robust comparison for numpy bool arrays (and avoids "truth value is ambiguous" errors)
        is_correct = (
            isinstance(recovered, np.ndarray)
            and recovered.dtype == np.bool_
            and recovered.shape == data.shape
            and np.array_equal(recovered, data)
        )

        # Keep incorrect reconstructions to compute failure rates downstream.
        return (
            packets_until_reconstructed,
            total_time,
            sender_time,
            receiver_time,
            0.0 if is_correct else 1.0,
        )

    raise RuntimeError("Unreachable")


# -----------------------------
# Discover implementations (PASS ONLY STRINGS/RANGES)
# -----------------------------
impl_dir = Path("src/impls")
if not impl_dir.exists():
    impl_dir = Path("impls")

impl_names = sorted(p.stem for p in impl_dir.glob("*.py") if not p.name.startswith("_"))

base_samples = 0
impl_specs = []  # list[(impl_name: str, (lo, hi))]
for name in impl_names:
    module = importlib.import_module(f"impls.{name}")
    max_payload = int(module.max_message_bitsize(packet_bitsize))
    start = packet_bitsize + 1
    if max_payload < start:
        continue
    impl_specs.append((name, (start, max_payload)))
    base_samples += max_payload - start + 1

if not impl_specs:
    raise RuntimeError("No eligible implementations found for this packet size.")

samples = int(base_samples * SAMPLING_DENSITY)
n_tasks = math.ceil(samples / CHUNK_SIZE)


# -----------------------------
# One task = one chunk of samples
# (imports modules inside, maintains per-process protocol cache)
# -----------------------------
def run_chunk(n_in_chunk: int, seed: int, impl_specs_local):
    import importlib
    import random
    import numpy as np

    random.seed(seed)
    np.random.seed(seed)

    protocol_cache = {}
    module_cache = {}

    def get_module(impl_name: str):
        m = module_cache.get(impl_name)
        if m is None:
            m = importlib.import_module(f"impls.{impl_name}")
            module_cache[impl_name] = m
        return m

    def get_protocol(impl_name: str, message_bitsize: int):
        key = (impl_name, message_bitsize)
        p = protocol_cache.get(key)
        if p is None:
            module = get_module(impl_name)
            config = Config(
                packet_bitsize=packet_bitsize, message_bitsize=message_bitsize
            )
            p = module.create_protocol(config)
            protocol_cache[key] = p
        return p

    impl_specs_local_weights = [hi - lo + 1 for _, (lo, hi) in impl_specs_local]

    rows = []
    for _ in range(n_in_chunk):
        impl_name, (lo, hi) = random.choices(
            impl_specs_local, k=1, weights=impl_specs_local_weights
        )[0]
        message_bitsize = random.randint(lo, hi)
        protocol = get_protocol(impl_name, message_bitsize)

        packets, total_time, sender_time, receiver_time, is_correct = (
            benchmark_single_pass(
                impl_name,
                message_bitsize,
                protocol,
                GILBERT_ELIOTT,
            )
        )
        rows.append(
            [
                impl_name,
                message_bitsize,
                packets,
                total_time,
                sender_time,
                receiver_time,
                is_correct,
            ]
        )

    return rows


# -----------------------------
# Launch: tqdm advances once per finished chunk
# -----------------------------

# Make task sizes: mostly CHUNK_SIZE, last one smaller
task_sizes = [CHUNK_SIZE] * (n_tasks - 1) + [samples - CHUNK_SIZE * (n_tasks - 1)]
seeds = np.random.SeedSequence(12345).spawn(n_tasks)  # reproducible-ish across runs


results_chunks = Parallel(n_jobs=N_WORKERS, prefer="processes")(
    delayed(run_chunk)(task_sizes[i], int(seeds[i].generate_state(1)[0]), impl_specs)
    for i in tqdm(range(n_tasks), desc="Benchmarking", total=n_tasks)
)

# -----------------------------
# Collect into pd data frame
# -----------------------------

all_rows = [row for chunk in results_chunks for row in chunk]

results_table = pd.DataFrame(
    all_rows,
    columns=[
        "impl",
        "message_bitsize",
        "packets_until_reconstructed",
        "time",
        "sender_time",
        "receiver_time",
        "is_incorrect_reconstruction",
    ],
)

results_table

Benchmarking:   0%|          | 0/489 [00:00<?, ?it/s]

Unnamed: 0,impl,message_bitsize,packets_until_reconstructed,time,sender_time,receiver_time,is_incorrect_reconstruction
0,chain_dl2,20,7,1.075167,1.046067,0.000148,0.0
1,chain_2n,29,7,0.083952,0.037732,0.000188,0.0
2,chain_dl2,71,24,0.001238,0.000511,0.000419,0.0
3,chain_dl2,55,18,0.000964,0.000386,0.000350,0.0
4,chain_2n,22,6,0.000669,0.000336,0.000171,0.0
...,...,...,...,...,...,...,...
48895,chain_2n,81,24,0.000676,0.000277,0.000257,0.0
48896,chain_dl2,33,11,0.000330,0.000149,0.000093,0.0
48897,chain_dl2,127,44,0.000811,0.000330,0.000296,0.0
48898,chain_2n,139,52,0.001407,0.000457,0.000643,0.0


### Plot results

In [4]:
import importlib
from pathlib import Path
import sys
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# =========================
# Config
# =========================

SAMPLES = False  # True => add "(samples)" scatter traces; False => omit them entirely

BANDS = [0.25, 0.5]

# Kernel bandwidth measured in "message_bitsize units"
# Higher => more smoothing (wider neighborhood)
KERNEL_BW = 0.0

PACKETS_COL = "packets_until_reconstructed"
FAILURE_COL = "is_incorrect_reconstruction"
TIME_COLS = [
    ("time", "Total time"),
    ("sender_time", "Sender time"),
    ("receiver_time", "Receiver time"),
]


# =========================
# Colors
# =========================

palette = px.colors.qualitative.D3
impls = sorted(results_table["impl"].unique())
color_map = {impl: palette[i % len(palette)] for i, impl in enumerate(impls)}


def to_rgba(color, alpha):
    if color.startswith("rgb"):
        r, g, b = map(int, color[color.find("(") + 1 : color.find(")")].split(","))
    else:
        c = color.lstrip("#")
        r, g, b = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16)
    return f"rgba({r},{g},{b},{alpha})"


# =========================
# Weighted stats
# =========================


def gaussian_kernel(dx, bw):
    # If bw <= 0 => no smoothing: only exact matches get weight 1
    if bw is None or bw <= 0:
        return (dx == 0).astype(float)
    return np.exp(-0.5 * (dx / bw) ** 2)


def weighted_mean(y, w):
    sw = w.sum()
    if sw <= 0:
        return np.nan
    return (w * y).sum() / sw


def weighted_quantile(y, w, q):
    """
    Weighted quantile for q in [0,1].
    Uses the standard definition based on cumulative normalized weights.
    """
    y = np.asarray(y, dtype=float)
    w = np.asarray(w, dtype=float)

    mask = np.isfinite(y) & np.isfinite(w) & (w > 0)
    y = y[mask]
    w = w[mask]

    if len(y) == 0:
        return np.nan

    order = np.argsort(y)
    y = y[order]
    w = w[order]

    cw = np.cumsum(w)
    total = cw[-1]
    if total <= 0:
        return np.nan

    t = q * total
    idx = np.searchsorted(cw, t, side="left")
    idx = min(max(idx, 0), len(y) - 1)
    return y[idx]


# =========================
# Kernel-smoothed asymmetric bands from neighborhood samples
# =========================


def compute_kernel_asymmetric_bands(
    raw_df, bands, bw, value_col, min_total_weight=1e-12
):
    rows = []

    for impl, g in raw_df.groupby("impl"):
        # Evaluate summaries at observed payload values (unique x)
        x_grid = np.sort(g["message_bitsize"].unique().astype(float))

        if bw is None or bw <= 0:
            # -------------------------
            # No smoothing mode: per-x distributions
            # -------------------------
            for x0 in x_grid:
                gg = g[g["message_bitsize"].astype(float) == x0]
                y = gg[value_col].to_numpy(dtype=float)
                y = y[np.isfinite(y)]
                if y.size == 0:
                    continue

                m = float(np.mean(y))
                if not np.isfinite(m):
                    continue

                row = {"impl": impl, "message_bitsize": x0, "mean": m}

                if bands:
                    ya = y[y >= m]
                    yb = y[y <= m]

                    # Unweighted quantiles (consistent with weighted_quantile API via w=1)
                    wa = np.ones(ya.size, dtype=float)
                    wb = np.ones(yb.size, dtype=float)

                    for p in bands:
                        up = weighted_quantile(ya, wa, p) if ya.size > 0 else m
                        lo = weighted_quantile(yb, wb, 1 - p) if yb.size > 0 else m

                        if not np.isfinite(up):
                            up = m
                        if not np.isfinite(lo):
                            lo = m

                        row[f"upper_{p}"] = max(up, m)
                        row[f"lower_{p}"] = min(lo, m)

                rows.append(row)

            continue  # done with this impl

        # -------------------------
        # Smoothed mode (original behavior)
        # -------------------------
        x_all = g["message_bitsize"].to_numpy(dtype=float)
        y_all = g[value_col].to_numpy(dtype=float)

        for x0 in x_grid:
            w = gaussian_kernel(x_all - x0, bw)
            sw = w.sum()

            if sw < min_total_weight:
                continue

            m = weighted_mean(y_all, w)
            if not np.isfinite(m):
                continue

            above_mask = y_all >= m
            below_mask = y_all <= m

            ya, wa = y_all[above_mask], w[above_mask]
            yb, wb = y_all[below_mask], w[below_mask]

            row = {"impl": impl, "message_bitsize": x0, "mean": m}

            for p in bands:
                up = weighted_quantile(ya, wa, p) if wa.sum() > 0 else m
                lo = weighted_quantile(yb, wb, 1 - p) if wb.sum() > 0 else m

                if not np.isfinite(up):
                    up = m
                if not np.isfinite(lo):
                    lo = m

                row[f"upper_{p}"] = max(up, m)
                row[f"lower_{p}"] = min(lo, m)

            rows.append(row)

    return (
        pd.DataFrame(rows)
        .sort_values(["impl", "message_bitsize"])
        .reset_index(drop=True)
    )


def add_kernel_traces(
    fig,
    raw_df,
    bands_df,
    value_col,
    row=1,
    col=1,
    legend_tag="",
    show_legend=True,
):
    for impl in impls:
        color = color_map[impl]
        group_prefix = f"{impl}{legend_tag}"
        show_legend_item = bool(show_legend and row == 1 and col == 1)

        # ---------- raw scatter ----------
        if SAMPLES:
            raw = raw_df[raw_df["impl"] == impl]
            fig.add_trace(
                go.Scatter(
                    x=raw["message_bitsize"],
                    y=raw[value_col],
                    mode="markers",
                    name=f"{impl} (samples)",
                    legendgroup=f"{group_prefix}_scatter",
                    marker=dict(color=to_rgba(color, 0.35), size=5),
                    visible="legendonly",
                    showlegend=show_legend_item,
                ),
                row=row,
                col=col,
            )

        # ---------- kernel bands ----------
        df = bands_df[bands_df["impl"] == impl]
        x = df["message_bitsize"].to_numpy()

        alpha_step = 0.1 / max(1, len(BANDS) - 1)
        for i, p in enumerate(sorted(BANDS, reverse=True)):
            alpha = 0.2 + i * alpha_step
            fillcolor = to_rgba(color, alpha)

            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=df[f"upper_{p}"],
                    mode="lines",
                    line=dict(width=0),
                    legendgroup=group_prefix,
                    showlegend=False,
                    hoverinfo="skip",
                ),
                row=row,
                col=col,
            )
            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=df[f"lower_{p}"],
                    mode="lines",
                    line=dict(width=0),
                    fill="tonexty",
                    fillcolor=fillcolor,
                    legendgroup=group_prefix,
                    showlegend=False,
                    hoverinfo="skip",
                ),
                row=row,
                col=col,
            )

        # ---------- kernel mean ----------
        fig.add_trace(
            go.Scatter(
                x=x,
                y=df["mean"],
                mode="lines+markers",
                name=impl,
                legendgroup=group_prefix,
                line=dict(color=color),
                marker=dict(color=color, size=6),
                showlegend=show_legend_item,
            ),
            row=row,
            col=col,
        )


def add_kernel_mean_traces(
    fig,
    raw_df,
    bands_df,
    value_col,
    row=1,
    col=1,
    legend_tag="",
    show_legend=True,
):
    for impl in impls:
        color = color_map[impl]
        group_prefix = f"{impl}{legend_tag}"
        show_legend_item = bool(show_legend and row == 1 and col == 1)

        # ---------- raw scatter ----------
        if SAMPLES:
            raw = raw_df[raw_df["impl"] == impl]
            fig.add_trace(
                go.Scatter(
                    x=raw["message_bitsize"],
                    y=raw[value_col],
                    mode="markers",
                    name=f"{impl} (samples)",
                    legendgroup=f"{group_prefix}_scatter",
                    marker=dict(color=to_rgba(color, 0.35), size=5),
                    visible="legendonly",
                    showlegend=show_legend_item,
                ),
                row=row,
                col=col,
            )

        # ---------- kernel mean ----------
        df = bands_df[bands_df["impl"] == impl]
        x = df["message_bitsize"].to_numpy()
        fig.add_trace(
            go.Scatter(
                x=x,
                y=df["mean"],
                mode="lines+markers",
                name=impl,
                legendgroup=group_prefix,
                line=dict(color=color),
                marker=dict(color=color, size=6),
                showlegend=show_legend_item,
            ),
            row=row,
            col=col,
        )


def ensure_impls_on_path():
    impl_dir = Path("src/impls")
    if impl_dir.exists():
        root = impl_dir.parent
    else:
        impl_dir = Path("impls")
        root = impl_dir.parent if impl_dir.exists() else None

    if root is not None:
        root_str = str(root.resolve())
        if root_str not in sys.path:
            sys.path.insert(0, root_str)
    return impl_dir if impl_dir.exists() else None


def load_impl_module(impl, impl_dir):
    importlib.invalidate_caches()
    module_name = f"impls.{impl}"
    try:
        module = importlib.import_module(module_name)
        return importlib.reload(module)
    except ModuleNotFoundError:
        pass

    if impl_dir is None:
        return None

    impl_path = impl_dir / f"{impl}.py"
    if not impl_path.exists():
        return None

    fallback_name = f"impls_{impl}_theory"
    if fallback_name in sys.modules:
        del sys.modules[fallback_name]

    spec = importlib.util.spec_from_file_location(fallback_name, impl_path)
    if spec is None or spec.loader is None:
        return None
    module = importlib.util.module_from_spec(spec)
    sys.modules[fallback_name] = module
    spec.loader.exec_module(module)
    return module


def iter_theory_series(raw_df):
    impl_dir = ensure_impls_on_path()
    for impl in impls:
        module = load_impl_module(impl, impl_dir)
        if module is None:
            continue

        expected_estimate = getattr(
            module, "expected_packets_until_reconstructed", None
        )
        if expected_estimate is None:
            continue

        x = np.sort(raw_df[raw_df["impl"] == impl]["message_bitsize"].unique())
        if x.size == 0:
            continue

        y = np.array(
            [expected_estimate(GILBERT_ELIOTT, packet_bitsize, int(v)) for v in x],
            dtype=float,
        )
        yield impl, x, y


def add_theory_traces(fig, raw_df, row=1, col=1, show_legend=True):
    for impl, x, y in iter_theory_series(raw_df):
        color = color_map[impl]
        show_legend_item = bool(show_legend and row == 1 and col == 1)
        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                mode="lines",
                name=f"{impl} (theory)",
                legendgroup=f"{impl}__theory",
                line=dict(color=color, dash="dot", width=2),
                showlegend=show_legend_item,
            ),
            row=row,
            col=col,
        )


# =========================
# Plot (packets)
# =========================

bands_df = compute_kernel_asymmetric_bands(
    results_table, BANDS, bw=KERNEL_BW, value_col=PACKETS_COL
)

fig = make_subplots(rows=1, cols=1)
add_kernel_traces(fig, results_table, bands_df, PACKETS_COL)
add_theory_traces(fig, results_table)

pGB, pBG, pG, pB = GILBERT_ELIOTT
survival_prob = (pBG * pG + pGB * pB) / (pBG + pGB)
ideal_x = np.sort(results_table["message_bitsize"].unique())
bits_through = ideal_x / survival_prob
ideal_y = bits_through / packet_bitsize
fig.add_trace(
    go.Scatter(
        x=ideal_x,
        y=ideal_y,
        mode="lines",
        name="bound",
        line=dict(color="black", dash="dash"),
        showlegend=True,
    )
)

fig.update_layout(
    title=f"Convergence (packet_bitsize={packet_bitsize}, GE=(pGB={pGB}, pBG={pBG}, pG={pG}, pB={pB}), kernel_bw={KERNEL_BW})",
    xaxis_title="message_bitsize",
    yaxis_title=PACKETS_COL,
    legend=dict(groupclick="togglegroup"),
)

fig.show()
# =========================
# Plot (compute time)
# =========================

time_bands = {
    col: compute_kernel_asymmetric_bands(
        results_table, BANDS, bw=KERNEL_BW, value_col=col
    )
    for col, _ in TIME_COLS
}

fig_time = make_subplots(
    rows=len(TIME_COLS),
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.08,
    subplot_titles=[title for _, title in TIME_COLS],
)

for i, (col, _) in enumerate(TIME_COLS, start=1):
    add_kernel_traces(
        fig_time,
        results_table,
        time_bands[col],
        col,
        row=i,
        col=1,
        show_legend=True,
    )
    fig_time.update_yaxes(title_text="seconds", row=i, col=1)

fig_time.update_layout(
    title=f"Compute time (packet_bitsize={packet_bitsize}, GE=(pGB={pGB}, pBG={pBG}, pG={pG}, pB={pB}), kernel_bw={KERNEL_BW})",
    height=1200,
    xaxis_title="message_bitsize",
    legend=dict(groupclick="togglegroup"),
)

fig_time.show()

# =========================
# Plot (failure rate)
# =========================

failure_bands = compute_kernel_asymmetric_bands(
    results_table, [], bw=KERNEL_BW, value_col=FAILURE_COL
)

fig_fail = make_subplots(rows=1, cols=1)
add_kernel_mean_traces(fig_fail, results_table, failure_bands, FAILURE_COL)

fig_fail.update_layout(
    title=f"Failure rate (packet_bitsize={packet_bitsize}, GE=(pGB={pGB}, pBG={pBG}, pG={pG}, pB={pB}), kernel_bw={KERNEL_BW})",
    xaxis_title="message_bitsize",
    yaxis_title=FAILURE_COL,
    legend=dict(groupclick="togglegroup"),
)

fig_fail.update_yaxes(range=[0, 1])

fig_fail.show()