In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
import ipywidgets as widgets
import matplotlib.colors as mcolors
from matplotlib.ticker import ScalarFormatter, FuncFormatter, MultipleLocator, MaxNLocator

# Set the style
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Dark2.colors)
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['text.color'] = 'black'
plt.rcParams['axes.labelcolor'] = 'black'
plt.rcParams['xtick.color'] = 'black'
plt.rcParams['ytick.color'] = 'black'
plt.rcParams['font.family'] = 'monospace' 
plt.rcParams['font.size'] = 12  # font size
plt.rcParams['font.weight'] = 'normal'  # 'bold', 'light', 'normal'

def plot_pdf_3d(X, Y, Z, x1_value, x2_value, title, details, Z_selected, X1_pdf, X2_pdf, x1_values, x2_values):
    fig = plt.figure(figsize=(12, 4), dpi=400)
    
    # Create a GridSpec layout for the 3D plot and two subplots with spacing.
    gs = fig.add_gridspec(2, 3, width_ratios=[6, 3, 4], height_ratios=[1, 1])

    ax = plt.axes(projection='3d')
    custom_cmap = plt.cm.viridis(np.linspace(0.2, 1, 512))
    custom_cmap = mcolors.ListedColormap(custom_cmap)

    # Plotting the 3D PDF with alpha transparency
    surf = ax.plot_surface(X, Y, Z, cmap=custom_cmap, edgecolor='black', zorder=5, alpha=1, linewidth=0.08)
    
    
    # Settings and labels for the 3D plot
    ax.set_xlabel('X1=x1', fontsize=4, labelpad=-10)
    ax.set_ylabel('X2=x2', fontsize=4, labelpad=-10)
    ax.set_zlabel('f(X1=x1, X2=x2)', fontsize=4)
    ax.zaxis.labelpad = -9
    ax.zaxis._axinfo['label']['space_factor'] = 1.1
    ax.tick_params(axis='both', which='major', labelsize=4, pad=-3)
    ax.xaxis.line.set_linewidth(0.2)
    ax.yaxis.line.set_linewidth(0.2)
    ax.zaxis.line.set_linewidth(0.2)
    ax.view_init(30, 30)
    
    # Dot at (X1=x1, Z=0)
    ax.scatter([x1_value], [Y.max()+(Y.max()-Y.min())*0.02], [Z.min()], color='red', s=5, zorder=2)

    # Dot at (X2=x2, Z=0)
    ax.scatter([X.max()+(X.max()-X.min())*0.02], [x2_value], [Z.min()], color='red', s=5, zorder=2)

    # Setting the grid line thickness for X, Y, and Z axes
    ax.xaxis._axinfo["grid"]['linewidth'] = 0.2
    ax.yaxis._axinfo["grid"]['linewidth'] = 0.2
    ax.zaxis._axinfo["grid"]['linewidth'] = 0.2

    # Settings for the 2D subplots
    ax2 = fig.add_subplot(gs[0, 2])
    ax3 = fig.add_subplot(gs[1, 2])
    
    ax2.plot(x1_values, X1_pdf)
    ax2.fill_between(x1_values, X1_pdf, color=(0.92, 0.92, 0.92))
    # ax2.scatter(x1_value, Z_selected, color='red', s=20, zorder=3)
    ax3.plot(x2_values, X2_pdf)
    ax3.fill_between(x2_values, X2_pdf, color=(0.92, 0.92, 0.92))
    # ax3.scatter(x2_value, Z_selected, color='red', s=20, zorder=3)

    ax2.set_xlabel('X1', fontsize=4)
    ax2.set_ylabel(f'f(X1=x1,X2={x2_value:.2f})', fontsize=4)
    ax2.tick_params(axis='both', which='major', labelsize=4)
    ax2.spines['top'].set_linewidth(0.2)
    ax2.spines['bottom'].set_linewidth(0.2)
    ax2.spines['left'].set_linewidth(0.2)
    ax2.spines['right'].set_linewidth(0.2)

    ax3.set_xlabel('X2', fontsize=4)
    ax3.set_ylabel(f'f(X1={x1_value:.2f},X2=x2)', fontsize=4)
    ax3.tick_params(axis='both', which='major', labelsize=4)
    ax3.spines['top'].set_linewidth(0.2)
    ax3.spines['bottom'].set_linewidth(0.2)
    ax3.spines['left'].set_linewidth(0.2)
    ax3.spines['right'].set_linewidth(0.2)

    ax2.xaxis.set_tick_params(width=0.2)
    ax2.yaxis.set_tick_params(width=0.2)
    ax3.xaxis.set_tick_params(width=0.2)
    ax3.yaxis.set_tick_params(width=0.2)
    
    ax.xaxis.set_major_locator(MaxNLocator(8, integer=True))
    ax.yaxis.set_major_locator(MaxNLocator(8, integer=True))
    ax.zaxis.set_major_locator(MaxNLocator(6, integer=True))
    ax2.xaxis.set_major_locator(MaxNLocator(8, integer=True))
    ax3.xaxis.set_major_locator(MaxNLocator(8, integer=True))

    plt.tight_layout()
    plt.show()

    for detail in details:
        print("    " + detail)
        
