In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.stats import multinomial, multivariate_hypergeom
from matplotlib.patches import Rectangle
import ipywidgets as widgets
from IPython.display import display, HTML
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.ticker import ScalarFormatter, FuncFormatter, MultipleLocator, MaxNLocator


# Set the style
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=plt.cm.Dark2.colors)
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['text.color'] = 'black'
plt.rcParams['axes.labelcolor'] = 'black'
plt.rcParams['xtick.color'] = 'black'
plt.rcParams['ytick.color'] = 'black'
plt.rcParams['font.family'] = 'monospace' 
plt.rcParams['font.size'] = 12  # font size
plt.rcParams['font.weight'] = 'normal'  # 'bold', 'light', 'normal'

def format_number(prob):
    if prob >= 0.001:
        return "{:.4f}".format(prob)
    else:
        return "{:.1e}".format(prob)

def custom_formatter(x, pos):
    # If x is close enough to 0, format it as 0
    if abs(x) < 1e-15:
        return "0"
    # If x has more than 4 zeros after decimal, use scientific notation
    elif abs(x) < 1e-4:
        return "{:.1e}".format(x)
    # Else, use default formatting with 3 decimal places
    else:
        return '{:.4f}'.format(x)
    
# Functions to plot each 3D pmf
def plot_pmf_2d(x1_value, x2_value, probs_x1_x2, probs, n, title, details):
    fig = plt.figure(figsize=(12, 4), dpi=400)
    
    # Create a GridSpec layout for the 3D plot and two subplots with spacing.
    gs = fig.add_gridspec(2, 3, width_ratios=[6, 3, 4], height_ratios=[1, 1])


    ax = plt.axes(projection='3d')
    dx = dy = 1

    custom_cmap = plt.cm.viridis(np.linspace(0.3, 1, 512))  
    custom_cmap = mcolors.ListedColormap(custom_cmap)
    
    valid_x1 = []
    valid_x2 = []
    valid_probs = []
    for x1 in range(n+1):
        for x2 in range(n+1):
            if x2 <= n - x1 and probs.get((x1, x2), 0) > 0:  # This is the new condition to ensure x2 is not greater than x1
                valid_x1.append(x1)
                valid_x2.append(x2)
                valid_probs.append(probs.get((x1, x2), 0))

    norm = mcolors.Normalize(vmin=min(valid_probs), vmax=max(valid_probs))
    colors = custom_cmap(norm(valid_probs))
    
    # Modify the color of the specific bar to red.
    for idx, (x1, x2) in enumerate(zip(valid_x1, valid_x2)):
        if x1 == x1_value and x2 == x2_value:
            colors[idx] = mcolors.to_rgba('red')
            break
        
    ax.bar3d(valid_x1, valid_x2, np.zeros_like(valid_probs), dx, dy, valid_probs, shade=False, color=colors, zorder=10, edgecolor = 'black', linewidth=0.1)
    
    # Adding 2D projections
    sum_x1 = {}
    sum_x2 = {}
    for x1 in range(n+1):
        for x2 in range(n+1):
            sum_x1[x2] = sum_x1.get(x2, 0) + probs.get((x1, x2), 0)
            sum_x2[x1] = sum_x2.get(x1, 0) + probs.get((x1, x2), 0)

    # For 2D projection bars
    norm_2d_x1 = mcolors.Normalize(vmin=min(list(sum_x1.values())), vmax=max(list(sum_x1.values())))
    norm_2d_x2 = mcolors.Normalize(vmin=min(list(sum_x2.values())), vmax=max(list(sum_x2.values())))

    # For X1 2D projection
    ax.bar3d(-1*np.ones_like(list(sum_x1.keys())), list(sum_x1.keys()), np.zeros_like(list(sum_x1.values())), 0, 1.0, list(sum_x1.values()), shade=False, color=custom_cmap(norm_2d_x1(list(sum_x1.values()))), alpha=0.8, edgecolor = 'black', linewidth=0.1, zorder=1)

    # For X2 2D projection
    ax.bar3d(list(sum_x2.keys()), -1*np.ones_like(list(sum_x2.keys())), np.zeros_like(list(sum_x2.values())), 1.0, 0, list(sum_x2.values()), shade=False, color=custom_cmap(norm_2d_x2(list(sum_x2.values()))), alpha=0.8, edgecolor = 'black', linewidth=0.1, zorder=1)

    # Setting the grid line thickness for X, Y, and Z axes
    ax.xaxis._axinfo["grid"]['linewidth'] = 0.2
    ax.yaxis._axinfo["grid"]['linewidth'] = 0.2
    ax.zaxis._axinfo["grid"]['linewidth'] = 0.2
    
    # Add legend to indicate the red bar's probability.
    red_patch = Rectangle((0, 0), 1, 1, fc="red", edgecolor='black', linewidth=0.05)
    legend = ax.legend([red_patch], [f'P(X1={x1_value}, X2={x2_value}) = {format_number(probs_x1_x2)}'], loc='upper left', fontsize=4)
    legend.get_frame().set_linewidth(0.5)  # Thinner legend edge

    ax.set_xlabel('X1=x1', fontsize=4, labelpad=-10)
    ax.set_ylabel('X2=x2', fontsize=4, labelpad=-10)
    ax.set_zlabel('P(X1=x1, X2=x2)', fontsize=4)
    ax.zaxis.labelpad = -10
    ax.zaxis._axinfo['label']['space_factor'] = 1.1
    ax.tick_params(axis='both', which='major', labelsize=4, pad=-3)
    
    # Thinner axis lines
    ax.xaxis.line.set_linewidth(0.2)
    ax.yaxis.line.set_linewidth(0.2)
    ax.zaxis.line.set_linewidth(0.2)
    
    # ax.set_title(title, fontsize=6)
    ax.view_init(30, 30)
    ax2 = fig.add_subplot(gs[0, 2])
    ax3 = fig.add_subplot(gs[1, 2])

    # Upper subplot: P(X1=x1, X2={x2_value})
    for x1, prob in enumerate([probs.get((x1, x2_value), 0) for x1 in range(n+1)]):
        ax2.bar(x1, prob, width=1, color=custom_cmap(norm(prob)), edgecolor='black', linewidth=0.1)  # Change linewidth and width
        if x1 == x1_value:
            ax2.scatter(x1, prob, color='red', s=5, zorder=3)

    # Lower subplot: P(X1={x1_value}, X2=x2)
    for x2, prob in enumerate([probs.get((x1_value, x2), 0) for x2 in range(n+1)]):
        ax3.bar(x2, prob, width=1, color=custom_cmap(norm(prob)), edgecolor='black', linewidth=0.1)  # Change linewidth and width
        if x2 == x2_value:
            ax3.scatter(x2, prob, color='red', s=5, zorder=3)

    ax2.set_xlabel('X1=x1', fontsize=4)
    ax2.set_ylabel(f'P(X1=x1, X2={x2_value})', fontsize=4)
    ax2.tick_params(axis='both', which='major', labelsize=4)
    ax2.spines['top'].set_linewidth(0.2)
    ax2.spines['bottom'].set_linewidth(0.2)
    ax2.spines['left'].set_linewidth(0.2)
    ax2.spines['right'].set_linewidth(0.2)

    ax3.set_xlabel('X2=x2', fontsize=4)
    ax3.set_ylabel(f'P(X1={x1_value}, X2=x2)', fontsize=4)
    ax3.tick_params(axis='both', which='major', labelsize=4)
    ax3.spines['top'].set_linewidth(0.2)
    ax3.spines['bottom'].set_linewidth(0.2)
    ax3.spines['left'].set_linewidth(0.2)
    ax3.spines['right'].set_linewidth(0.2)

    ax2.xaxis.set_tick_params(width=0.2)
    ax2.yaxis.set_tick_params(width=0.2)
    ax3.xaxis.set_tick_params(width=0.2)
    ax3.yaxis.set_tick_params(width=0.2)
    
    ax2.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
    ax2.ticklabel_format(style='plain', axis='y')

    ax3.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
    ax3.ticklabel_format(style='plain', axis='y')
    
    ax2.yaxis.set_major_formatter(FuncFormatter(custom_formatter))
    ax3.yaxis.set_major_formatter(FuncFormatter(custom_formatter))
    
    # Set integer ticks for the 3D plot's X and Y axes with max 8 tickers
    ax.xaxis.set_major_locator(MaxNLocator(8, integer=True))
    ax.yaxis.set_major_locator(MaxNLocator(8, integer=True))

    # Set integer ticks for the 2D subplot's X axes with max 8 tickers
    ax2.xaxis.set_major_locator(MaxNLocator(8, integer=True))
    ax3.xaxis.set_major_locator(MaxNLocator(8, integer=True))
    
    plt.tight_layout()
    plt.show()

    for detail in details:
        print("    " + detail)

