In [1]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from math import sqrt

def simulate_and_plot(W0, Y0, X0, Z0, W_birth, Y_birth, W_death, Y_death,
                      X_size, Z_size, X_rate, Z_rate, Time, use_X, use_Z, show_phase_only):
    Q1 = W_death / W_birth
    Q2 = Y_death / Y_birth

    disc_W = (1 - Q1 + Q2)**2 - 4 * Q2
    W_equil1 = W_equil2 = np.nan
    if disc_W >= 0:
        sqrt_disc_W = np.sqrt(disc_W)
        W_equil1 = 0.5 * ((1 - Q1 + Q2) + sqrt_disc_W)
        W_equil2 = 0.5 * ((1 - Q1 + Q2) - sqrt_disc_W)

    disc_Y = (1 - Q2 + Q1)**2 - 4 * Q1
    Y_equil1 = Y_equil2 = np.nan
    if disc_Y >= 0:
        sqrt_disc_Y = np.sqrt(disc_Y)
        Y_equil1 = 0.5 * ((1 - Q2 + Q1) + sqrt_disc_Y)
        Y_equil2 = 0.5 * ((1 - Q2 + Q1) - sqrt_disc_Y)

    X_push = X_rate * X_size

    dt = 0.01
    t = np.arange(0, Time + dt, dt)
    W, Y, X, Z = np.zeros_like(t), np.zeros_like(t), np.zeros_like(t), np.zeros_like(t)
    W[0], Y[0], X[0], Z[0] = W0, Y0, X0, Z0

    for i in range(1, len(t)):
        dW = -W_birth * W[i - 1]**2 * Y[i - 1] + W_birth * W[i - 1] * Y[i - 1] - W_death * W[i - 1]
        dY = -Y_birth * Y[i - 1]**2 * W[i - 1] + Y_birth * W[i - 1] * Y[i - 1] - Y_death * Y[i - 1]
        if use_X:
            dW += X_rate * X_size * (X[i - 1] - W[i - 1])
        if use_Z:
            dY += Z_rate * Z_size * (Z[i - 1] - Y[i - 1])
        dX = X_rate * (W[i - 1] - X[i - 1])
        dZ = Z_rate * (Y[i - 1] - Z[i - 1])

        W[i] = W[i - 1] + dt * dW
        Y[i] = Y[i - 1] + dt * dY
        X[i] = X[i - 1] + dt * dX
        Z[i] = Z[i - 1] + dt * dZ

    if not show_phase_only:
        plt.figure(figsize=(12, 5))

        if use_X:
            plt.plot(t, X, label=r'$X_t$', color='lightgreen')
        if use_Z:
            plt.plot(t, Z, label=r'$Z_t$', color='skyblue')
        # Population trajectories
        plt.plot(t, W, label=r'$W_t$', color='darkgreen')
        plt.plot(t, Y, label=r'$Y_t$', color='darkblue')

        # Equilibrium lines
        if not np.isnan(W_equil1):
            plt.axhline(W_equil1, color='darkgreen', linestyle='--', label=r'$W_{(eq)}^+$' )
        if not np.isnan(W_equil2):
            plt.axhline(W_equil2, color='darkgreen', linestyle='--', label=r'$W_{(eq)}^-$')
        if not np.isnan(Y_equil1):
            plt.axhline(Y_equil1, color='darkblue', linestyle=':', label=r'$Y_{(eq)}^+$')
        if not np.isnan(Y_equil2):
            plt.axhline(Y_equil2, color='darkblue', linestyle=':', label=r'$Y_{(eq)}^-$')

        # Labels and layout
        plt.xlabel('Time')
        plt.ylabel('Population')
        plt.title('Population Dynamics Over Time')
        plt.ylim(0, 1)
        plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1))
        plt.tight_layout()
   

    if show_phase_only:
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))

        # --- Phase plot: W vs Y ---
        axs[0].plot(W, Y, label='W vs Y', color='purple')
        if use_X and use_Z:
            axs[0].plot(X, Z, label='X vs Z', color='brown', linestyle='--')

        if not np.isnan(W_equil1) and not np.isnan(Y_equil1):
            axs[0].scatter(W_equil1, Y_equil1, color='black', marker='o', label='Equilibrium +')
        if not np.isnan(W_equil2) and not np.isnan(Y_equil2):
            axs[0].scatter(W_equil2, Y_equil2, color='gray', marker='x', label='Equilibrium -')

        axs[0].set_xlabel('W')
        axs[0].set_ylabel('Y')
        axs[0].set_title('Phase Plot')
        axs[0].set_xlim(0, 1)
        axs[0].set_ylim(0, 1)
        axs[0].grid(True)
        axs[0].legend(loc='best')

        # --- Constraint plot: sqrt(Q1) + sqrt(Q2) = 1 ---
        Q_vals = np.linspace(0, 1, 400)
        q1_grid, q2_grid = np.meshgrid(Q_vals, Q_vals)
        valid_region = (np.sqrt(q1_grid) + np.sqrt(q2_grid)) <= 1

        axs[1].contourf(Q_vals, Q_vals, valid_region, levels=[0.5, 1], colors=['#e0f7fa'], alpha=0.5)

        axs[1].plot(Q_vals, (1 - np.sqrt(Q_vals))**2, 'k--', label=r'$\sqrt{Q_1} + \sqrt{Q_2} = 1$')
        axs[1].scatter(Q1, Q2, color='red', label='Current (Q1, Q2)', zorder=5)

        # Annotate the sum
        qsum = np.sqrt(Q1) + np.sqrt(Q2)
        text = r"$\sqrt{Q_1} + \sqrt{Q_2} = $" + f"{qsum:.2f}"
        axs[1].text(Q1 + 0.02, Q2 + 0.02, text, fontsize=10, color='black')

        axs[1].set_xlabel('Q1 = W_death / W_birth')
        axs[1].set_ylabel('Q2 = Y_death / Y_birth')
        axs[1].set_title('Constraint Region')
        axs[1].set_xlim(0, 1)
        axs[1].set_ylim(0, 1)
        axs[1].legend(loc='best')

        plt.tight_layout()
 