def plot_bivariate_gaussian(mu1, mu2, sigma1, sigma2, rho, x1, x2):
    title = 'Bivariate Gaussian PDF'
    details = [
        f"Equation: f(x1, x2) = (1 / (2π*σ1*σ2*√(1-ρ^2))) * e^−(((x1-μ1/σ1)^2 + (x2-μ2/σ2)^2 − 2ρ*(x1-μ1/σ1)(x2-μ2/σ2)) / 2(1−ρ^2))",
    ]

    # Dynamically determine the range based on sigma values
    x_range = 3 * sigma1
    y_range = 3 * sigma2
    max_range = max(x_range, y_range)

    x1_values = np.linspace(mu1 - max_range, mu1 + max_range, 1000)
    x2_values = np.linspace(mu2 - max_range, mu2 + max_range, 1000)
    X, Y = np.meshgrid(x1_values, x2_values)

    # Bivariate Gaussian distribution function
    covariance_matrix = [[sigma1**2, rho*sigma1*sigma2], [rho*sigma1*sigma2, sigma2**2]]
    pos = np.dstack((X, Y))
    Z = multivariate_normal([mu1, mu2], covariance_matrix).pdf(pos)
    
    Z_selected = multivariate_normal([mu1, mu2], covariance_matrix).pdf(np.array([[x1, x2]]))
    
    X1_pdf = multivariate_normal([mu1, mu2], covariance_matrix).pdf(np.dstack((x1_values, np.full_like(x1_values, x2))))
    X2_pdf = multivariate_normal([mu1, mu2], covariance_matrix).pdf(np.dstack((np.full_like(x2_values, x1), x2_values)))

    plot_pdf_3d(X, Y, Z, x1, x2, title, details, Z_selected, X1_pdf, X2_pdf, x1_values, x2_values)

# Widgets
distribution_dropdown = widgets.Dropdown(
    options=["Select a distribution", "Bivariate Gaussian"],
    value="Select a distribution",
    description='Distribution:'
)

slider_layout = widgets.Layout(width='100%')
output_container = widgets.VBox([])  # Container to hold sliders and plots

def display_distribution_widgets(change):
    if change['new'] == "Bivariate Gaussian":
        mu1_slider = widgets.FloatSlider(value=0, min=-10, max=10, step=0.01, description='μ (X1):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        mu2_slider = widgets.FloatSlider(value=0, min=-10, max=10, step=0.01, description='μ (X2):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        sigma1_slider = widgets.FloatSlider(value=1, min=0.1, max=5, step=0.01, description='σ (X1):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        sigma2_slider = widgets.FloatSlider(value=1, min=0.1, max=5, step=0.01, description='σ (X2):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        rho_slider = widgets.FloatSlider(value=0, min=-0.99, max=0.99, step=0.01, description='ρ (correlation):', continuous_update=False, style={'description_width': '200px'}, layout=widgets.Layout(width='50%'))
        x1_slider_gauss = widgets.FloatSlider(value=0, min=-3, max=3, step=0.01, description='X1:', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        x2_slider_gauss = widgets.FloatSlider(value=0, min=-3, max=3, step=0.01, description='X2:', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        
        gauss_sliders1 = widgets.HBox([x1_slider_gauss, x2_slider_gauss])
        gauss_sliders2 = widgets.HBox([mu1_slider, mu2_slider])
        gauss_sliders3 = widgets.HBox([sigma1_slider, sigma2_slider])
        gauss_sliders4 = widgets.HBox([rho_slider])
        
        def update_x1_bounds(*args):
            mu1 = mu1_slider.value
            sigma1 = sigma1_slider.value
            if x1_slider_gauss.max > mu1 - 3 * sigma1:
                x1_slider_gauss.min = mu1 - 3 * sigma1
                x1_slider_gauss.max = mu1 + 3 * sigma1
            else:
                x1_slider_gauss.max = mu1 + 3 * sigma1
                x1_slider_gauss.min = mu1 - 3 * sigma1
            x1_slider_gauss.value = mu1
            
        def update_x2_bounds(*args):
            mu2 = mu2_slider.value
            sigma2 = sigma2_slider.value
            if x2_slider_gauss.max > mu2 - 3 * sigma2:
                x2_slider_gauss.min = mu2 - 3 * sigma2
                x2_slider_gauss.max = mu2 + 3 * sigma2
            else:
                x2_slider_gauss.max = mu2 + 3 * sigma2
                x2_slider_gauss.min = mu2 - 3 * sigma2
            x2_slider_gauss.value = mu2

        mu1_slider.observe(update_x1_bounds, 'value')
        mu2_slider.observe(update_x2_bounds, 'value')
        sigma1_slider.observe(update_x1_bounds, 'value')
        sigma2_slider.observe(update_x2_bounds, 'value')

        gauss_interactive = widgets.interactive_output(plot_bivariate_gaussian, {'mu1': mu1_slider, 'mu2': mu2_slider, 'sigma1': sigma1_slider, 'sigma2': sigma2_slider, 'rho': rho_slider, 'x1': x1_slider_gauss, 'x2': x2_slider_gauss})
        
        output_container.children = [gauss_sliders1, gauss_sliders2, gauss_sliders3, gauss_sliders4, gauss_interactive]

# Re-observe the dropdown to include the new bivariate Gaussian option
distribution_dropdown.observe(display_distribution_widgets, names='value')

# Display the dropdown and output container again
display(distribution_dropdown, output_container)