In [1]:
import numpy as np
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp
from scipy.optimize import root
from scipy.misc import derivative
from inspect import getfullargspec
from fractions import Fraction
from IPython.display import clear_output

import plotly.graph_objects as go
import dash_core_components as dcc
import dash_html_components as html

from dash.dependencies import Output, Input, State
from dash.exceptions import PreventUpdate
from dash import no_update, callback_context


from jupyter_dash import JupyterDash  # for local use
app = JupyterDash(__name__)           # with jupyter

In [2]:
def get_array(start, end, step):
    """np.linspace() to avoid array shape errors (due to float precision) but with np.arange() signature for usability

    Args:
        start, end, step: float (end is included in array)

    Returns:
        ndarray
    """
    num = int(np.rint((end - start) / step))  # number of array elements
    frac = Fraction("{0:.16f}".format(step))
    zoom = frac.denominator  # zoom to avoid float precision error
    if zoom > 1e12:
        print("WARNING: Ensure that step is not an irrational number")
    return np.linspace(start * zoom, end * zoom, num + 1) / zoom

In [3]:
def integrate(t_array, diff_eq, y_ini=0, method='LSODA'):
    """Integrate differential equation over t_array with initial value stock_ini.

    Args:
        t_array (ndarray)
        y_ini (int / ndarray)
        diff_eq: function with time and stock as parameter
        method:
            - 'LSODA': for fast result (default)
            - 'RK45', 'RK23', 'DOP853' for non-stiff problems (use solve_ivp)
            - 'Radau', 'BDF' for stiff problem (use solve_ivp)

    Returns:
        ndarray
    """

    nb_args_diff_eq = len(getfullargspec(diff_eq).args)  # check number argument of diff_eq

    if nb_args_diff_eq == 1:
        function = lambda t,y: diff_eq(y)
    elif nb_args_diff_eq == 2:
        function = diff_eq
    else:
        raise Exception("diff_eq only accepts 1 argument (autonomous) or 2 arguments (non-autonomous)") 


    with np.errstate(all='warn'):
        y = solve_ivp(function, (t[0], t[-1]),
                        np.atleast_1d(y_ini), # solve_ivp requires y0 to be an ndarray
                        method=method, t_eval=t).y          # error tolerance (default= 1e-3 can diverge with RK45)

    return np.squeeze(y)

