In [2]:
import datetime as dt

import altair as alt
import numpy as np
import polars as pl
import scipy.optimize
import scipy.stats

# Data

I'll look at just the national data for one season


In [3]:
data = (
    pl.read_parquet("../data/raw.parquet")
    .filter(
        pl.col("geography_type") == pl.lit("nation"),
        pl.col("time_end").is_between(dt.date(2009, 7, 1), dt.date(2010, 4, 1)),
    )
    .with_columns(
        t=(pl.col("time_end") - pl.col("time_end").min())
        .dt.total_days()
        .cast(pl.Float64)
    )
)

data

geography_type,geography,time_end,estimate,lci,uci,sample_size,t
str,str,date,f64,f64,f64,u32,f64
"""nation""","""nation""",2009-08-01,0.007,0.006,0.008,361485,0.0
"""nation""","""nation""",2009-09-01,0.076,0.074,0.078,361485,31.0
"""nation""","""nation""",2009-10-01,0.244,0.241,0.247,361485,61.0
"""nation""","""nation""",2009-11-01,0.335,0.332,0.338,361485,92.0
"""nation""","""nation""",2009-12-01,0.366,0.363,0.369,361485,122.0
"""nation""","""nation""",2010-01-01,0.379,0.375,0.383,377569,153.0
"""nation""","""nation""",2010-02-01,0.392,0.388,0.396,377569,184.0
"""nation""","""nation""",2010-03-01,0.399,0.395,0.403,377569,212.0
"""nation""","""nation""",2010-04-01,0.403,0.399,0.407,377569,243.0


# Manual

Pick some parameters at random and see how they fit


In [4]:
def predict(t: pl.Series, loc1, loc2, amp1, amp2, sig1, sig2) -> pl.Series:
    """Predict that incident uptake is the sum of two normals"""
    y1 = amp1 * scipy.stats.norm(loc=loc1, scale=sig1).cdf(t)
    y2 = amp2 * scipy.stats.norm(loc=loc2, scale=sig2).cdf(t)
    return y1 + y2


def plot_fit(args) -> alt.LayerChart:
    """Plot the fit for some given parameters loc1, ..., sig2"""
    assert len(args) == 6

    # need to use .map_batches() because there is no polars-native implementation of normal cdf
    chart_base = alt.Chart(
        data.with_columns(pred=pl.col("t").map_batches(lambda x: predict(x, *args)))
    ).encode(alt.X("time_end"))

    chart_error = chart_base.mark_rule().encode(
        alt.X2("time_end"), alt.Y("lci"), alt.Y2("uci")
    )
    chart_data = chart_base.mark_point().encode(alt.Y("estimate"))
    chart_pred = chart_base.mark_point(color="red").encode(alt.Y("pred"))

    return chart_error + chart_data + chart_pred


# I picked these values basically at random, just by eyeballing the data
manual_args = [31.0, 100.0, 0.3, 0.1, 10.0, 10.0]
plot_fit(manual_args)

# MLE fit

Find the parameters that best fit this particular curve, using a simple sum of squares


In [5]:
def obj(x: np.ndarray, data=data):
    """Sum of squares error for a given parameterization `x`"""
    pred_y = predict(data["t"], *x)
    return ((data["estimate"] - pred_y) ** 2).sum()


fit = scipy.optimize.minimize(
    obj,
    x0=np.array(manual_args),
    bounds=(
        # midpoints can be anywhere
        (None, None),
        (None, None),
        # amplitudes must be nonzero (and shouldn't exceed 1)
        (0.0, 1.0),
        (0.0, 1.0),
        # SDs must be positive
        (1e-6, None),
        (1e-6, None),
    ),
)

fit

  message: CONVERGENCE: NORM OF PROJECTED GRADIENT <= PGTOL
  success: True
   status: 0
      fun: 6.830304730880782e-05
        x: [ 4.762e+01  8.397e+01  2.594e-01  1.403e-01  1.914e+01
             5.734e+01]
      nit: 47
      jac: [-1.402e-06 -4.608e-06  9.744e-06  1.034e-06 -5.921e-06
             1.325e-07]
     nfev: 420
     njev: 60
 hess_inv: <6x6 LbfgsInvHessProduct with dtype=float64>

In [6]:
# make a chart using those ideal values
plot_fit(fit.x)