In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import ipywidgets as widgets
from IPython.display import display
from matplotlib.backends.backend_pgf import FigureCanvasPgf  # register PGF backend
last_fig = None
from math import sqrt

# LaTeX: Times-like math (newtxmath) + Computer Modern text
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['CMU Serif', 'Computer Modern Roman', 'DejaVu Serif', 'Times New Roman', 'Times']
mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.latex.preamble'] = r'\usepackage{newtxmath}'
# Legend appearance: slightly opaque background
mpl.rcParams['legend.framealpha'] = .9
# PGF export configuration (pdflatex + newtxmath)
mpl.rcParams['pgf.texsystem'] = 'pdflatex'
mpl.rcParams['pgf.preamble'] = r'\usepackage{newtxmath}'
plt.rcParams['axes.titlesize'] = 14
width_pt = 390
inches_per_pt = 1.0/72.27
golden_ratio = (5**.5 - 1) / 2  # aesthetic figure height

fig_width = width_pt * inches_per_pt  # width in inches
fig_height = fig_width * golden_ratio # height in inches


def simulate_and_plot(V0, W0, Y0, X0, W_birth, Y_birth, W_death, Y_death,
                      X_in, X_out, Time, use_X, show_phase_only,
                      Extinction_severity, Extinction_rate, extinction_affects_W=False, random_variance =0, Extinction_duration=0, Y_size=1, PVDead=0, PXDead=0, PB1=0, PB2=0, Z0=0):
    # Calculate equilibrium values for reference

    X_scaler = X_out/X_in
    # Z turned off: ignore Z_in/Z_out

    Q1 = W_death / W_birth
    Q2 = Y_death / Y_birth

    disc_W = (1 - Q1 + Q2)**2 - 4 * Q2
    W_equil1 = W_equil2 = np.nan
    if disc_W >= 0:
        sqrt_disc_W = np.sqrt(disc_W)
        W_equil1 = 0.5 * ((1 - Q1 + Q2) + sqrt_disc_W)
        W_equil2 = 0.5 * ((1 - Q1 + Q2) - sqrt_disc_W)

    disc_Y = (1 - Q2 + Q1)**2 - 4 * Q1
    Y_equil1 = Y_equil2 = np.nan
    if disc_Y >= 0:
        sqrt_disc_Y = np.sqrt(disc_Y)
        Y_equil1 = 0.5 * ((1 - Q2 + Q1) + sqrt_disc_Y)
        Y_equil2 = 0.5 * ((1 - Q2 + Q1) - sqrt_disc_Y)

    dt = 0.01
    t = np.arange(0, Time + dt, dt)
    # Initialize arrays for populations
    V = np.zeros_like(t)
    W = np.zeros_like(t)
    Y = np.zeros_like(t)
    X = np.zeros_like(t)
    Z = np.zeros_like(t)
    
    V[0], W[0], Y[0], X[0], Z[0] = V0, W0, Y0, X0/X_scaler, 0.0

    # Generate extinction events:
    extinction_events = []
    if random_variance > 0:
        current_time = 0.0
        while current_time < Time:
            # Draw the interval from a normal distribution (forcing a nonnegative value)
            interval = np.random.normal(loc=Extinction_rate, scale=random_variance)
            interval = max(0, interval)
            current_time += interval
            if current_time < Time:
                extinction_events.append((current_time, Extinction_severity))
    else:
        extinction_events = [(k, Extinction_severity) for k in np.arange(Extinction_rate, Time+dt, Extinction_rate)]
    
    # Add a final event at time 0 to ensure the extinction starts at the beginning
    extinction_events.append((0.0, Extinction_severity)) 
    extinction_events = sorted(extinction_events, key=lambda x: x[0])
    event_index = 0  # pointer to the next event

    # Variables for the lock period.
    # We record separate locked values for the two arrays of the affected population.
    lock_end_time = -1.0
    locked_value_1 = None  # For W or Y (first array)
    locked_value_2 = None  # For W2 or Y2 (second array)

    # Simulation loop
    for i in range(1, len(t)):
        # Calculate the derivatives
        dV =  W_birth * (1 - W[i - 1] - V[i - 1])* V[i - 1] * Y[i - 1] - W_death * V[i - 1]
        dW =  W_birth * (1 - W[i - 1] - V[i - 1])* W[i - 1] * Y[i - 1] - W_death * W[i - 1]

        dY = Y_birth * (1 - Y[i - 1]) * Y[i - 1] * (V[i - 1] + W[i - 1]) - Y_death * Y[i - 1]
    
        if use_X:
            dW += X_out * X[i - 1] - X_in * W[i - 1]
        # Z turned off: no effect on Y from Z
    
        dX = -X_out * X[i - 1] + X_in * W[i - 1]
        # Z turned off: no dynamics
    
        # Update populations normally
        V[i] = V[i - 1] + dt * dV
        W[i] = W[i - 1] + dt * dW
        Y[i] = Y[i - 1] + dt * dY
        X[i] = X[i - 1] + dt * dX
        Z[i] = 0.0
   
        current_sim_time = i * dt
    
        # During the lock period, we enforce an upper bound on the affected populations.
        if current_sim_time < lock_end_time:
            if extinction_affects_W:
                if W[i] > locked_value_1:
                    W[i] = locked_value_1
                if V[i] > locked_value_2:
                    V[i] = locked_value_2
            else:
                if Y[i] > locked_value_1:
                    Y[i] = locked_value_1
        else:
            # If not in a lock period, check for an extinction event.
            while event_index < len(extinction_events) and current_sim_time >= extinction_events[event_index][0]:
                severity = extinction_events[event_index][1]
                if extinction_affects_W:
                    # Apply extinction event to each array independently.
                    W[i] = W[i] * (1 - severity)
                    # Record the locked values (upper bounds) for each array.
                    locked_value_1 = W[i]
                    V[i] = V[i] * (1 - severity)
                    locked_value_2 = V[i]
                else:
                    Y[i] = Y[i] * (1 - severity)
                    locked_value_1 = Y[i]
    
                lock_end_time = current_sim_time + Extinction_duration
                event_index += 1
    X_plot = X * X_scaler
    # Z turned off: no Z plot
    # --- Start Plotting Section ---
    if not show_phase_only:
        plt.figure(figsize=(fig_width, fig_height))
        plt.plot(t, Y, label=r'$Y$', color='darkblue')
        if use_X:
            plt.plot(t, X_plot, label=f'{X_scaler:.1f}'r'$W^d$', color='lightgreen')
        # Z turned off: skip Z time-series
        plt.plot(t, V, label=r'$\widetilde{W}^a$', color='orange')
        plt.plot(t, W, label=r'$W^a$', color='darkgreen')        
        
        
        if not np.isnan(W_equil1):
            plt.axhline(W_equil1, color='darkgreen', linestyle='--', label=r'$W_{(eq,+)}^a$')
        if not np.isnan(W_equil2):
            plt.axhline(W_equil2, color='darkgreen', linestyle='--', label=r'$W_{(eq,-)}^a$')
        if not np.isnan(Y_equil1):
            plt.axhline(Y_equil1, color='darkblue', linestyle=':', label=r'$Y_{(eq,+)}$')
        if not np.isnan(Y_equil2):
            plt.axhline(Y_equil2, color='darkblue', linestyle=':', label=r'$Y_{(eq,-)}$')

        # --- Added vertical lines for extinction events ---
        # Define an extinction threshold.
        extinction_threshold = 1e-3

        def get_extinction_times(population):
            # Return list of times when the population first drops below the threshold.
            ext_times = []
            for i in range(1, len(population)):
                if population[i-1] >= extinction_threshold and population[i] < extinction_threshold:
                    ext_times.append(t[i])
            return ext_times

        # Compute extinction times for each population.
        v_ext_times = get_extinction_times(V)
        y_ext_times = get_extinction_times(Y)
        w_ext_times = get_extinction_times(W)
        if use_X:
            x_ext_times = get_extinction_times(X)
        else:
            x_ext_times = []

        # Plot a vertical line for each extinction event using the same color as the curve.
        # (Only the first vertical line for each species gets a label to avoid duplications in the legend.)
        if v_ext_times:
            plt.axvline(x=v_ext_times[0], color='orange', linestyle=':', lw=1, label='V extinction')
            for ext_time in v_ext_times[1:]:
                plt.axvline(x=ext_time, color='orange', linestyle=':', lw=1)
        if y_ext_times:
            plt.axvline(x=y_ext_times[0], color='darkblue', linestyle=':', lw=1, label='Y extinction')
            for ext_time in y_ext_times[1:]:
                plt.axvline(x=ext_time, color='darkblue', linestyle=':', lw=1)
        if w_ext_times:
            plt.axvline(x=w_ext_times[0], color='darkgreen', linestyle=':', lw=1, label='W extinction')
            for ext_time in w_ext_times[1:]:
                plt.axvline(x=ext_time, color='darkgreen', linestyle=':', lw=1)
        if use_X and x_ext_times:
            plt.axvline(x=x_ext_times[0], color='lightgreen', linestyle=':', lw=1, label='X extinction')
            for ext_time in x_ext_times[1:]:
                plt.axvline(x=ext_time, color='lightgreen', linestyle=':', lw=1)
        # --- End of vertical extinction lines ---

        plt.xlabel('Time')
        plt.ylabel('Population')
        plt.title(r' \textbf{Population Dynamics Over Time}')#color="#0065BD"
        plt.ylim(0, 1)
        plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1))
        plt.tight_layout()
   
    if show_phase_only:
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        axs[0].plot(W, Y, label='W vs Y', color='purple')
        # Z turned off: no X vs Z phase line
        if not np.isnan(W_equil1) and not np.isnan(Y_equil1):
            axs[0].scatter(W_equil1, Y_equil1, color='black', marker='o', label='Equilibrium +')
        if not np.isnan(W_equil2) and not np.isnan(Y_equil2):
            axs[0].scatter(W_equil2, Y_equil2, color='gray', marker='x', label='Equilibrium -')
        axs[0].set_xlabel('W')
        axs[0].set_ylabel('Y')
        axs[0].set_title('Phase Plot')
        axs[0].set_xlim(0, 1)
        axs[0].set_ylim(0, 1)
        axs[0].grid(True)
        axs[0].legend(loc='best', )
    
        Q_vals = np.linspace(0, 1, 400)
        q1_grid, q2_grid = np.meshgrid(Q_vals, Q_vals)
        valid_region = (np.sqrt(q1_grid) + np.sqrt(q2_grid)) <= 1
        axs[1].contourf(Q_vals, Q_vals, valid_region, levels=[0.5, 1], colors=['#e0f7fa'])
        axs[1].plot(Q_vals, (1 - np.sqrt(Q_vals))**2, 'k--', label=r'$\sqrt{Q_1} + \sqrt{Q_2} = 1$')
        axs[1].scatter(Q1, Q2, color='red', label='Current (Q1, Q2)', zorder=5)
        qsum = np.sqrt(Q1) + np.sqrt(Q2)
        text = r"$\sqrt{Q_1} + \sqrt{Q_2} = $" + f"{qsum:.2f}"
        axs[1].text(Q1 + 0.02, Q2 + 0.02, text, fontsize=10, color='black')
        axs[1].set_xlabel('Q1 = W_death / W_birth')
        axs[1].set_ylabel('Q2 = Y_death / Y_birth')
        axs[1].set_title('Constraint Region')
        axs[1].set_xlim(0, 1)
        axs[1].set_ylim(0, 1)
        axs[1].legend(loc='best') 
        plt.tight_layout()
        
    global last_fig
    last_fig = plt.gcf()

