# Fitting a family of exponentials

This app fits several different exponential functions to a dataset.


## Data format

Your data must be in as a two-column csv file with no headers. Column one must indicate time, and column two must indicate the value at that timepoint. Each point must have a corresponding value. Data should be ordered from the earliest time point to the last.

## Statistics

$R^2$ is the coefficient of determination that indicates how much variance in your data the model (fit) is able to capture.

Adj. $R^2$ is $R^2$ corrected for the number of parameters, which is useful when deciding which model explains your data best.

In [14]:
from pathlib import Path
import io, base64, re
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import lmfit
import ipywidgets as w
from traitlets import HasTraits, Unicode
from IPython.display import display, clear_output, Markdown, Javascript


plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'  # Makes text labels properly editable


# Define models that you want to fit and provide their representation in LaTeX
models = {
    'single exp': (
        lambda t, k, y_max: y_max * np.exp(-k*t),
        'y_{max} e^{-kt}'
    ),
    'exp with floor': (
        lambda t, k, c, y_max: c + (y_max - c) * np.exp(-k*t),
        'c + (y_{max} - c) e^{-kt}'
    ),
    'double exp': (
        lambda t, k_1, k_2, w, y_max: y_max * (w * np.exp(-k_1*t) + (1 - w) * np.exp(-k_2*t)),
        'y_{max} [w e^{-k_1 t} + (1 - w) e^{-k_2 t}]'
    )
}

# Create the initial figure and capture it as a widgets output
output = w.Output()
with output:
    fig, ax = plt.subplots()
    plt.close(fig)  # must close otherwise it stick around below the widget
    display(fig)


def plot_fits(t, y, title='', scale='log'):
    y_max = dict(value=y[0], vary=False)
    k0 = 1 / max(t)

    # Initial guesses
    params_dict = {
        'single exp': dict(
            k=k0,
            y_max=y_max
        ),
        'exp with floor': dict(
            k=k0,
            c=y[-1],
            y_max=y_max
        ),
        'double exp': dict(
            k_1=k0,
            k_2=5 * k0,
            w=.5,
            y_max=y_max
        ),
    }

    plots = {}
    ax.clear()  # Important when reloading the data file
    ax.set_title(title)

    # Plot data points
    ax.scatter(t, y, s=40, color='black', label='data')

    # Fit each model
    for i, (name, (f, mlabel)) in enumerate(models.items()):
        m = lmfit.Model(f)

        # Unpack initial guesses and pass into the model
        params = m.make_params()
        for k, v in params_dict[name].items():
            if isinstance(v, dict):
                params[k].set(**v)
            else:
                params[k].set(value=v)

        # Fit
        res = m.fit(y, params, t=t,
                    method='least_squares',
                    nan_policy='omit',
                    max_nfev=20000)

        # Produce points for plotting the fit
        # Our data is inherently logarithmic,
        # thus we produce them on a log scale
        tt = np.logspace(.01, np.log10(max(t)), 2000)
        y_fit = res.eval(t=tt)

        # Adjust R^2 for the number of parameters in the model
        n = len(y)
        p = len(res.params)
        r2_adj = 1 - (1 - res.rsquared) * (n - 1) / (n - p - 1)

        # Plot the fit
        line, = ax.plot(
                    tt,
                    y_fit,
                    linestyle='--',
                    label=rf'{name}: $\bar{{R}}^2={r2_adj:.2f}$'
                    )
        plots[name] = (line, res, r2_adj)

    ax.set_xlabel('Time, sec')
    ax.set_ylabel('Substrate fraction')

    _set_scale('log', t)
    ax.set_ylim([-.05, 1.05])
    ax.legend()

    return plots, t

def _set_scale(scale, t):
    """Set linear or log scale"""
    if scale == 'log':
        t_min = t[np.nonzero(t)[0][0]]  # first non-zero timepoint
        # We'll add 0 timepoint as well here
        ax.set_xscale('symlog', linthresh=t_min, linscale=1)
        ax.xaxis.set_minor_locator(
            matplotlib.ticker.SymmetricalLogLocator(
                base=10, linthresh=t_min, subs=np.arange(2, 10)
        ))
    else:
        ax.set_xscale(scale)

    ax.relim()
    ax.autoscale_view()

def _format_float(a):
    """Format floats as 1.7 x 10^3 if needed"""
    mantissa, exp = f'{a:.1e}'.split('e')
    exp = int(exp)
    if exp < -2 or exp > 2:
        output = f'{mantissa}\\times10^{{{exp}}}'
    else:
        output = f'{a:.2g}'
    return output

def _redraw():
    """A helper function to update plots."""
    fig.canvas.draw_idle()
    with output:
        clear_output(wait=True)
        display(fig)


class AppState(HasTraits):
    """Tracks changes in the data file name."""
    filestem = Unicode(allow_none=True).tag(sync=True)