In [4]:
def run(t, diff_eq, x, x_ini_slider):

    # Set up Phase Portrait graph
    # x = get_array(start=-7, end=7, step=0.1)
    dx_dt = diff_eq(x)
    y_equals_0 = go.Scatter(
        x=x, y=np.zeros(len(x)), mode="lines", name="y = 0", line=dict(color="black")
    )
    differential = go.Scatter(
        x=x, y=dx_dt, mode="lines", name="differential", line=dict(color="blue")
    )
    initial_val = go.Scatter(
        x=[x_ini_slider[len(x_ini_slider)//2][0]],
        y=[diff_eq(x_ini_slider[len(x_ini_slider)//2][0])],
        mode="markers",
        name="initial value",
        marker=dict(size=14, color="blue"),
    )
    stable_points = go.Scatter(
        x=[],
        y=[],
        mode="markers",
        name="stable",
        marker=dict(size=12, color="black", line={"width": 2}),
    )
    unstable_points = go.Scatter(
        x=[],
        y=[],
        mode="markers",
        name="unstable",
        marker=dict(size=12, color="white", line={"width": 2}),
    )
    halfstable_points = go.Scatter(
        x=[],
        y=[],
        mode="markers",
        name="halfstable",
        marker=dict(size=12, symbol='asterisk', line={"width": 2, 'color':'red'}),
    )
    data = [y_equals_0, differential, initial_val, stable_points, unstable_points, halfstable_points]
    layout = go.Layout(title="Phase Portrait", hovermode="closest", template="plotly_white",
                        xaxis_title='x', yaxis_title='dx/dt')
    phase_portrait = dcc.Graph(id="phase_portrait", figure={"data": data, "layout": layout})


    # set up slider (for initial value)
    myslider = dcc.Slider(
        id="myslider",
        min=x_ini_slider[0][0],
        max=x_ini_slider[-1][0],
        step=(x_ini_slider[-1][0] - x_ini_slider[0][0])/(len(x_ini_slider)-1),
        value = x_ini_slider[len(x_ini_slider)//2][0],
        updatemode="drag",
        marks={i[0]: {'label':i[1]} for i in x_ini_slider if i[1]},
)


    # set up dropdown (for solver options)
    methods = []
    for item in ["LSODA", "RK45", "RK23", "DOP853", "Radau", "BDF"]:
        methods.append({"label": item, "value": item})
    solver_options = dcc.Dropdown(id="solver", options=methods, value="LSODA")


    # Set up Dynamics graph
    mygraph = dcc.Graph(id="graph")


    # Set up app layout
    app.layout = html.Div(
        children=[
            html.Div(
                children=[
                    dcc.Store(id="stored_stable_fixed_pts", data=[]),
                    html.Div(
                        id="printed_stable_fixed_pts", style={"display": "inline-block"}
                    ),
                    dcc.Store(id="stored_unstable_fixed_pts", data=[]),
                    html.Div(
                        id="printed_unstable_fixed_pts", style={"display": "inline-block"}
                    ),
                    dcc.Store(id="stored_halfstable_fixed_pts", data=[]),
                    html.Div(
                        id="printed_halfstable_fixed_pts", style={"display": "inline-block"}
                    ),
                ],
                id="fixed_points",
                style={"width": "70%", "margin": "0 auto"},
            ),
            html.Div(phase_portrait, style={"width": "70%", "margin": "0 auto"}),
            html.Div(myslider, style={"width": "65%", "display": "inline-block"}),
            html.Div(
                solver_options,
                style={"width": "20%", "display": "inline-block", "float": "right"},
            ),
            html.Div(mygraph, style={"width": "70%", "margin": "0 auto"}),
        ]
    )


    @app.callback(
        [
            Output("stored_stable_fixed_pts", "data"),
            Output("printed_stable_fixed_pts", "children"),
            Output("stored_unstable_fixed_pts", "data"),
            Output("printed_unstable_fixed_pts", "children"),
            Output("stored_halfstable_fixed_pts", "data"),
            Output("printed_halfstable_fixed_pts", "children"),
        ],
        [
            Input("phase_portrait", "clickData"),
            Input("stored_stable_fixed_pts", "data"),
            Input("stored_unstable_fixed_pts", "data"),
            Input("stored_halfstable_fixed_pts", "data"),
        ],
        prevent_initial_call=True,
    )
    def update_roots(clickData, stable_roots, unstable_roots, halfstable_roots):
        click_x = clickData["points"][0]["x"]
        solution = root(diff_eq, click_x)
        if solution.success:
            x_=solution.x[0]
        else:
            raise PreventUpdate
        if x_ in set(stable_roots) | set(unstable_roots) | set(halfstable_roots):
            raise PreventUpdate
        stability = np.sign(derivative(diff_eq, x_, dx=1e-6))
        if stability == -1:
            stable_roots.append(x_)
            return (stable_roots, str(stable_roots), no_update, no_update, no_update, no_update)
        elif stability == 1:
            unstable_roots.append(x_)
            return (no_update, no_update, unstable_roots, str(unstable_roots), no_update, no_update)
        elif stability == 0:
            halfstable_roots.append(x_)
            return (no_update, no_update, no_update, no_update, halfstable_roots, str(halfstable_roots))
        else:
            raise PreventUpdate

    @app.callback(
        Output("phase_portrait", "figure"),
        [
            Input("stored_stable_fixed_pts", "data"),
            Input("stored_unstable_fixed_pts", "data"),
            Input("stored_halfstable_fixed_pts", "data"),
            Input("myslider", "value"),
        ],
        State("phase_portrait", "figure"),
        prevent_initial_call=True,
    )
    def update_portrait(stable_roots, unstable_roots, halfstable_roots, x_ini, fig):
        new_portrait = go.Figure(fig)
        roots = dict(stable=stable_roots, unstable=unstable_roots, halfstable=halfstable_roots)
        ctx = callback_context
        input_source = ctx.triggered[0]["prop_id"].split(".")[0]  # get id of input source
        if input_source == 'myslider':
            new_portrait.update_traces(
                x=[x_ini], y=[diff_eq(x_ini)], selector=dict(name='initial value')
            )
            # pass
        else:
            name = input_source.split("_")[1]  # stable / unstable / halfstable
            new_portrait.update_traces(
                x=roots[name], y=np.zeros(len(roots[name])), selector=dict(name=name)
            )
        return new_portrait


    @app.callback(
        Output("graph", "figure"),
        [
            Input("myslider", "value"),
            Input("solver", "value")
        ],
    )
    def update_graph(x_ini, method):
        ctx = callback_context
        input_source = ctx.triggered[0]["prop_id"].split(".")[0]  # get id of input source
        if input_source == 'solver':
            #clear_output
            pass
        x = integrate(t_array=t, diff_eq=diff_eq, y_ini=x_ini, method=method)
        dx_dt = diff_eq(x)
        trace_x = go.Scatter(x=t, y=x, mode="lines", name="position")
        # trace_dx_dt = go.Scatter(x=t, y=dx_dt, mode="lines", name="velocity")
        # data = [trace_x, trace_dx_dt]
        data = [trace_x]
        layout = go.Layout(title="Evolution over time", xaxis_title='time')
        return {"data": data, "layout": layout}



    app.run_server(debug=True, mode='external')   # for local use with jupyter
    #app.run_server(debug=True, port=3050)

<hr>

## 2. Flows on the Line

### 2.1 A geometric way of thinking


Consider the following nonlinear differential equation:
$$\dot{x} = \sin(x)$$

In [5]:
# Parameters
t = get_array(start=0, end=15, step=0.01)
x = get_array(start=-7, end=7, step=0.1)  # for phase portrait
x_ini_slider = [(i*np.pi, str(int(i))+"\U0001D70B") if i.is_integer() 
               else (i*np.pi, str(i)+"\U0001D70B") if np.round(i/0.25, decimals=4).is_integer() 
               else (i*np.pi, None) for i in get_array(0, 2.25, 0.05)]

# Differential equation
def diff_eq(x):
   dx_dt = np.sin(x)
   return dx_dt


# Run app
run(t, diff_eq, x, x_ini_slider)

Dash app running on http://127.0.0.1:8050/