# Define a compact slider layout with reduced width and margins
slider_layout = widgets.Layout(width='240px', margin='2px 2px 2px -20px')

# === Column 1: Initial Values ===
W0_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.2, 
                                description="W0", style={'handle_color': 'darkgreen'}, 
                                layout=slider_layout)
X0_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.2, 
                                description="X0", style={'handle_color': 'lightgreen'}, 
                                layout=slider_layout)
Y0_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.75, 
                                description="Y0", style={'handle_color': 'darkblue'}, 
                                layout=slider_layout)
V0_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.47, 
                                description="V0", style={'handle_color': 'orange'}, 
                                layout=slider_layout)
col1 = widgets.VBox([W0_slider, X0_slider, Y0_slider, V0_slider])

# === Column 2: W + X Parameters ===
W_birth_slider = widgets.FloatSlider(min=0.01, max=2.0, step=0.01, value=0.4, 
                                     description="W_birth", style={'handle_color': 'darkgreen'}, 
                                     layout=slider_layout)
W_death_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.1, 
                                     description="W_death", style={'handle_color': 'darkgreen'}, 
                                     layout=slider_layout)
Y_birth_slider = widgets.FloatSlider(min=0.01, max=2.0, step=0.01, value=0.9, 
                                     description="Y_birth", style={'handle_color': 'darkblue'}, 
                                     layout=slider_layout)
