In [1]:
import random
from collections import defaultdict
from typing import Callable

import numpy as np
from tqdm import tqdm

from interface import Producer, Recoverer

def test(
    producer_constructor: Callable[[int, int, int], Producer],
    recoverer_constructor: Callable[[int, int], Recoverer],
    N: int,
    D: int,
    passes: int = 100,
    sample_burst_size: Callable[[], int] = lambda: random.randint(1, 10),
    sample_data_size: Callable[[], int] = lambda: random.randint(3, 30),
    iters_bound: int = 10000,
):
    time_to_recover_distribution = defaultdict(lambda: 0)

    init_pat = 3
    patience = init_pat

    for _ in tqdm(range(passes)):
        # Make random binary data matrix
        data = random.randint(0, D - 1)
        recovered = None

        # Initialize producer and recoverer
        producer = producer_constructor(data, N, D)
        recoverer = recoverer_constructor(N, D)
        time_to_recover = 0

        for _ in range(iters_bound):
            # Skip samples
            burst_size = sample_burst_size()
            time_to_recover += burst_size
            for _ in range(burst_size):
                producer.generate()

            # Use samples
            for _ in range(sample_data_size()):
                sample = producer.generate()
                time_to_recover += 1
                recovered = recoverer.feed(sample)
                if recovered is not None:
                    break

            if recovered is not None:
                break

            recoverer.feed(None)  # Indicate end of continuity

        if recovered is None:
            if patience > 0:
                print(f"Failed to recover {data} within iteration bound")
            elif patience == 0:
                print("Giving up further tests due to repeated failures.")
            patience -= 1
            passes -= 1
            continue

        # Check correctness
        assert np.array_equal(recovered, data), f"Expected {data}, got {recovered}"
        time_to_recover_distribution[time_to_recover] += 1
    
    print(f"{init_pat - patience} failures.")
    return {k: v / passes for k, v in time_to_recover_distribution.items()}

In [2]:
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from IPython.display import HTML, display
import math