def plot_trinomial(n, p1, p2, x1, x2):
    p3 = 1 - p1 - p2
    probabilities = multinomial.pmf([[x1_val, x2_val, n - x1_val - x2_val] for x1_val in range(n+1) for x2_val in range(n+1) if x1_val + x2_val <= n], n, [p1, p2, p3])
    probs = {}
    # Add a counter for the probabilities array
    prob_counter = 0 

    for x1_val in range(n+1):
        for x2_val in range(n+1):
            if x1_val + x2_val <= n:
                probs[x1_val, x2_val] = probabilities[prob_counter]
                prob_counter += 1
    
    title = 'Trinomial PMF'
    details = [
        f"Equation: P(X1=x1, X2=x2) = C(n, x1, x2) * p1^x1 * p2^x2 * p3^(n-x1-x2)",
        f"P(X1={x1}, X2={x2}) = C({n}, {x1}, {x2}) * {p1:.2f}^{x1} * {p2:.2f}^{x2} * {p3:.2f}^{n-x1-x2} = {format_number(probs[x1, x2])}"
    ]
    plot_pmf_2d(x1, x2, probs[x1, x2], probs, n, title, details)

def plot_trihypergeometric(N1, N2, N3, n, x1, x2):
    # Calculate the PMF for the given parameters
    prob_x1_x2 = multivariate_hypergeom.pmf([x1, x2, n-x1-x2], [N1, N2, N3], n)
    
    # Generate the PMF for all possible combinations
    probs = {}
    for x1_val in range(min(N1, n)+1):
        for x2_val in range(min(N2, n-x1_val)+1):
            x3_val = n - x1_val - x2_val
            if x3_val <= N3:
                prob = multivariate_hypergeom.pmf([x1_val, x2_val, x3_val], [N1, N2, N3], n)
                probs[x1_val, x2_val] = prob
            else:
                probs[x1_val, x2_val] = 0
    
    # Use the plot_pmf_2d function to display the PMF
    title = 'Tri-hypergeometric PMF'
    details = [
        f"Equation: P(X1=x1, X2=x2) = C(N1, x1) * C(N2, x2) * C(N3, n-x1-x2) / C(N1+N2+N3, n)",
        f"P(X1={x1}, X2={x2}) = C({N1}, {x1}) * C({N2}, {x2}) * C({N3}, {n-x1-x2}) / C({N1+N2+N3}, {n}) = {format_number(prob_x1_x2)}"
    ]
    plot_pmf_2d(x1, x2, prob_x1_x2, probs, n, title, details)