# Layout
slider_layout = widgets.Layout(width='250px')

# === Column 1: Initial Values ===
W0_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.5, description="W0", style={'handle_color': 'darkgreen'}, layout=slider_layout)
X0_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.66, description="X0", style={'handle_color': 'lightgreen'}, layout=slider_layout)
Y0_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.7, description="Y0", style={'handle_color': 'darkblue'}, layout=slider_layout)
Z0_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.75, description="Z0", style={'handle_color': 'lightskyblue'}, layout=slider_layout)

col1 = widgets.VBox([W0_slider, X0_slider, Y0_slider, Z0_slider])

# === Column 2: W + X Parameters ===
W_birth_slider = widgets.FloatSlider(min=0.01, max=2.0, step=0.01, value=0.4, description="W_birth", style={'handle_color': 'darkgreen'}, layout=slider_layout)
W_death_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.1, description="W_death", style={'handle_color': 'darkgreen'}, layout=slider_layout)
Y_birth_slider = widgets.FloatSlider(min=0.01, max=2.0, step=0.01, value=0.9, description="Y_birth", style={'handle_color': 'darkblue'}, layout=slider_layout)
Y_death_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.15, description="Y_death", style={'handle_color': 'darkblue'}, layout=slider_layout)

col2 = widgets.VBox([W_birth_slider, W_death_slider, Y_birth_slider, Y_death_slider])

# === Column 3: Y + Z Parameters ===
X_size_slider = widgets.FloatSlider(min=0.01, max=10.0, step=0.01, value=0.3, description="X_size", style={'handle_color': 'lightgreen'}, layout=slider_layout)
X_rate_slider = widgets.FloatSlider(min=0.0, max=2.0, step=0.01, value=0.1, description="X_rate", style={'handle_color': 'lightgreen'}, layout=slider_layout)
Z_size_slider = widgets.FloatSlider(min=0.01, max=10.0, step=0.01, value=0.2, description="Z_size", style={'handle_color': 'lightskyblue'}, layout=slider_layout)
Z_rate_slider = widgets.FloatSlider(min=0.0, max=2.0, step=0.01, value=0.05, description="Z_rate", style={'handle_color': 'lightskyblue'}, layout=slider_layout)

col3 = widgets.VBox([X_size_slider, X_rate_slider, Z_size_slider, Z_rate_slider])

# === Column 4: Time and Toggles ===
Time_slider = widgets.IntSlider(min=10, max=1000, step=10, value=100, description="Time", layout=slider_layout)
phase_only_checkbox = widgets.Checkbox(value=False, description="Phase Plot Only", layout=slider_layout)
X_checkbox = widgets.Checkbox(value=True, description="Enable X", layout=slider_layout)
Z_checkbox = widgets.Checkbox(value=True, description="Enable Z", layout=slider_layout)

col4 = widgets.VBox([Time_slider, phase_only_checkbox, X_checkbox, Z_checkbox])

# === Final layout and display ===
controls = widgets.HBox([col1, col2, col3, col4])


# Binding sliders to function
out = widgets.interactive_output(simulate_and_plot, {
    'W0': W0_slider, 'Y0': Y0_slider, 'X0': X0_slider, 'Z0': Z0_slider,
    'W_birth': W_birth_slider, 'Y_birth': Y_birth_slider,
    'W_death': W_death_slider, 'Y_death': Y_death_slider,
    'X_size': X_size_slider, 'Z_size': Z_size_slider,
    'X_rate': X_rate_slider, 'Z_rate': Z_rate_slider,
    'Time': Time_slider,
    'use_X': X_checkbox, 'use_Z': Z_checkbox,
    'show_phase_only': phase_only_checkbox
})

display(controls, out)

HBox(children=(VBox(children=(FloatSlider(value=0.5, description='W0', layout=Layout(width='250px'), max=1.0, …

Output()