In [None]:
%matplotlib inline
import datetime

import numpy
import pandas
import matplotlib
import matplotlib.pyplot as plt
import scipy.stats

import common
from common import load_ctp_stats, ST_STATS

In [None]:
# Earliest date that there is sufficient data for all states, including MA
EARLIEST_DATE = pandas.Period('2020-03-13', freq='D')

# Set a latest date when the most recent days have garbage (like on or after holidays)
LATEST_DATE = pandas.Period('2021-01-27', freq='D')
LATEST_DATE = pandas.Period(datetime.date.today(), freq='D')

STATS_LAG = 10
RATIO_DAYS = 14
MIN_SHIFT = 0
MAX_SHIFT = 7

In [None]:
ctp_stats = load_ctp_stats().set_index(['ST', 'Date']).sort_index()[['Hospital']]
ctp_stats.tail(5)

In [None]:
# Hospitalization shifts, earliest good data, and ignore days for date-of-death states
_ST_STATS = [
#     ('AL', 6, '2020-07-15', 28), ('AZ', 1, '2020-07-15', 23),
#             ('CT', 4, '2020-07-15', 20), ('FL', 6, '2020-07-15', 21),
#             ('GA', 7, '2020-07-15', 22),
#    ('DE', 3, '2020-07-15', 28), 
#     ('KS', 2, '2020-07-15', 28),
#     ('IA', 1, '2020-07-15', 30),
#             ('IN', 6, '2020-07-15', 25), ('MA', 0, '2020-07-15', 12),
#             ('MI', 6, '2020-07-15', 14), ('MO', 0, '2020-07-15', 44),
#             ('MS', 3, '2020-07-15', 18), ('NC', 5, '2020-07-15', 18),
#             ('ND', 0, '2020-07-15', 20), ('NJ', 5, '2020-07-15', 22),
#             ('NV', 4, '2020-07-15', 16),
#    ('OH', 7, '2020-07-15', 36),
#             ('PA', 2, '2020-07-15', 30), ('RI', 4, '2020-07-15', 20),
#             ('SC', 2, '2020-07-25', 15), ('SD', 0, '2020-07-15', 38),
#             ('TN', 1, '2020-07-15', 20), ('TX', 3, '2020-07-15', 25),
             ('VA', 0, '2020-07-15', 22),
]

In [None]:
STATES = {}
for st, __, __, __ in _ST_STATS:
    deaths = common.__dict__[f'load_{st.lower()}_data']().set_index('Date')
    deaths.Deaths = (deaths.Deaths - deaths.Deaths.shift())
    deaths.Deaths = deaths.Deaths.rolling(window=5, center=True, min_periods=1).mean()
    hosp = ctp_stats.loc[st, :]
    hosp.Hospital = hosp.Hospital.rolling(window=5, center=True, min_periods=1).mean()
    min_ = max(deaths.dropna().index.min(), hosp.dropna().index.min())
    max_ = min(deaths.dropna().index.max(), hosp.dropna().index.max())
    both = deaths.loc[min_:max_, :].join(hosp.loc[min_:max_])
    STATES[st] = both

In [None]:
idx = 0

In [None]:
st, __, min_date, max_lag = _ST_STATS[idx]  # MO 9, OH 15, VA 22
idx = (idx + 1) % len(_ST_STATS)
# max_lag = 45
min_date, max_date = pandas.Period(min_date, freq='D'), LATEST_DATE - max_lag
stats_max_date = max_date - STATS_LAG
both = STATES[st]
print(idx-1, st, min_date, max_date, stats_max_date, max_lag)
fam = both[['Deaths', 'Hospital']].plot(
    title=f"Current Hospitalizations vs. Daily Deaths",
    secondary_y='Deaths', figsize=(16,5), ylim=0)
__ = fam.axvline(min_date, color="red", linestyle="--")
__ = fam.axvline(stats_max_date, color="red", linestyle="--")
__ = fam.axvline(max_date, color="green", linestyle="--")
__ = fam.get_figure().get_axes()[1].set_ylim(0)

In [None]:
both.Deaths.sum()

In [None]:
both.columns

In [None]:
best, best_sh, best_corr = None, 0.0, 0.0
for shift in range(-MAX_SHIFT, MAX_SHIFT+4):
    h = both.Hospital.shift(shift).loc[min_date:stats_max_date]
    d = both.Deaths.loc[min_date:stats_max_date]
    corr = d.corr(h)
    print(shift, corr)
    if corr > best_corr:
        best_sh, best_corr = shift, corr
        best = pandas.concat([d, h], axis=1)
print("Best:", best_sh, best_corr)
if best_sh < MIN_SHIFT:
    best_sh = MIN_SHIFT
    print(f"Pinning shift to {MIN_SHIFT}")
if best_sh > MAX_SHIFT:
    best_sh = MAX_SHIFT
    print(f"Pinning shift to {MAX_SHIFT}")
fam = best.plot(
    title=f"Current Hospitalizations vs. Daily Deaths",
    secondary_y='Deaths', figsize=(16,5), ylim=0)

In [None]:
h = both.Hospital.shift(best_sh).loc[max_date-RATIO_DAYS:max_date].sum()
d = both.Deaths.loc[max_date-RATIO_DAYS:max_date].sum()
hd_ratio = h / d

In [None]:
both['Projected'] = both.Deaths
both.Projected.loc[max_date:] = both.Hospital.shift(best_sh).loc[max_date:] / hd_ratio
both.tail()

In [None]:
fam = both.plot(
    title=f"Current Hospitalizations vs. Daily Deaths",
    secondary_y='Hospital', figsize=(16,5), ylim=0)

In [None]:
both.Projected.sum(), both.Deaths.sum()  # 160

In [None]:
both.loc['2021-01-02':'2021-01-10', :].Deaths.sum()

In [None]:
spaz = both.copy()
spaz.loc['2021-01-02':'2021-01-10', 'Deaths'] += 17
fam = spaz.plot(
    title=f"Current Hospitalizations vs. Daily Deaths",
    secondary_y='Hospital', figsize=(16,5), ylim=0)

In [None]:
spaz.Deaths.sum()

In [None]:
raise ValueError()

In [None]:
MAX_SHIFT = 7
for st, orig_shift, min_date, max_lag in ST_STATS:
    min_date, max_date = pandas.Period(min_date, freq='D'), LATEST_DATE - max_lag
    stats_max_date = max_date - STATS_LAG
    both = STATES[st]

    best, best_sh, best_corr = None, 0.0, 0.0
    for shift in range(-MAX_SHIFT, MAX_SHIFT+1):
        h = both.Hospital.shift(shift).loc[min_date:stats_max_date]
        d = both.Deaths.loc[min_date:stats_max_date]
        corr = d.corr(h)
        # print(shift, corr)
        if corr > best_corr:
            best_sh, best_corr = shift, corr
            best = pandas.concat([d, h], axis=1)
    print(f"{st}: {best_sh}, {best_corr}")
#     if best_sh < MIN_SHIFT:
#         best_sh = MIN_SHIFT
#         print(f"Pinning shift to {MIN_SHIFT}")
#     if best_sh > MAX_SHIFT:
#         best_sh = MAX_SHIFT
#         print(f"Pinning shift to {MAX_SHIFT}")