# Bottleneck transmission benchmark

This notebook defines a single-pass benchmark and runs randomized, uniform sampling across all implementations in `src/impls` for the configured packet size and deletion probability. Results are collected in `results_table` with columns `impl`, `payload_bitsize`, `n_received`.


### Collect samples

In [None]:
import numpy as np
from impls._interface import Deletion, Protocol


def benchmark_single_pass(
    payload_bitsize: int,
    protocol: Protocol,
    max_iters: int = 10000,
    deletion_prob: float = 0.1,
):
    data = np.random.randint(0, 2, size=payload_bitsize, dtype=np.bool_)

    sampler = protocol.make_sampler(data)
    estimator = protocol.make_estimator()

    packets_until_reconstructed = 0

    try:
        prev_was_deletion = False

        progress = next(estimator)

        while True:
            try:
                packet = next(sampler)
            except StopIteration:
                break

            if np.random.rand() < deletion_prob:
                if not prev_was_deletion:
                    progress = estimator.send(Deletion)
                prev_was_deletion = True
            else:
                progress = estimator.send(packet)
                prev_was_deletion = False

            packets_until_reconstructed += 1
            assert (
                packets_until_reconstructed <= max_iters
            ), f"Exceeded max iters at progress: {progress*100.0:.2f}%"

    except StopIteration as e:
        assert np.array_equal(e.value, data)
        return packets_until_reconstructed  # number of packets received until reconstructed

    assert False

In [18]:
# Collect samples

import importlib
import random
import pandas as pd
from pathlib import Path
from tqdm import tqdm

from impls._interface import Config

PACKET_BITSIZE = 5
DELETION_PROB = 0.1
SAMPLING_DENSITY = 100.0

# Discover impl modules (exclude private _*.py).
impl_dir = Path("src/impls")
if not impl_dir.exists():
    impl_dir = Path("impls")

impl_names = sorted(p.stem for p in impl_dir.glob("*.py") if not p.name.startswith("_"))

# Build payload-size ranges per impl for uniform sampling.
# At the same time estimate how many to sample
samples = 0
impl_specs = []
for name in impl_names:
    module = importlib.import_module(f"impls.{name}")
    max_payload = int(module.max_payload_bitsize(PACKET_BITSIZE))
    start = PACKET_BITSIZE + 1
    if max_payload < start:
        continue
    payload_sizes = [start, max_payload]
    samples += max_payload - start + 1
    impl_specs.append((name, module, payload_sizes))


impls_count = len(impl_specs)
if impls_count == 0:
    raise RuntimeError("No eligible implementations found for this packet size.")

samples = int(samples * SAMPLING_DENSITY)

# Cache protocol instances per (impl, payload_bitsize).
protocol_cache = {}


def get_protocol(name, module, payload_bitsize):
    key = (name, payload_bitsize)
    protocol = protocol_cache.get(key)
    if protocol is None:
        config = Config(packet_bitsize=PACKET_BITSIZE, payload_bitsize=payload_bitsize)
        protocol = module.create_protocol(config)
        protocol_cache[key] = protocol
    return protocol


# Each row samples an impl and payload size uniformly from available options.
rows = []
for _ in tqdm(range(samples), desc="Benchmarking"):
    name, module, payload_sizes = random.choice(impl_specs)
    payload_bitsize = random.randint(payload_sizes[0], payload_sizes[1])
    protocol = get_protocol(name, module, payload_bitsize)
    packets_until_reconstructed = benchmark_single_pass(
        payload_bitsize, protocol, deletion_prob=DELETION_PROB
    )
    rows.append([name, payload_bitsize, packets_until_reconstructed])


results_table = pd.DataFrame(
    rows, columns=["impl", "payload_bitsize", "packets_until_reconstructed"]
)

del rows

results_table

Benchmarking:   0%|          | 0/16900 [00:00<?, ?it/s]

Benchmarking: 100%|██████████| 16900/16900 [00:25<00:00, 675.56it/s]


Unnamed: 0,impl,payload_bitsize,packets_until_reconstructed
0,systematic,16,31
1,lt,13,6
2,chain,77,51
3,lt,16,22
4,systematic,12,4
...,...,...,...
16895,systematic,14,22
16896,chain,89,37
16897,lt,6,1
16898,systematic,10,4


### Plot results

In [19]:
import plotly.express as px
import plotly.graph_objects as go

summary = results_table.groupby(["impl", "payload_bitsize"], as_index=False).agg(
    mean_packets=("packets_until_reconstructed", "mean"),
    std_packets=("packets_until_reconstructed", "std"),
)

impls = sorted(summary["impl"].unique())
palette = px.colors.qualitative.D3
color_map = {impl: palette[i % len(palette)] for i, impl in enumerate(impls)}


def to_rgba(color, alpha):
    if color.startswith("rgb"):
        values = color[color.find("(") + 1 : color.find(")")].split(",")
        r, g, b = [int(v.strip()) for v in values]
    else:
        hex_color = color.lstrip("#")
        r = int(hex_color[0:2], 16)
        g = int(hex_color[2:4], 16)
        b = int(hex_color[4:6], 16)
    return f"rgba({r},{g},{b},{alpha})"


fig = go.Figure()

for impl in impls:
    df_impl = summary[summary["impl"] == impl].sort_values("payload_bitsize")
    x = df_impl["payload_bitsize"]
    mean = df_impl["mean_packets"]
    std = df_impl["std_packets"].fillna(0)
    color = color_map[impl]
    band_2 = to_rgba(color, 0.12)
    band_1 = to_rgba(color, 0.24)

    upper_2 = mean + 2 * std
    lower_2 = mean - 2 * std
    upper_1 = mean + std
    lower_1 = mean - std

    fig.add_trace(
        go.Scatter(
            x=x,
            y=upper_2,
            mode="lines",
            line=dict(width=0),
            legendgroup=impl,
            showlegend=False,
            hoverinfo="skip",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=lower_2,
            mode="lines",
            line=dict(width=0),
            fill="tonexty",
            fillcolor=band_2,
            legendgroup=impl,
            showlegend=False,
            hoverinfo="skip",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=upper_1,
            mode="lines",
            line=dict(width=0),
            legendgroup=impl,
            showlegend=False,
            hoverinfo="skip",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=lower_1,
            mode="lines",
            line=dict(width=0),
            fill="tonexty",
            fillcolor=band_1,
            legendgroup=impl,
            showlegend=False,
            hoverinfo="skip",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=mean,
            mode="lines+markers",
            name=impl,
            legendgroup=impl,
            showlegend=True,
            line=dict(color=color),
            marker=dict(color=color, size=6),
        )
    )

fig.update_layout(
    title=f"Convergence (packet_bitsize={PACKET_BITSIZE}, deletion_prob={DELETION_PROB})",
    xaxis_title="payload_bitsize",
    yaxis_title="packets_until_reconstructed",
    legend=dict(groupclick="togglegroup"),
)
fig.show()