# Widgets
distribution_dropdown = widgets.Dropdown(
    options=["Select a distribution", "Trinomial", "Tri-hypergeometric"],
    value="Select a distribution",
    description='Distribution:'
)

output_container = widgets.VBox([])  # Container to hold sliders and plots
slider_layout = widgets.Layout(width='100%')

def display_distribution_widgets(change):
    if change['new'] == "Trinomial":
        n_slider_trinom = widgets.IntSlider(value=10, min=1, max=50, step=1, description='n (trials):', continuous_update=False, style={'description_width': '200px'}, layout=widgets.Layout(width='50%'))
        p1_slider_trinom = widgets.FloatSlider(value=0.3, min=0.01, max=0.99, step=0.01, description='p1 (prob. of outcome 1):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        p2_slider_trinom = widgets.FloatSlider(value=0.3, min=0.01, max=0.99, step=0.01, description='p2 (prob. of outcome 2):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        x1_slider_trinom = widgets.IntSlider(value=3, min=0, max=n_slider_trinom.value, step=1, description='x1 (# of outcome 1):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        x2_slider_trinom = widgets.IntSlider(value=3, min=0, max=n_slider_trinom.value, step=1, description='x2 (# of outcome 2):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)

        def update_trinom_ranges(*args):
            # Ensure the probabilities sum is less than 1
            p1_slider_trinom.max = 1 - p2_slider_trinom.value
            p2_slider_trinom.max = 1 - p1_slider_trinom.value
            
            if n_slider_trinom.value - x1_slider_trinom.value < 0 or n_slider_trinom.value - x2_slider_trinom.value < 0:
                x1_slider_trinom.value = 0
                x2_slider_trinom.value = 0
            else:
                x1_slider_trinom.max = n_slider_trinom.value - x2_slider_trinom.value
                x2_slider_trinom.max = n_slider_trinom.value - x1_slider_trinom.value

        n_slider_trinom.observe(update_trinom_ranges, 'value')
        p1_slider_trinom.observe(update_trinom_ranges, 'value')
        p2_slider_trinom.observe(update_trinom_ranges, 'value')
        x1_slider_trinom.observe(update_trinom_ranges, 'value')
        x2_slider_trinom.observe(update_trinom_ranges, 'value')

        trinom_sliders1 = widgets.HBox([p1_slider_trinom, p2_slider_trinom])
        trinom_sliders2 = widgets.HBox([x1_slider_trinom, x2_slider_trinom])
        trinom_interactive = widgets.interactive_output(plot_trinomial, {'n': n_slider_trinom, 'p1': p1_slider_trinom, 'p2': p2_slider_trinom, 'x1': x1_slider_trinom, 'x2': x2_slider_trinom})

        output_container.children = [n_slider_trinom, trinom_sliders1, trinom_sliders2, trinom_interactive]
    elif change['new'] == "Tri-hypergeometric":
        N1_slider = widgets.IntSlider(value=10, min=1, max=30, step=1, description='N1 (# of items in type 1):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        N2_slider = widgets.IntSlider(value=10, min=0, max=30, step=1, description='N2 (# of items in type 2):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        N3_slider = widgets.IntSlider(value=12, min=0, max=30, step=1, description='N3 (# of items in type 3):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        n_slider = widgets.IntSlider(value=10, min=1, max=N1_slider.value+N2_slider.value+N3_slider.value, step=1, description='n (# of picks):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        x1_slider = widgets.IntSlider(value=5, min=0, max=N1_slider.value, step=1, description='x1 (# of picked type 1 items):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)
        x2_slider = widgets.IntSlider(value=4, min=0, max=N2_slider.value, step=1, description='x2 (# of picked type 2 items):', continuous_update=False, style={'description_width': '200px'}, layout=slider_layout)

        # Update the range of sliders based on other sliders
        def update_ranges(*args):
            # Update maximum picks based on total items
            n_slider.max = N1_slider.value + N2_slider.value + N3_slider.value
            
            # Update maximum picked items from type 1 and type 2 based on total items of type 1 and type 2 and total picks
            x1_slider.max = min(N1_slider.value, n_slider.value)
            x2_slider.max = min(N2_slider.value, n_slider.value - x1_slider.value)
            
        # Observe changes in sliders to update their ranges
        N1_slider.observe(update_ranges, 'value')
        N2_slider.observe(update_ranges, 'value')
        N3_slider.observe(update_ranges, 'value')
        n_slider.observe(update_ranges, 'value')
        x1_slider.observe(update_ranges, 'value')
        x2_slider.observe(update_ranges, 'value')

        # Interactive plot for the trihypergeometric distribution
        trihypergeometric_interactive = widgets.interactive_output(plot_trihypergeometric, {'N1': N1_slider, 'N2': N2_slider, 'N3': N3_slider, 'n': n_slider, 'x1': x1_slider,'x2': x2_slider})

        # Set up the layout
        trihyper_sliders1 = widgets.HBox([N1_slider, N2_slider, N3_slider])
        trihyper_sliders2 = widgets.HBox([n_slider, x1_slider, x2_slider])
        output_container.children = [trihyper_sliders1, trihyper_sliders2, trihypergeometric_interactive]
        
# Observer
distribution_dropdown.observe(display_distribution_widgets, names='value')

# Display initial dropdown and the output container
display(distribution_dropdown, output_container)