In [1]:
%matplotlib inline

import matplotlib.pyplot as plt

from distributions import (
    UniformDist,
    ExponentialDist,
    NormalDist,
    GammaDist,
    BetaDist,
    ChiSquareDist,
)

from plots import (
    plot_pdf,
    plot_cdf,
    plot_pdf_with_interval,
    plot_sample_histogram,
)

from utils import probability_between, suggest_plot_range


In [2]:
import ipywidgets as widgets
from IPython.display import display

In [3]:
def create_distribution(dist_name, **params):
    """
    Create and return a distribution object based on the given name and parameters.

    Examples
    --------
    create_distribution("uniform", a=0, b=5)
    create_distribution("normal", mu=0, sigma=1)
    create_distribution("exponential", lam=2.0)
    """
    name = dist_name.lower()

    if name == "uniform":
        return UniformDist(a=params["a"], b=params["b"])

    elif name == "exponential":
        return ExponentialDist(lam=params["lam"])

    elif name == "normal":
        return NormalDist(mu=params["mu"], sigma=params["sigma"])

    elif name == "gamma":
        return GammaDist(alpha=params["alpha"], beta_param=params["beta_param"])

    elif name == "beta":
        return BetaDist(alpha=params["alpha"], beta_param=params["beta_param"])

    elif name in ["chi-square", "chisquare", "chi2"]:
        return ChiSquareDist(k=params["k"])

    else:
        raise ValueError(f"Unknown distribution name: {dist_name}")


In [4]:
def interactive_continuous_visualizer():
    """
    Create an interactive widget panel to explore continuous distributions.

    Uses:
    - create_distribution (your factory)
    - suggest_plot_range
    - plot_pdf, plot_cdf, plot_pdf_with_interval, plot_sample_histogram

    This does NOT save images; it only shows them for exploration.
    """

    # ---- Widgets ----

    dist_widget = widgets.Dropdown(
        options=["uniform", "exponential", "normal", "gamma", "beta", "chi-square"],
        value="normal",
        description="Dist:",
    )

    # Parameters (we'll reuse them depending on the chosen distribution)
    a_widget = widgets.FloatText(value=0.0, description="a (uniform)")
    b_widget = widgets.FloatText(value=10.0, description="b (uniform)")

    lam_widget = widgets.FloatSlider(
        value=1.0, min=0.1, max=5.0, step=0.1, description="λ (exp)"
        )

    mu_widget = widgets.FloatSlider(
        value=0.0, min=-5.0, max=5.0, step=0.1, description="μ (normal)"
        )
    sigma_widget = widgets.FloatSlider(
        value=1.0, min=0.1, max=5.0, step=0.1, description="σ (normal)"
        )

    alpha_widget = widgets.FloatSlider(
        value=2.0, min=0.1, max=10.0, step=0.1, description="α"
        )
    beta_param_widget = widgets.FloatSlider(
        value=2.0, min=0.1, max=10.0, step=0.1, description="β"
        )

    k_widget = widgets.IntSlider(
        value=4, min=1, max=20, step=1, description="k (chi²)"
        )

    sample_size_widget = widgets.IntSlider(
        value=2000, min=100, max=10000, step=100, description="Samples"
        )

    use_interval_widget = widgets.Checkbox(
        value=True, description="Shade interval instead of histogram"
    )

    a_interval_widget = widgets.FloatText(value=-1.0, description="a (interval)")
    b_interval_widget = widgets.FloatText(value=1.0, description="b (interval)")

    out = widgets.Output()

    # ---- Update function ----

    def _update(dist_name,
               a, b,
               lam,
               mu, sigma,
               alpha, beta_param,
               k,
               sample_size,
               use_interval,
               a_interval, b_interval):

        with out:
            out.clear_output()

            # 1. Decide parameters based on distribution
            if dist_name == "uniform":
                # ensure a < b
                if b <= a:
                    b = a + 1e-6
                params = {"a": a, "b": b}

            elif dist_name == "exponential":
                params = {"lam": lam}

            elif dist_name == "normal":
                params = {"mu": mu, "sigma": sigma}

            elif dist_name == "gamma":
                params = {"alpha": alpha, "beta_param": beta_param}

            elif dist_name == "beta":
                params = {"alpha": alpha, "beta_param": beta_param}

            elif dist_name in ["chi-square", "chisquare", "chi2"]:
                dist_name = "chi-square"
                params = {"k": k}

            else:
                print("Unknown distribution.")
                return

            # 2. Create distribution object
            dist = create_distribution(dist_name, **params)

            # 3. Suggest plotting range
            x_min, x_max = suggest_plot_range(dist_name, **params)

            print(f"Distribution: {dist_name}")
            print(f"Parameters: {params}")
            print(f"Plot range: [{x_min:.3f}, {x_max:.3f}]")

            # 4. Make plots
            fig, axes = plt.subplots(1, 3, figsize=(18, 4))

            # PDF
            plot_pdf(
                dist,
                x_min,
                x_max,
                ax=axes[0],
                title=f"{dist_name.title()} PDF"
            )

            # CDF
            plot_cdf(
                dist,
                x_min,
                x_max,
                ax=axes[1],
                title=f"{dist_name.title()} CDF"
            )

            # Interval or histogram
            if use_interval:
                _, prob = plot_pdf_with_interval(
                    dist,
                    x_min,
                    x_max,
                    a=a_interval,
                    b=b_interval,
                    ax=axes[2],
                    title=f"{dist_name.title()} interval [{a_interval}, {b_interval}]"
                )
                print(f"P({a_interval} <= X <= {b_interval}) = {prob:.4f}")
            else:
                plot_sample_histogram(
                    dist,
                    sample_size,
                    x_min,
                    x_max,
                    ax=axes[2],
                    title=f"{dist_name.title()} histogram"
                )

            plt.tight_layout()
            plt.show()

            print(f"Mean (analytical): {dist.mean():.4f}")
            print(f"Variance (analytical): {dist.var():.4f}")

    # ---- Wire widgets to update function ----

    controls = {
        "dist_name": dist_widget,
        "a": a_widget,
        "b": b_widget,
        "lam": lam_widget,
        "mu": mu_widget,
        "sigma": sigma_widget,
        "alpha": alpha_widget,
        "beta_param": beta_param_widget,
        "k": k_widget,
        "sample_size": sample_size_widget,
        "use_interval": use_interval_widget,
        "a_interval": a_interval_widget,
        "b_interval": b_interval_widget,
    }

    ui_left = widgets.VBox([
        dist_widget,
        widgets.HTML("<b>Common parameters</b>"),
        a_widget, b_widget,
        lam_widget,
        mu_widget, sigma_widget,
        alpha_widget, beta_param_widget,
        k_widget,
    ])

    ui_right = widgets.VBox([
        widgets.HTML("<b>Interval & sampling</b>"),
        use_interval_widget,
        a_interval_widget,
        b_interval_widget,
        sample_size_widget,
    ])

    ui = widgets.HBox([ui_left, ui_right])

    interactive_out = widgets.interactive_output(_update, controls)

    display(ui, out, interactive_out)


In [5]:
interactive_continuous_visualizer()

HBox(children=(VBox(children=(Dropdown(description='Dist:', index=2, options=('uniform', 'exponential', 'norma…

Output()

Output()