Y_death_slider = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.15, 
                                     description="Y_death", style={'handle_color': 'darkblue'}, 
                                     layout=slider_layout)
col2 = widgets.VBox([W_birth_slider, W_death_slider, Y_birth_slider, Y_death_slider])

# === Column 3: X Parameters ===
X_in_slider = widgets.FloatSlider(min=0.01, max=2.0, step=0.01, value=0.2, 
                                    description="X_in", style={'handle_color': 'lightgreen'}, 
                                    layout=slider_layout)
X_out_slider = widgets.FloatSlider(min=0.0, max=2.0, step=0.01, value=0.1, 
                                    description="X_out", style={'handle_color': 'lightgreen'}, 
                                    layout=slider_layout)
col3 = widgets.VBox([X_in_slider, X_out_slider])

# === Column 4: Time and Toggles ===
Time_slider = widgets.IntSlider(min=10, max=50000, step=10, value=400, 
                                description="Time", layout=slider_layout)
phase_only_checkbox = widgets.Checkbox(value=False, description="Phase Plot Only", layout=slider_layout)
X_checkbox = widgets.Checkbox(value=True, description="Enable X", layout=slider_layout)
save_button = widgets.Button(description='Save PDF', layout=slider_layout)

def _on_save_clicked(b):
    global last_fig
    import os, datetime
    if last_fig is None:
        return
    mode = 'phase' if phase_only_checkbox.value else 'time'
    ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    base = f'BirthDeathEvo_{mode}'
    outdir = 'plots' if os.path.isdir('plots') else '.'
    pdf_path = os.path.join(outdir, base + '.pdf')
    pgf_path = os.path.join(outdir, base + '.pgf')

    # Temporarily move legends inside the axes for saving
    '''
    legend_states = []
    for ax in last_fig.axes:
        leg = ax.get_legend()
        if leg is None:
            legend_states.append(None)
            continue
        state = {
            'ax': ax,
            'loc': getattr(leg, '_loc', None),
            'bbox': leg.get_bbox_to_anchor(),
            'frameon': leg.get_frame_on(),
        }
        legend_states.append(state)
        leg.remove()
        ax.legend(loc='upper right', frameon=True)
        '''

    last_fig.savefig(pdf_path, bbox_inches='tight')
    last_fig.savefig(pgf_path, bbox_inches='tight')

    # Restore original legend positions
    for state in legend_states:
        if not state:
            continue
        ax = state['ax']
        handles, labels = ax.get_legend_handles_labels()
        if state['bbox'] is not None and state['loc'] is not None:
            ax.legend(handles, labels, loc=state['loc'], bbox_to_anchor=state['bbox'], frameon=state['frameon'])
        elif state['loc'] is not None:
            ax.legend(handles, labels, loc=state['loc'], frameon=state['frameon'])
        else:
            ax.legend(handles, labels, frameon=state['frameon'])

    print(f'Saved {pdf_path} and {pgf_path}')