class Plot:
    def __init__(self, state, save_btn):
        self.state = state

        # File upload button
        self.picker = w.FileUpload(accept='.csv', multiple=False)
        self.picker.observe(self._on_pick, 'value')

        # A toggle between linear and log scale
        self.scale = w.ToggleButtons(options=['linear','log'], value='log', description='Scale:')
        self.scale.observe(self._on_scale, names='value')

        # Checkboxes for which model (fit) to display
        self.cbs = {}
        cb_box = []
        for name, (f, mlabel) in models.items():
            cb = w.Checkbox(
                value=True,
                indent=False,
                description=name,
            )
            # We want some math reported and Markdown format appears to work
            # most reliably
            out = w.Output()
            with out:
                display(Markdown(f'${mlabel}$'))

            ui = w.HBox([cb, out], layout=w.Layout(display='inline-block', margin='0', padding='0'))
            cb_box.append(ui)
            cb.observe(self._on_cb, names='value')
            self.cbs[name] = ui

        # self.ui = w.VBox([self.picker, self.scale, w.VBox(cb_box)])

        self.ui = w.HBox([
            output,
            w.VBox([w.HBox([self.picker, save_btn]),
                    self.scale,
                    w.VBox(cb_box)
            ])
        ])

        self.plots = {}
        self.t = None

    def _on_pick(self, change):
        """Load data file and get fits."""
        files = change['new']
        if files:
            up = files[0]
            self.state.filestem = Path(up.name).stem
            data = np.loadtxt(io.BytesIO(up.content), delimiter=',', encoding='utf-8-sig')
            self.plots, self.t = plot_fits(
                data[:, 0], data[:, 1],
                title=self.state.filestem,
                scale=self.scale.value
                )
            for name, (line, res, r2_adj) in self.plots.items():
                self._update_cb_label(name, res, r2_adj)
            _redraw()

    def _on_scale(self, change):
        """Toggle between linear and log scale."""
        with output:
            _set_scale(change['new'], self.t)
            _redraw()

    def _on_cb(self, change):
        """Show/hide a fit."""
        if change['name'] == 'value':
            for name, line in self.plots.items():
                line.set_visible(self.cbs[name].children[0].value)
            ax.legend()
            _redraw()

    def _update_cb_label(self, name, res, r2_adj):
        """Report fits and parameters next to a checkbox."""
        out = self.cbs[name].children[1]
        out.clear_output()
        mlabel = models[name][1]

        with out:
            label = [
                f'Adj. $R^2 = {r2_adj:.3f}$',
                f'$R^2 = {res.rsquared:.3f}$'
            ]

            # Compute half-life
            if 'k' in res.params:
                half_life = np.log(2) / res.params['k'].value
                s = _format_float(half_life)
                md = f'$t_{{1/2}}={s}$ s'
                label.append(md)

            # Report fit parameters in a nice way
            params = []
            for k in res.params:
                if '_' in k:
                    name, underscript = k.split('_', maxsplit=1)
                    key = f'{name}_{{{underscript}}}'
                else:
                    key = k
                s = _format_float(res.params[k].value)
                params.append(f'${key}={s}$')

            display(Markdown('\n\n'.join([
                f'${mlabel}$',
                '; '.join(label),
                '; '.join(params)
                ])))


class Save:
    def __init__(self, state):
        self.state = state
        self.fname = 'plot'
        self.btn = w.Button(description='Download plot')
        self.link = w.HTML(value='')
        self.btn.on_click(self._on_save)
        self.ui = w.HBox([self.btn, self.link])
        self.state.observe(self._on_filestem, 'filestem')

    def _on_filestem(self, change):
        filestem = change['new']
        self.fname = f'{filestem.replace(" ", "_")}' if filestem else ''

    def _on_save(self, _):
        """Allows downloading the plot."""

        # Some clean up for SVG
        fig.patch.set_visible(False)
        for a in fig.axes:
            a.patch.set_visible(False)

        # Write image to a buffer
        buf = io.BytesIO()
        fig.savefig(buf, format='svg', transparent=True, bbox_inches='tight', pad_inches=0)
        b64 = base64.b64encode(buf.getvalue()).decode()

        name = f'{self.fname}.svg'

        # Always show a fallback link (works in Voilà)
        self.link.value = f'<a download="{name}" href="data:image/svg+xml;base64,{b64}">Download {name}</a>'

        # Try JS auto-click (ignored/blocked in Voilà)
        js = (
            "var a=document.createElement('a');"
            f"a.href='data:image/svg+xml;base64,{b64}';"
            f"a.download='{name}';"
            "document.body.appendChild(a);a.click();a.remove();"
        )
        display(Javascript(js))


state = AppState()
save = Save(state)
plot = Plot(state, save.btn)
display(plot.ui)

HBox(children=(Output(), VBox(children=(HBox(children=(FileUpload(value=(), accept='.csv', description='Upload…

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>