In [3]:
from copy import deepcopy
from datetime import datetime
from os import getcwd, path, makedirs
from string import ascii_letters, digits
import json
import multiprocessing as mp

from scipy import stats as sps
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math


from _99_shared_functions import SIR_from_params, qdraw, jumper, power_spline,\
    reopen_wrapper

from _02_munge_chains import SD_plot, mk_projection_tables, plt_predictive, \
    plt_pairplot_posteriors, SEIR_plot, Rt_plot
from utils import beta_from_q

LET_NUMS = pd.Series(list(ascii_letters) + list(digits))

def get_dir_name(options):
    now = datetime.now()
    dir = now.strftime("%Y_%m_%d_%H_%M_%S")
    if options.prefix:
        dir = f"{dir}_{options.prefix}"
    if options.out:
        dir = f"{dir}_{options.out}"
    outdir = path.join(f"{getcwd()}", "output", dir)
    # In case we're running a few instances in a tight loop, generate a random
    # output directory
    if path.isdir(outdir):
        dir = f"{dir}_{''.join(LET_NUMS.sample(6, replace=True))}"
        outdir = path.join(f"{getcwd()}", "output", dir)
    makedirs(outdir)
    return outdir


def get_inputs(options):
    census_ts, params = None, None
    if options.prefix is not None:
        prefix = options.prefix
        datadir = path.join(f"{getcwd()}", "data")
        # import the census time series and set the zero day to be the first instance of zero
        census_ts = pd.read_csv(path.join(f"{datadir}", f"{prefix}_ts.csv"), encoding='latin-1')
        # impute vent with the proportion of hosp.  this is a crude hack
        census_ts.loc[census_ts.vent.isna(), "vent"] = census_ts.hosp.loc[
            census_ts.vent.isna()
        ] * np.mean(census_ts.vent / census_ts.hosp)
        # import parameters
        params = pd.read_csv(path.join(f"{datadir}", f"{prefix}_parameters.csv"), encoding = 'latin-1')
    if options.parameters is not None:
        params = pd.read_csv(options.parameters, encoding = 'latin-1')
    if options.ts is not None:
        census_ts = pd.read_csv(options.ts, encoding = 'latin-1')
        # impute vent with the proportion of hosp.  this is a crude hack
        census_ts.loc[census_ts.vent.isna(), "vent"] = census_ts.hosp.loc[
            census_ts.vent.isna()
        ] * np.mean(census_ts.vent / census_ts.hosp)
    return census_ts, params


def write_inputs(options, paramdir, census_ts, params):
    with open(path.join(paramdir, "args.json"), "w") as f:
        json.dump(options.__dict__, f)
    census_ts.to_csv(path.join(paramdir, "census_ts.csv"), index=False)
    params.to_csv(path.join(paramdir, "params.csv"), index=False)
    with open(path.join(paramdir, "git.sha"), "w") as f:
        f.write(Repo(search_parent_directories=True).head.object.hexsha)


def loglik(r):
    return -len(r) / 2 * (np.log(2 * np.pi * np.var(r))) - 1 / (
        2 * np.pi * np.var(r)
    ) * np.sum(r ** 2)


def do_shrinkage(pos, shrinkage, shrink_mask):
    densities = sps.beta.pdf(pos, a=shrinkage[0], b=shrinkage[1])
    densities *= shrink_mask
    regularization_penalty = -np.sum(np.log(densities))
    return regularization_penalty


def eval_pos(pos, params, obs, shrinkage, shrink_mask, holdout, 
             sample_obs, forecast_priors, ignore_vent):
    """function takes quantiles of the priors and outputs a posterior and relevant stats"""
    n_obs = obs.shape[0]
    nobs = n_obs-holdout
    draw = SIR_from_params(qdraw(pos, params))
    if sample_obs:
        ynoise_h = np.random.normal(scale=obs.hosp_rwstd)
        ynoise_h[0] = 0
        obs.hosp += ynoise_h
        ynoise_v = np.random.normal(scale=obs.vent_rwstd)
        ynoise_v[0] = 0
        obs.vent += ynoise_v
    if holdout > 0:
        train = obs[:-holdout]
        test = obs[-holdout:]
    else:
        train = obs
    # loss for vent
    LL = 0
    residuals_vent = None
    if train.vent.sum() > 0:
        residuals_vent = (
            draw["arr"][: (n_obs - holdout), 5] - train.vent.values[:nobs]
        )  # 5 corresponds with vent census
        if any(residuals_vent == 0):
            residuals_vent[residuals_vent == 0] = 0.01
        sigma2 = np.var(residuals_vent)
        if ignore_vent is False:
            LL += loglik(residuals_vent)

    # loss for hosp
    residuals_hosp = (
        draw["arr"][: (n_obs - holdout), 3] - train.hosp.values[:nobs]
    )  # 3 corresponds with hosp census
    if any(residuals_hosp == 0):
        residuals_hosp[residuals_hosp == 0] = 0.01
    sigma2 = np.var(residuals_hosp)
    LL += loglik(residuals_hosp)

    Lprior = np.log(draw["parms"].prob).sum()
    posterior = LL + Lprior
    # shrinkage -- the regarization parameter reaches its max value at the median of each prior.
    # the penalty gets subtracted off of the posterior
    if shrinkage is not None:
        assert (str(type(shrinkage).__name__) == "ndarray") & (len(shrinkage) == 2)
        posterior -= do_shrinkage(pos, shrinkage, shrink_mask)
    # forecast prior: compute the probability of the current forecast undet the specified prior
    # first compute the percent change in the forecast, one week out
    # then compute the probability of the change under the prior
    if forecast_priors['sig']>0:
        hosp_next_week = draw['arr'][n_obs+7,3]
        hosp_now = train.hosp.values[-1]
        hosp_pct_diff = (hosp_next_week/hosp_now-1) * 100
        hosp_forecast_prob = sps.norm.pdf(hosp_pct_diff, forecast_priors['mu'], forecast_priors['sig'])
        
        vent_next_week = draw['arr'][n_obs+7,5]
        vent_now = train.vent.values[-1]
        vent_pct_diff = (vent_next_week/vent_now-1) * 100
        vent_forecast_prob = sps.norm.pdf(vent_pct_diff, forecast_priors['mu'], forecast_priors['sig'])      

        forecast_prior_contrib = (hosp_forecast_prob * vent_forecast_prob)
        forecast_prior_contrib = np.log(forecast_prior_contrib) if forecast_prior_contrib >0 else -np.inf
        posterior += forecast_prior_contrib

    out = dict(
        pos=pos,
        draw=draw,
        posterior=posterior,
        residuals_vent=residuals_vent,
        residuals_hosp=residuals_hosp,
    )
    if holdout > 0:
        res_te_vent = draw["arr"][(n_obs - holdout) : n_obs, 5] - test.vent.values[:n_obs]
        res_te_hosp = draw["arr"][(n_obs - holdout) : n_obs, 3] - test.hosp.values[:n_obs]
        test_loss = (np.mean(res_te_hosp ** 2) + np.mean(res_te_vent ** 2)) / 2
        out.update({"test_loss": test_loss})
    return out