def visualize_distribution(distribution: dict):
    # Convert the result (dict: time -> probability) into a sorted DataFrame
    # result is expected to be a dict mapping int -> float (probability)
    df = pd.DataFrame(list(distribution.items()), columns=["time_to_recover", "prob"]) 
    df = df.sort_values("time_to_recover").reset_index(drop=True)

    # Reindex to a full integer range so smoothing is continuous across missing bins
    min_t = int(df["time_to_recover"].min())
    max_t = int(df["time_to_recover"].max())
    full_idx = pd.Series(range(min_t, max_t + 1), name="time_to_recover")
    df_full = pd.DataFrame({"time_to_recover": full_idx})
    df_full = df_full.merge(df, on="time_to_recover", how="left").fillna(0)

    # Smoothing: Gaussian kernel convolution implemented with NumPy
    # Choose sigma adaptively relative to range length (tune divisor to change smoothness)
    range_len = max_t - min_t + 1
    sigma = max(1.0, range_len / 100.0)  # smaller divisor -> stronger smoothing
    radius = max(1, int(3 * sigma))
    _x = np.arange(-radius, radius + 1)
    _kernel = np.exp(-0.5 * (_x / sigma) ** 2)
    _kernel = _kernel / _kernel.sum()

    prob_arr = df_full["prob"].to_numpy()
    smoothed = np.convolve(prob_arr, _kernel, mode="same")
    df_full["smoothed"] = smoothed

    # Adaptive downsampling / binning for plotting to avoid extremely large HTML payloads
    # Keep at most max_plot_points points in the exported plot
    max_plot_points = 500
    if range_len > max_plot_points:
        bin_size = int(math.ceil(range_len / max_plot_points))
        # Create bins that partition the integer range; right=False ensures contiguous ranges
        bins = list(range(min_t, max_t + bin_size, bin_size))
        df_full["bin"] = pd.cut(df_full["time_to_recover"], bins=bins, right=False)
        binned = df_full.groupby("bin").agg(
            time_to_recover=("time_to_recover", "mean"),
            prob=("prob", "sum"),
            smoothed=("smoothed", "mean"),
        ).reset_index(drop=True)
        plot_df = binned
    else:
        plot_df = df_full

    # Calculate expected time, std, median, p90 from raw `result` (more accurate than binned)
    # Ensure probabilities sum to 1 (or use normalization if not)
    _total_prob = sum(distribution.values())
    if _total_prob <= 0:
        expected_time = float('nan')
        std_time = float('nan')
        median_time = float('nan')
        p90_time = float('nan')
    else:
        # normalize if necessary
        items = sorted(distribution.items())
        times = np.array([t for t, p in items], dtype=float)
        probs = np.array([p for t, p in items], dtype=float)
        if abs(_total_prob - 1.0) > 1e-8:
            probs = probs / _total_prob

        expected_time = float(np.sum(times * probs))
        var_time = float(np.sum(((times - expected_time) ** 2) * probs))
        std_time = math.sqrt(var_time)

        # median and percentile via cumulative sum
        cumsum = np.cumsum(probs)
        median_time = float(times[np.searchsorted(cumsum, 0.5)])
        p90_time = float(times[np.searchsorted(cumsum, 0.9)])

    # Print summary statistics
    print(f"Expected time to recover: {expected_time:.3f}")
    print(f"Std dev: {std_time:.3f}")
    print(f"Median: {median_time:.3f}, 90th percentile: {p90_time:.3f}")

    # Determine x-axis tick spacing to keep labels readable (about 20 ticks max)
    plot_range_len = int(plot_df["time_to_recover"].max() - plot_df["time_to_recover"].min() + 1)
    tick_step = max(1, int(plot_range_len / 20))
    # Use integer tickvals where possible
    tickvals = list(range(int(plot_df["time_to_recover"].min()), int(plot_df["time_to_recover"].max()) + 1, tick_step))

    # Build figure: bar + smoothed line overlay for readability
    fig = go.Figure()
    fig.add_trace(
        go.Bar(
            x=plot_df["time_to_recover"],
            y=plot_df["prob"],
            name="probability",
            marker_color="steelblue",
            opacity=0.6,
            hovertemplate="Time=%{x}<br>Probability=%{y:.6f}<extra></extra>",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=plot_df["time_to_recover"],
            y=plot_df["smoothed"],
            mode="lines",
            name=f"smoothed (gaussian, sigma={sigma:.2f})",
            line=dict(color="firebrick", width=2),
            hovertemplate="Time=%{x}<br>Smoothed=%{y:.6f}<extra></extra>",
        )
    )

    # Add vertical lines/annotations for expected and median if finite
    ymax = max(float(plot_df["prob"].max()), float(plot_df["smoothed"].max()))
    if not math.isnan(expected_time):
        fig.add_shape(type="line", x0=expected_time, x1=expected_time, y0=0, y1=ymax * 1.05,
                    line=dict(color="green", width=2, dash="dash"), name="expected")
        fig.add_annotation(x=expected_time, y=ymax * 1.05, text=f"E={expected_time:.2f}", showarrow=True, arrowhead=2, ax=0, ay=-30)
    if not math.isnan(median_time):
        fig.add_shape(type="line", x0=median_time, x1=median_time, y0=0, y1=ymax * 1.05,
                    line=dict(color="orange", width=2, dash="dot"), name="median")
        fig.add_annotation(x=median_time, y=ymax * 1.05, text=f"median={median_time:.0f}", showarrow=True, arrowhead=2, ax=0, ay=-50)

    fig.update_layout(
        title="Time-to-Recover Distribution (smoothed)",
        xaxis_title="Time to recover (steps)",
        yaxis_title="Probability",
        template="plotly_white",
        bargap=0.1,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    )
    fig.update_xaxes(tickmode="array", tickvals=tickvals)

    # Export a lightweight HTML fragment when falling back; use full_html=False to avoid large wrapper
    try:
        fig.show()
    except Exception as e:
        msg = str(e)
        if "nbformat" in msg or "Mime type rendering requires nbformat" in msg:
            print("nbformat isn't available or is too old in this environment — falling back to compact HTML export for the figure.")
            html = pio.to_html(fig, include_plotlyjs='cdn', full_html=False)
            display(HTML(html))
        else:
            raise


In [3]:
def visual_test(
    producer_constructor: Callable[[int, int, int], Producer],
    recoverer_constructor: Callable[[int, int], Recoverer],
    N: int,
    D: int,
    passes: int = 100,
    sample_burst_size: Callable[[], int] = lambda: random.randint(1, 10),
    sample_data_size: Callable[[], int] = lambda: random.randint(3, 5),
    iters_bound: int = 300,
):
    visualize_distribution(
        test(
            producer_constructor,
            recoverer_constructor,
            N,
            D,
            passes,
            sample_burst_size,
            sample_data_size,
            iters_bound,
        )
    )

In [4]:
from impls.matrix import producer_constructor, recoverer_constructor

visual_test(
    producer_constructor,
    recoverer_constructor,
    N=5,
    D=9999360,
    passes=10000,
)

 87%|████████▋ | 8701/10000 [00:11<00:01, 716.66it/s]

Failed to recover 2691660 within iteration bound


100%|██████████| 10000/10000 [00:12<00:00, 783.77it/s]


1 failures.
Expected time to recover: 30.473
Std dev: 15.739
Median: 26.000, 90th percentile: 51.000
nbformat isn't available or is too old in this environment — falling back to compact HTML export for the figure.


In [5]:
# from impls.zero_log2 import producer_constructor, recoverer_constructor

# visual_test(
#     producer_constructor,
#     recoverer_constructor,
#     N=5,
#     D=9999360,
#     passes=10000,
# )