# Preliminary sensitivity analysis

In [16]:
from scipy.integrate import solve_ivp
import numpy as np
from ipywidgets import widgets, interact
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from matplotlib import cm
from matplotlib.colors import rgb2hex

style = {'description_width': '250px'}
layout = widgets.Layout(width='600px')

@interact
def fragmentmnp(K=widgets.IntSlider(min=1, max=20, value=7, description="Number of size classes", style=style, layout=layout),
                K_range=widgets.IntRangeSlider(min=-15, max=-2, value=[-9, -3], description="Size class diameter range (10^value)", style=style, layout=layout),
                T=widgets.IntSlider(min=1, max=1000, value=100, description="Number of timesteps", style=style, layout=layout),
                bar_ts=widgets.IntText(value=99, description="Show bar chart for time step...", style=style, layout=layout),
                _k_frag=widgets.BoundedFloatText(0.01, min=0.0, max=1.0, step=0.002, description="Average fragmentation rate", style=style, layout=layout),
                theta1=widgets.BoundedFloatText(0.0, min=0.0, max=0.5, step=0.01, description="Empirical constant theta1", style=style, layout=layout),
                _n_0=widgets.FloatText(42.0, description="Initial particle number concentration", style=style, layout=layout),
                loss=widgets.Checkbox(value=True, description="Allow fragmentation from smallest size class?", style=style, layout=layout)):
    
    # Set up a particle size distribution and get the median
    d = np.logspace(K_range[1], K_range[0], K)
    d_median = np.median(d)
    
    # Set k_frag for each size class, based on the average k_frag for d_median and theta1
    # First, get the proportionality constant
    k_prop = _k_frag / (d_median ** (2 * theta1))
    # Now create the array of k_frags
    k_frag = k_prop * d ** (2 * theta1)
    
    # If fragmentation from smallest class isn't allowed, then set k_frag for
    # that class to zero
    if not loss:
        k_frag[K-1] = 0.0
    n_0 = np.full(K, _n_0)

    # Fragment size distribution matrix - assume fragmentation event results in even
    # split between size classes of daughter particles
    fsd = np.zeros((K,K))
    for k in np.arange(K):
        fsd[k,:] = 1 / (K - k - 1) if (K - k) != 1 else 0
        # Get the upper triangle of this matrix, which effectively sets f to zero for
        # size classes larger (or equal to) than the current one
    fsd = np.triu(fsd, k=1)

    # Define the function that satisfies n'(t) = f(t, n)
    # i.e. the RHS of our differential eq
    def f(t, n):
        # Get number of size classes and create empty result to be filled
        N = n.shape[0]
        dndt = np.empty(N)
        # Loop over the size classes and perform the calculation
        for k in np.arange(N):
            dndt[k] = - k_frag[k] * n[k] + np.sum(fsd[:,k] * k_frag * n)
        # Return the solution for all of the size classes
        return dndt

    # Numerically solve this given the initial values for n, over T time steps
    soln = solve_ivp(fun=f,
                     t_span=(0, T),
                     y0=n_0,
                     t_eval=np.arange(0, T))

    # If k_frag != 0 for the smallest size class, then there will be a loss to the
    # system, so keep track of that here
    n_loss = np.sum(n_0) - np.sum(soln.y, axis=0)

    # Finally, create the graphs!

    # Define the colour map to use
    viridis = [rgb2hex(rgb) for rgb in cm.get_cmap('viridis', K).colors]
  
    # Set up the subplots
    fig = make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.03,
                      subplot_titles=(f'k_frag: {np.array2string(np.flip(k_frag), precision=3)}<br>theta1: {theta1}',
                                      f'Time step: {bar_ts}'))
    # Update the size of the subplot titles
    fig.update_annotations(font_size=14)

    # Line chart timeseries first
    for i in range(0,K):
        fig.add_trace(go.Scatter(x=soln.t, y=soln.y[i], name=f'{d[i]:.2e} m',
                      line_color=viridis[i]), row=1, col=1)
    if loss:
        # Plot the loss (if there is any) with a different style
        fig.add_trace(go.Scatter(x=soln.t, y=n_loss, name='Loss',
                                 line={'width': 3, 'dash': 'dash', 'color': 'lightblue'}), row=1, col=1)
    fig.update_xaxes(title='Time', col=1, row=1)
    fig.update_yaxes(title='Particle number concentration', col=1, row=1)
  
    # Now the bar chart
    if bar_ts < T:
        fig.append_trace(go.Bar(x=[f'{sc:.2e} m' for sc in d], y=soln.y[:,bar_ts],
                                marker_color=viridis, showlegend=False), row=1, col=2)
    fig.update_xaxes(title='Size class diameter', autorange='reversed', col=2, row=1)

    # Update the layout and show
    fig.update_layout(width=1300)
    fig.show()
    fig.write_image(f'output/theta1_{theta1}.png')

interactive(children=(IntSlider(value=7, description='Number of size classes', layout=Layout(width='600px'), m…