def chain(seed, params, obs, n_iters, shrinkage, holdout, 
          forecast_priors,
          sample_obs,
          ignore_vent):
    np.random.seed(seed)
    if shrinkage is not None:
        assert (shrinkage < 1) and (shrinkage >= 0.05)
        sq1 = shrinkage / 2
        sq2 = 1 - shrinkage / 2
        shrinkage = beta_from_q(sq1, sq2)
        shrink_mask= np.array([1 if "" in i else 0 for i in params.param])
    current_pos = eval_pos(
        pos = np.random.uniform(size=params.shape[0]),
        params = params,
        obs = obs, 
        shrinkage=shrinkage,
        shrink_mask = shrink_mask,
        holdout=holdout,
        sample_obs=sample_obs,
        forecast_priors = forecast_priors,
        ignore_vent = ignore_vent
    )
    outdicts = []
    U = np.random.uniform(0, 1, n_iters)
    posterior_history = []
    jump_sd = .2 # this is the starting value
    for ii in range(n_iters):
        try:
            proposed_pos = eval_pos(
                jumper(current_pos["pos"], jump_sd),
                params,
                obs,
                shrinkage=shrinkage,
                shrink_mask = shrink_mask,
                holdout=holdout,
                sample_obs=sample_obs,
                forecast_priors = forecast_priors,
                ignore_vent = ignore_vent
            )
            p_accept = np.exp(proposed_pos["posterior"] - current_pos["posterior"])
            if U[ii] < p_accept:
                current_pos = proposed_pos

        except Exception as e:
            print(e)
        # append the relevant results
        out = {
            current_pos["draw"]["parms"].param[i]: current_pos["draw"]["parms"].val[i]
            for i in range(params.shape[0])
        }
        # out.update({"arr": current_pos["draw"]["arr"]})
        out.update({"arr": current_pos["draw"]["arr_stoch"]})
        out.update({"iter": ii})
        out.update({"chain": seed})
        out.update({"posterior": proposed_pos["posterior"]})
        out.update({"offset": current_pos["draw"]["offset"]})
        out.update({"s": current_pos['draw']['s']})
        out.update({"e": current_pos['draw']['e']})
        out.update({"i": current_pos['draw']['i']})
        out.update({"r": current_pos['draw']['r']})
        if holdout > 0:
            out.update({"test_loss": current_pos["test_loss"]})
        outdicts.append(out)
        posterior_history.append(current_pos['posterior'])
        if (ii%100 == 0) and (ii>200):
            # diagnose:
            always_rejecting = len(list(set(posterior_history[-99:])))<10
            if (ii>2000) and (ii%1000 == 0):
                flat = np.mean(posterior_history[-999:]) < np.mean(posterior_history[-1990:-999])
            else:
                flat = False
            if always_rejecting or flat:
                jump_sd *= .9
        # TODO: write down itermediate chains in case of a crash... also re-read if we restart. Good for debugging purposes.
    return pd.DataFrame(outdicts)


def get_test_loss(n_iters, seed, holdout, shrinkage, params, obs, 
                  forecast_priors, ignore_vent):
    return chain(n_iters = n_iters, seed = seed, params=params, 
                 obs=obs, shrinkage=shrinkage, holdout=holdout,
                 forecast_priors = forecast_priors, ignore_vent = ignore_vent)["test_loss"]


def do_chains(n_iters, 
              params, 
              obs, 
              best_penalty, 
              sample_obs, 
              holdout, 
              n_chains, 
              forecast_priors, 
              parallel,
              ignore_vent):
    tuples_for_starmap = [(i, params, obs, n_iters, best_penalty, holdout, \
                           forecast_priors, sample_obs, ignore_vent) \
                          for i in range(n_chains)]
    # get the final answer based on the best penalty
    if parallel:
        pool = mp.Pool(mp.cpu_count())
        chains = pool.starmap(chain, tuples_for_starmap)
        pool.close()
    else:
        chains = map(lambda x: chain(*x), tuples_for_starmap)
    df = pd.concat(chains, ignore_index=True)
    return df

ModuleNotFoundError: No module named 'scipy'