In [None]:
%matplotlib inline
import datetime

import numpy
import pandas
import matplotlib
import matplotlib.pyplot as plt
import scipy.stats

import cdc_common
from cdc_common import load_data

In [None]:
# Earliest date that there is sufficient data for all states, including MA
EARLIEST_DATE = pandas.Period('2020-03-13', freq='D')

# Set a latest date when the most recent days have garbage (like on or after holidays)
LATEST_DATE = pandas.Period('2022-02-03', freq='D')
LATEST_DATE = pandas.Period(datetime.date.today(), freq='D')

MIN_STAT_DATE = '2021-02-01'
STATS_LAG = 14
RATIO_DAYS = 14
MIN_SHIFT = 4
MAX_SHIFT = 12

In [None]:
latest_date, meta, all_stats, cdc_stats, hosp_stats = load_data(EARLIEST_DATE, LATEST_DATE, skip_projection=True)

In [None]:
all_stats.tail()

In [None]:
STATES = all_stats.reset_index().ST.unique()
STATES

In [None]:
cnt = 40

In [None]:
cnt += 1
idx = cnt // 2
USE_BEST = False

DOD_META = [
    ('AK', 8, 35), ('AL', 12, 24), ('AR', 9, 28),  ('AZ', 6, 28),  ('CA', 12, 28),
    ('CO', 10, 28), ('CT', 7, 24), ('DC', 9, 42),  ('DE', 11, 26), ('FL', 11, 28),
    ('GA', 9, 30),  ('HI', 12, 21), ('IA', 12, 21), ('ID', 11, 28), ('IL', 9, 28),
    ('IN', 10, 33), ('KS', 7, 21),  ('KY', 12, 28), ('LA', 11, 28), ('MA', 7, 18),
    ('MD', 9, 28), ('ME', 6, 28), ('MI', 9, 28), ('MN', 9, 25), ('MO', 8, 28),
    ('MS', 12, 21), ('MT', 7, 28),  ('NC', 10, 34), ('ND', 5, 21),  ('NE', 12, 28),
    ('NH', 9, 24),  ('NJ', 11, 24), ('NM', 6, 31),  ('NV', 12, 26), ('NY', 10, 24),
    ('OH', 9, 30), ('OK', 8, 30),  ('OR', 7, 35),  ('PA', 6, 28),  ('RI', 12, 28),  
    ('SC', 9, 27),  ('SD', 9, 28), ('TN', 11, 21), ('TX', 11, 31), ('UT', 4, 32),
    ('VA', 9, 28),  ('VT', 6, 28),  ('WA', 11, 31), ('WI', 6, 38), ('WV', 11, 28),
    ('WY', 12, 28),
]

st, hosp_lag, max_lag = DOD_META[idx]  # MO 9, OH 15, VA 22
min_date, max_date = pandas.Period(MIN_STAT_DATE, freq='D'), LATEST_DATE - max_lag
stats_max_date = max_date - STATS_LAG
print(idx, st, min_date, max_date, stats_max_date, max_lag)

both = all_stats.loc[st, :].loc['2020-08-01':, :][['Daily', 'NewHosp']].copy()
both.columns = ['Deaths', 'Hospital']
fam = both.plot(title=f"New Hospitalizations vs. Daily Deaths", secondary_y='Deaths', figsize=(16,5), ylim=0)
__ = fam.axvline(min_date, color="red", linestyle="--")
__ = fam.axvline(stats_max_date, color="red", linestyle="--")
__ = fam.axvline(max_date, color="green", linestyle="--")
__ = fam.get_figure().get_axes()[1].set_ylim(0)

best, best_sh, best_corr = None, 0.0, 0.0
for shift in range(MIN_SHIFT-4, MAX_SHIFT+4):
    h = both.Hospital.shift(shift).loc[min_date:stats_max_date]
    d = both.Deaths.loc[min_date:stats_max_date]
    corr = d.corr(h)
    if corr > best_corr:
        best_sh, best_corr = shift, corr
        best = pandas.concat([d, h], axis=1)
if best_sh < MIN_SHIFT:
    best_sh = MIN_SHIFT
if best_sh > MAX_SHIFT:
    best_sh = MAX_SHIFT
print(f"Best shift is {best_sh}")    

if USE_BEST:
    h = both.Hospital.shift(best_sh).loc[max_date-RATIO_DAYS:max_date].sum()
else:
    h = both.Hospital.shift(hosp_lag).loc[max_date-RATIO_DAYS:max_date].sum()
d = both.Deaths.loc[max_date-RATIO_DAYS:max_date].sum()
hd_ratio = h / d

proj = both.copy()
proj.Hospital = proj.Hospital.shift(hosp_lag)
proj['Projected'] = proj.Deaths

old_vals = proj.Projected.loc[max_date:]
new_vals = proj.Hospital.loc[max_date:] / hd_ratio
proj.Projected.loc[max_date:] = new_vals.combine(old_vals, max)
print(f"proj={proj.Projected.sum()}, deaths={proj.Deaths.sum()}")

fam = proj.plot(title=f"Current Hospitalizations vs. Daily Deaths", secondary_y='Hospital', figsize=(16,5), ylim=0)
__ = fam.get_figure().get_axes()[1].set_ylim(0)
__ = fam.axvline(min_date, color="red", linestyle="--")
__ = fam.axvline(stats_max_date, color="red", linestyle="--")
__ = fam.axvline(max_date, color="green", linestyle="--")

In [None]:
both.loc['2021-01-02':'2021-01-10', :].Deaths.sum()

In [None]:
spaz = both.copy()
spaz.loc['2021-01-02':'2021-01-10', 'Deaths'] += 17
fam = spaz.plot(
    title=f"Current Hospitalizations vs. Daily Deaths",
    secondary_y='Hospital', figsize=(16,5), ylim=0)

In [None]:
both.Deaths.sum()

In [None]:
raise ValueError()

In [None]:
MAX_SHIFT = 7
for st, orig_shift, min_date, max_lag in ST_STATS:
    min_date, max_date = pandas.Period(min_date, freq='D'), LATEST_DATE - max_lag
    stats_max_date = max_date - STATS_LAG
    both = STATES[st]

    best, best_sh, best_corr = None, 0.0, 0.0
    for shift in range(-MAX_SHIFT, MAX_SHIFT+1):
        h = both.Hospital.shift(shift).loc[min_date:stats_max_date]
        d = both.Deaths.loc[min_date:stats_max_date]
        corr = d.corr(h)
        # print(shift, corr)
        if corr > best_corr:
            best_sh, best_corr = shift, corr
            best = pandas.concat([d, h], axis=1)
    print(f"{st}: {best_sh}, {best_corr}")
#     if best_sh < MIN_SHIFT:
#         best_sh = MIN_SHIFT
#         print(f"Pinning shift to {MIN_SHIFT}")
#     if best_sh > MAX_SHIFT:
#         best_sh = MAX_SHIFT
#         print(f"Pinning shift to {MAX_SHIFT}")