In [5]:
from __future__ import annotations

import jax.numpy as jnp
from scipy import stats
from scipy.stats import norm
from nptyping import NDArray, Shape, Int, Float, Bool
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import display, HTML


def load_mathjax():
    """Load MathJax library in JupyterLab to enable LaTeX rendering in Plotly charts.

    This function checks if the MathJax library is already loaded, and if not,
    it loads the library from the provided CDN link.
    """
    display(HTML("""
        <script>
            if (typeof MathJax === 'undefined') {
                var script = document.createElement('script');
                script.type = 'text/javascript';
                script.src = 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-AMS-MML_HTMLorMML';
                document.head.appendChild(script);
            }
        </script>
    """))
# Load MathJax
load_mathjax()

In [23]:
def plot_mixture_norm_probability_density(means, std_deviations, distribution_weights, line_width=2):
    """
    Plot the probability density function of a mixture of normal distributions.

    Parameters
    ----------
    means : list of float
        List of means for each normal distribution in the mixture.
    std_deviations : list of float
        List of standard deviations for each normal distribution in the mixture.
    distribution_weights : list of float
        List of weights for each normal distribution in the mixture. Must sum to 1.
    line_width : float, optional, default: 2
        Line width for the probability density plot.

    Raises
    ------
    ValueError
        If lengths of means, std_deviations, and distribution_weights are not equal,
        if any element in std_deviations is negative,
        or if any element in distribution_weights is negative and their sum is not equal to 1.

    Examples
    --------
    >>> means = [0, 2]
    >>> std_deviations = [1, 0.05]
    >>> distribution_weights = [0.5, 0.5]
    >>> plot_mixture_norm_probability_density(means, std_deviations, distribution_weights)
    """
    def _validate():
        if len(means) != len(std_deviations) or len(means) != len(distribution_weights):
            raise ValueError("Lengths of means, std_deviations, and distribution_weights must be equal.")
        if not all(std_dev >= 0 for std_dev in std_deviations):
            raise ValueError("All elements in std_deviations must be non-negative.")
        if not all(weight >= 0 for weight in distribution_weights):
            raise ValueError("All elements in distribution_weights must be non-negative.")
        weight_sum = jnp.sum(distribution_weights)
        if not jnp.isclose(weight_sum, 1, rtol=1e-5, atol=1e-8):
            raise ValueError("The sum of distribution_weights must be equal to 1.")
        
    _validate()
    # Define normal distributions
    normal_distributions = [norm(loc=means[i], scale=std_deviations[i]) for i in range(len(means))]
    # Define a set of x points for graphing.
    x_values = jnp.linspace(-2, 2 * max(means), 600)
    # Combine the two distributions by their weights, evaluated at the x points.
    probability_density = sum(distribution_weights[i] * normal_distributions[i].pdf(x_values) for i in range(len(means)))
    # Calculate the mean of the final distribution.
    mean_probability_density = jnp.mean(x_values * probability_density)
    # Create a data frame for plotly.express
    data = {
        'x': x_values,
        'probability_density': probability_density
    }
    # Create the plot with plotly.express
    fig = px.line(data, x='x', y='probability_density')
    # Add a star marker for the mean
    trace = go.Scatter(
        x=[mean_probability_density],
        mode='markers',
        marker=dict(symbol='star', size=10, color='red'),
        name='mean'
    )
    fig.add_trace(trace)

    fig.update_xaxes(title_text='$x$')
    fig.update_yaxes(title_text='$p(x)$')
    fig.show()


# Example usage
means = [0, 2]
std_deviations = [1, 0.05]
distribution_weights = [0.5, 0.5]
plot_mixture_norm_probability_density(means, std_deviations, distribution_weights)


In [20]:
def plot_combined_distributions(
        x_values: NDArray[Float], 
        distributions: list[stats.norm_gen], 
        weights: list[float], 
        labels: list[str] | None = None, 
        colors: list[str] | None = None, 
        linestyle: str = 'solid') -> None:
    """
    Plots a combination of normal distributions given their weights, colors, and labels.
    
    Parameters
    ----------
    x_values : numpy.ndarray
        The x values used for plotting the distributions.
    distributions : list[scipy.stats.norm_gen]
        A list of normal distributions to be combined and plotted.
    weights : list[float]
        A list of weights for each distribution; the length should match the number of distributions.
    labels : list[str], optional
        A list of labels for each distribution; the length should match the number of distributions (default is None).
    colors : list[str], optional
        A list of colors for each distribution; the length should match the number of distributions (default is None).
    linestyle : str, optional
        Line style for the plotted distributions (default is 'solid').
    
    Raises
    ------
    ValueError
        If the lengths of distributions and weights are not equal.
        If any element in weights is negative.
        If the sum of weights is not equal to 1.
    """
    def _validate():
        if len(distributions) != len(weights):
            raise ValueError("Lengths of distributions and weights must be equal.")
        if not all(weight >= 0 for weight in weights):
            raise ValueError("All elements in weights must be non-negative.")
        weight_sum = sum(weights)
        if not jnp.isclose(weight_sum, 1, rtol=1e-5, atol=1e-8):
            raise ValueError("The sum of weights must be equal to 1.")
    _validate()

    fig = go.Figure()
    num_dists: int = len(distributions)
    labels = [None]*num_dists if labels is None else labels
    colors = [None]*num_dists if colors is None else colors
    for distribution, weight, label, color in zip(distributions, weights, labels, colors):
        probability_density: NDArray[Float] = weight * distribution.pdf(x_values)
        fig.add_trace(go.Scatter(x=x_values, y=probability_density, mode='lines', line=dict(color=color, dash=linestyle), name=label))
    
    label: str | None = f"$ {'+'.join(labels).replace('$', '')} $" if labels is not None else None
    probability_density: NDArray[Float] = sum(weights[i] * distributions[i].pdf(x_values) for i in range(num_dists))
    fig.add_trace(go.Scatter(x=x_values, y=probability_density, mode='lines', line=dict(color='black', dash='dashdot'), name=label))

    fig.update_xaxes(title_text='$x$')
    fig.update_yaxes(title_text='$p(x)$')
    fig.show()
    

# Data for both distributions
means = [0, 2]
std_deviations = [0.5, 0.5]
weights = [0.5, 0.5]
x_values = jnp.linspace(-2, 2 * means[1], 600)
distributions = [norm(loc=means[i], scale=std_deviations[i]) for i in range(2)]

colors = ["green", "red"]
labels = ["$0.5\mathcal{N}(x|0,0.5)$", "$0.5\mathcal{N}(x|2,0.5)$"]

plot_combined_distributions(x_values=x_values, distributions=distributions, weights=weights, labels=labels, colors=colors)

In [23]:
means = [0, 2, 3]
std_deviations = [0.5, 0.3, 0.4]
weights = [0.3, 0.4, 0.3]
x_values = jnp.linspace(-2, 2 * means[2], 600)
distributions = [norm(loc=means[i], scale=std_deviations[i]) for i in range(3)]

labels = ["$0.3\mathcal{N}(x|0,0.5)$", "$0.4\mathcal{N}(x|2,0.3)$", "$0.3\mathcal{N}(x|3,0.4)$"]

plot_combined_distributions(x_values=x_values, distributions=distributions, weights=weights, labels=labels)