save_button.on_click(_on_save_clicked)

# === Column 5: Extinction Parameters ===
ext_severity_slider = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.5, 
                                          description="Severity", layout=slider_layout)
ext_rate_slider = widgets.FloatSlider(min=1, max=50, step=1, value=25, 
                                      description="Rate", layout=slider_layout)
random_variance_slider = widgets.FloatSlider(min=0, max=10, step=0.1, value=0, 
                                              description="VAR", layout=slider_layout)
ext_affects_checkbox = widgets.Checkbox(value=False, description="Affect W", layout=slider_layout)
ext_duration_slider = widgets.FloatSlider(min=0, max=50, step=0.5, value=0,
                                          description="Duration", layout=slider_layout)
col4 = widgets.VBox([Time_slider, ext_affects_checkbox, X_checkbox, save_button])
col5 = widgets.VBox([ext_severity_slider, ext_rate_slider, random_variance_slider, ext_duration_slider])
controls = widgets.HBox([col1, col2, col3, col4, col5])

# Binding sliders to the simulation function.
out = widgets.interactive_output(simulate_and_plot, {'V0': V0_slider,
    'W0': W0_slider, 'Y0': Y0_slider, 'X0': X0_slider,
    'W_birth': W_birth_slider, 'Y_birth': Y_birth_slider,
    'W_death': W_death_slider, 'Y_death': Y_death_slider,
    'X_in': X_in_slider,
    'X_out': X_out_slider,
    'Time': Time_slider,
    'use_X': X_checkbox,
    'show_phase_only': phase_only_checkbox,
    'Extinction_severity': ext_severity_slider,
    'Extinction_rate': ext_rate_slider,
    'extinction_affects_W': ext_affects_checkbox,
    'random_variance': random_variance_slider,
    'Extinction_duration': ext_duration_slider
})

display(controls, out)

HBox(children=(VBox(children=(FloatSlider(value=0.2, description='W0', layout=Layout(margin='2px 2px 2px -20px…

Output()