In [29]:
from __future__ import annotations

import jax.numpy as jnp
from jax.scipy.stats import norm
from nptyping import NDArray, Shape, Int, Float, Bool
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display, HTML


def load_mathjax():
    """Load MathJax library in JupyterLab to enable LaTeX rendering in Plotly charts.

    This function checks if the MathJax library is already loaded, and if not,
    it loads the library from the provided CDN link.
    """
    display(HTML("""
        <script>
            if (typeof MathJax === 'undefined') {
                var script = document.createElement('script');
                script.type = 'text/javascript';
                script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-AMS-MML_HTMLorMML';
                document.head.appendChild(script);
            }
        </script>
    """))
# Load MathJax
load_mathjax()

In [33]:
x: NDArray[Float] = jnp.linspace(-3, 3, 500)

data: dict['str', NDArray] = {
    'x': x,
    'pdf': norm.pdf(x)
}
fig = px.scatter(data, x='x', y='pdf')
fig.show()

data: dict['str', NDArray] = {
    'x': x,
    'cdf': norm.cdf(x)
}
fig = px.scatter(data, x='x', y='cdf')
fig.show()

In [32]:
def add_shaded_area(fig, x, y, mask, fillcolor='rgba(0, 0, 255, 0.2)'):
    """
    Add a shaded area to the given figure using the specified x, y data and mask.

    Parameters
    ----------
    fig : plotly.graph_objs.Figure
        The figure to which the shaded area will be added.
    x : numpy.array
        The x-values of the data points.
    y : numpy.array
        The y-values of the data points.
    mask : numpy.array
        A boolean mask indicating which points should be included in the shaded area.
    fillcolor : str, optional
        The fill color for the shaded area (default is 'rgba(0, 0, 255, 0.2)').
    """
    fig.add_trace(
        go.Scatter(
            x=jnp.concatenate([x[mask], x[mask][::-1]]),
            y=jnp.concatenate([jnp.zeros_like(x[mask]), y[mask][::-1]]),
            fill='toself',
            fillcolor=fillcolor,
            line=dict(width=0),
            showlegend=False
        )
    )

x: NDArray[Float] = jnp.linspace(-3, 3, 500)
data: dict[str, NDArray] = {
    'x': x,
    'pdf': norm.pdf(x)
}

# Calculate the interval for the shaded area
left_bound: float = norm.ppf(0.025)
right_bound: float = norm.ppf(0.975)

# Create masks for the shaded areas
mask_left: NDArray[Bool] = x < left_bound
mask_right: NDArray[Bool] = x > right_bound
fig = px.scatter(data, x='x', y='pdf')
add_shaded_area(fig, x=x, y=data['pdf'], mask=mask_left)
add_shaded_area(fig, x=x, y=data['pdf'], mask=mask_right)
fig.update_xaxes(
    tickvals=[left_bound, 0, right_bound], 
    ticktext=[r'$\Phi^{-1}(\alpha/2)$', r'$0$', r'$\Phi^{-1}(1-\alpha/2)$'])
fig.show()

# Create a mask for the shaded area
mask: NDArray[Bool] = (x >= left_bound) & (x <= right_bound)
fig = px.scatter(data, x='x', y='pdf')
add_shaded_area(fig, x=x, y=data['pdf'], mask=mask)
fig.update_xaxes(
    tickvals=[left_bound, 0, right_bound], 
    ticktext=[r'$\Phi^{-1}(\alpha/2)$', r'$0$', r'$\Phi^{-1}(1-\alpha/2)$'])
fig.show()