In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd
import pymc3 as pm
import datetime
import platform
import sys

import theano
theano.config.floatX = 'float64'  # necessary to prevent an dtype error
import theano.tensor as tt
import os
import matplotlib.pyplot as plt

from src.SEIR_Forecast.Bayesian_Inference_plotting import *
import src.SEIR_Forecast.Bayesian_Inference_SEIR_helper as model_helper

%matplotlib inline

## Tutorial

In [None]:
model = pm.Model()
with model:
    mu1 = pm.Normal("mu1", mu=0, sigma=1)
    trace = pm.sample(2000, init=None, step=pm.Metropolis())
    print(trace.stat_names)

In [None]:
varnames = [str(x).replace('_log__', '') for x in model.free_RVs]
print(varnames)
print(trace['mu1'].shape)

In [None]:
def plot_hist_custom(model, trace, varname, colors = ('tab:blue', 'tab:green'), bins = 50):
    """
    Plots one histogram of the prior and posterior distribution of the variable varname.
    Parameters
    ----------
    model: pm.Model instance
    trace: trace of the model
    ax: matplotlib.axes  instance
    varname: string
    colors: list with 2 colornames
    bins:  number or array
        passed to np.hist
    Returns
    -------
    None
    """
    plt.hist(trace[varname], bins=bins, density=True, color=colors[1],
            label='Posterior')
    ax = plt.gca()
    limits = ax.get_xlim()
    x = np.linspace(*limits, num=100)
    try:
        ax.plot(x, get_prior_distribution(model, x, varname), label='Prior',
                color=colors[0], linewidth=3)
    except:
        pass
    ax.set_xlim(*limits)
    ax.set_xlabel(varname)
    
plot_hist_custom(model, trace, 'mu1')

## Bayesian MCMC main

In [None]:
def get_change_points(final_date, final_change_date, cluster_id):
    print('get_change_points', final_date, cluster_id)

    if final_date == datetime.datetime(2020, 4, 30):
        prior_date_1 = datetime.datetime(2020, 3, 25)
        prior_date_2 = datetime.datetime(2020, 4, 1)
        prior_date_3 = datetime.datetime(2020, 4, 8)
            
    elif final_date == datetime.datetime(2020, 5, 15):
        prior_date_1 = datetime.datetime(2020, 4, 10)
        prior_date_2 = datetime.datetime(2020, 4, 17)
        prior_date_3 = datetime.datetime(2020, 4, 24)
    
    elif final_date == datetime.datetime(2020, 5, 31):
        prior_date_1 = datetime.datetime(2020, 4, 25)
        prior_date_2 = datetime.datetime(2020, 5, 1)
        prior_date_3 = datetime.datetime(2020, 5, 8)
        
    elif final_date == datetime.datetime(2020, 6, 15):
        prior_date_1 = datetime.datetime(2020, 5, 10)
        prior_date_2 = datetime.datetime(2020, 5, 17)
        prior_date_3 = datetime.datetime(2020, 5, 24)
            
    elif final_date == datetime.datetime(2020, 6, 30):
        prior_date_1 = datetime.datetime(2020, 5, 25)
        prior_date_2 = datetime.datetime(2020, 6, 1)
        prior_date_3 = datetime.datetime(2020, 6, 8)
            
        
    change_points = [dict(pr_mean_date_begin_transient=prior_date_1,
                      pr_sigma_date_begin_transient=1,
                      pr_median_lambda=0.2,
                      pr_sigma_lambda=0.5),
                 dict(pr_mean_date_begin_transient=prior_date_2,
                      pr_sigma_date_begin_transient=1,
                      pr_median_lambda=0.2,
                      pr_sigma_lambda=0.5),
                 dict(pr_mean_date_begin_transient=prior_date_3,
                      pr_sigma_date_begin_transient=1,
                      pr_median_lambda=0.2,
                      pr_sigma_lambda=0.5),
                 dict(pr_mean_date_begin_transient=final_change_date,
                      pr_sigma_date_begin_transient=1,
                      pr_median_lambda=0.2,
                      pr_sigma_lambda=0.5)]
                      
                      
    return change_points

In [None]:
def run(sir_model, N_SAMPLES, cluster_save_path):
    print('sample start')
    with sir_model:
        trace = pm.sample(N_SAMPLES, model=sir_model, step=pm.Metropolis(), progressbar=True)
    pm.save_trace(trace, cluster_save_path + 'sir_model.trace', overwrite=True)
    print('sample end')
    # -------- prepare data for visualization ---------------
    varnames = get_all_free_RVs_names(sir_model)
    #for varname in varnames:
        #visualize_trace(trace[varname][:, None], varname, N_SAMPLES)
        
    lambda_t = np.median(trace['lambda_t'][:, :], axis=0)
    μ = np.median(trace['mu'][:, None], axis=0)

    # -------- visualize histogram ---------------
    num_cols = 5
    num_rows = int(np.ceil(len(varnames)/num_cols))
    x_size = num_cols * 2.5
    y_size = num_rows * 2.5

    fig, axes = plt.subplots(num_rows, num_cols, figsize = (x_size, y_size),squeeze=False)
    i_ax = 0
    for i_row, axes_row in enumerate(axes):
        for i_col, ax in enumerate(axes_row):
            if i_ax >= len(varnames):
                ax.set_visible(False)
                continue
            else:
                plot_hist(sir_model, trace, ax, varnames[i_ax], 
                                         colors=('tab:blue', 'tab:green'))
            if i_col == 0:
                ax.set_ylabel('Density')
            if i_col == 0 and i_row == 0:
                ax.legend()
            i_ax += 1
    fig.subplots_adjust(wspace=0.25, hspace=0.4)
    plt.savefig(cluster_save_path + 'plot_hist.png')
    plt.clf()

    np.save(cluster_save_path + 'varnames.npy', varnames)
    np.save(cluster_save_path + 'SIR_params.npy', [lambda_t, μ])
    
    

def SIR_with_change_points(
    S_begin_beta,
    I_begin_beta,
    new_cases_obs,
    change_points_list,
    date_begin_simulation,
    num_days_sim,
    diff_data_sim,
    N,
    priors_dict=None,
    weekends_modulated=False
):
    """
        Parameters
        ----------
        new_cases_obs : list or array
            Timeseries (day over day) of newly reported cases (not the total number)
        change_points_list : list of dicts
            List of dictionaries, each corresponding to one change point.
            Each dict can have the following key-value pairs. If a pair is not provided,
            the respective default is used.
                * pr_mean_date_begin_transient :     datetime.datetime, NO default
                * pr_median_lambda :                 number, same as default priors, below
                * pr_sigma_lambda :                  number, same as default priors, below
                * pr_sigma_date_begin_transient :    number, 3
                * pr_median_transient_len :          number, 3
                * pr_sigma_transient_len :           number, 0.3
        date_begin_simulation: datetime.datetime
            The begin of the simulation data
        num_days_sim : integer
            Number of days to forecast into the future
        diff_data_sim : integer
            Number of days that the simulation-begin predates the first data point in
            `new_cases_obs`. This is necessary so the model can fit the reporting delay.
            Set this parameter to a value larger than what you expect to find
            for the reporting delay.
            should be significantly larger than the expected delay,
            in order to always fit the same number of data points.
        N : number
            The population size. For Germany, we used 83e6
        priors_dict : dict
            Dictionary of the prior assumptions
            Possible key-value pairs (and default values) are:
                * pr_beta_I_begin :        number, default = 100
                * pr_median_lambda_0 :     number, default = 0.4
                * pr_sigma_lambda_0 :      number, default = 0.5
                * pr_median_mu :           number, default = 1/8
                * pr_sigma_mu :            number, default = 0.2
                * pr_median_delay :        number, default = 8
                * pr_sigma_delay :         number, default = 0.2
                * pr_beta_sigma_obs :      number, default = 10
                * week_end_days :          tuple,  default = (6,7)
                * pr_mean_weekend_factor : number, default = 0.7
                * pr_sigma_weekend_factor :number, default = 0.17
        weekends_modulated : bool
            Whether to add the prior that cases are less reported on week ends. Multiplies the new cases numbers on weekends
            by a number between 0 and 1, given by a prior beta distribution. The beta distribution is parametrised
            by pr_mean_weekend_factor and pr_sigma_weekend_factor
        weekend_modulation_type : 'step' or 'abs_sine':
            whether the weekends are modulated by a step function, which only multiplies the days given by  week_end_days
            by the week_end_factor, or whether the whole week is modulated by an abs(sin(x)) function, with an offset
            with flat prior.
        Returns
        -------
        : pymc3.Model
            Returns an instance of pymc3 model with the change points
    """
    if priors_dict is None:
        priors_dict = dict()

    default_priors = dict(
        pr_beta_I_begin=10000.0,
        pr_median_lambda_0=0.2,
        pr_sigma_lambda_0=0.5,
        pr_median_mu=1 / 8,
        pr_sigma_mu=0.2,
        pr_median_delay= 1.0, # 1.0,
        pr_sigma_delay=0.2,
        pr_beta_sigma_obs=5.0,
        week_end_days = (6,7),
        pr_mean_weekend_factor=0.7,
        pr_sigma_weekend_factor=0.17
    )
    default_priors_change_points = dict(
        pr_median_lambda=default_priors["pr_median_lambda_0"],
        pr_sigma_lambda=default_priors["pr_sigma_lambda_0"],
        pr_sigma_date_begin_transient=3.0,
        pr_median_transient_len=3.0,
        pr_sigma_transient_len=0.3,
        pr_mean_date_begin_transient=None,
    )

    if not weekends_modulated:
        del default_priors['week_end_days']
        del default_priors['pr_mean_weekend_factor']
        del default_priors['pr_sigma_weekend_factor']

    for prior_name in priors_dict.keys():
        if prior_name not in default_priors:
            raise RuntimeError(f"Prior with name {prior_name} not known")
    for change_point in change_points_list:
        for prior_name in change_point.keys():
            if prior_name not in default_priors_change_points:
                raise RuntimeError(f"Prior with name {prior_name} not known")

    for prior_name, value in default_priors.items():
        if prior_name not in priors_dict:
            priors_dict[prior_name] = value
            # print(f"{prior_name} was set to default value {value}")
    for prior_name, value in default_priors_change_points.items():
        for i_cp, change_point in enumerate(change_points_list):
            if prior_name not in change_point:
                change_point[prior_name] = value
                # print(f"{prior_name} of change point {i_cp} was set to default value {value}")

    if num_days_sim < len(new_cases_obs) + diff_data_sim:
        raise RuntimeError(
            "Simulation ends before the end of the data. Increase num_days_sim."
        )

    # ------------------------------------------------------------------------------ #
    # Model and prior implementation
    # ------------------------------------------------------------------------------ #

    with pm.Model() as model:
        # all pm functions now apply on the model instance
        # true cases at begin of loaded data but we do not know the real number
        I_begin = pm.Normal(name="I_begin", mu=I_begin_beta, sigma=I_begin_beta/10)
        S_begin = pm.Normal(name="S_begin", mu=S_begin_beta, sigma=S_begin_beta/10)
        # S_begin = N - I_begin

        # I_begin_print = tt.printing.Print('I_begin')(I_begin)
        # S_begin_print = tt.printing.Print('S_begin')(S_begin)
        # fraction of people that are newly infected each day
        lambda_list = []
        lambda_list.append(
            pm.Lognormal(
                name="lambda_0",
                mu=np.log(priors_dict["pr_median_lambda_0"]),
                sigma=priors_dict["pr_sigma_lambda_0"],
            )
        )
        for i, cp in enumerate(change_points_list):
            lambda_list.append(
                pm.Lognormal(
                    name=f"lambda_{i + 1}",
                    mu=np.log(cp["pr_median_lambda"]),
                    sigma=cp["pr_sigma_lambda"],
                )
            )

        # list of start dates of the transient periods of the change points
        tr_begin_list = []
        dt_before = date_begin_simulation
        for i, cp in enumerate(change_points_list):
            dt_begin_transient = cp["pr_mean_date_begin_transient"]
            if dt_before is not None and dt_before > dt_begin_transient:
                raise RuntimeError("Dates of change points are not temporally ordered")
            print('--------------Hi----------')
            print((dt_begin_transient - date_begin_simulation).days)
            print(dt_begin_transient, date_begin_simulation)
            prior_mean = (dt_begin_transient - date_begin_simulation).days - 1  # convert the provided date format (argument) into days (a number)

            tr_begin = pm.Normal(
                name=f"transient_begin_{i}",
                mu=prior_mean,
                sigma=cp["pr_sigma_date_begin_transient"],
            )
            tr_begin_list.append(tr_begin)
            dt_before = dt_begin_transient

        # same for transient times
        tr_len_list = []
        for i, cp in enumerate(change_points_list):
            tr_len = pm.Lognormal(
                name=f"transient_len_{i}",
                mu=np.log(cp["pr_median_transient_len"]),
                sigma=cp["pr_sigma_transient_len"],
            )
            tr_len_list.append(tr_len)

        # build the time-dependent spreading rate
        lambda_t_list = [lambda_list[0] * tt.ones(num_days_sim)]
        lambda_before = lambda_list[0]

        for tr_begin, tr_len, lambda_after in zip(
            tr_begin_list, tr_len_list, lambda_list[1:]
        ):
            lambda_t = model_helper.smooth_step_function(
                start_val=0,
                end_val=1,
                t_begin=tr_begin,
                t_end=tr_begin + tr_len,
                t_total=num_days_sim,
            ) * (lambda_after - lambda_before)
            lambda_before = lambda_after
            lambda_t_list.append(lambda_t)
        lambda_t = sum(lambda_t_list)

        # fraction of people that recover each day, recovery rate mu
        mu = pm.Lognormal(
            name="mu",
            mu=np.log(priors_dict["pr_median_mu"]),
            sigma=priors_dict["pr_sigma_mu"],
        )

        # delay in days between contracting the disease and being recorded
        delay = pm.Lognormal(
            name="delay",
            mu=np.log(priors_dict["pr_median_delay"]),
            sigma=priors_dict["pr_sigma_delay"],
        )

        # prior of the error of observed cases
        sigma_obs = pm.HalfCauchy("sigma_obs", beta=priors_dict["pr_beta_sigma_obs"])

        # -------------------------------------------------------------------------- #
        # training the model with loaded data provided as argument
        # -------------------------------------------------------------------------- #

        S, I, new_I = _SIR_model(
            lambda_t=lambda_t, mu=mu, S_begin=S_begin, I_begin=I_begin, N=N
        )

        # ignore this delay
        # new_cases_inferred = model_helper.delay_cases(
        #     new_I_t=new_I,
        #     len_new_I_t=num_days_sim,
        #     len_out=num_days_sim - diff_data_sim,
        #     delay=delay,
        #     delay_diff=diff_data_sim,
        # )
        new_cases_inferred = new_I

        # likelihood of the model:
        # observed cases are distributed following studentT around the model.
        # we want to approximate a Poisson distribution of new cases.
        # we choose nu=4 to get heavy tails and robustness to outliers.
        # https://www.jstor.org/stable/2290063
        num_days_data = new_cases_obs.shape[-1]
        pm.StudentT(
            name="_new_cases_studentT",
            nu=4,
            mu=new_cases_inferred[:num_days_data],
            sigma=tt.abs_(new_cases_inferred[:num_days_data] + 1) ** 0.5
            * sigma_obs,  # +1 and tt.abs to avoid nans
            observed=new_cases_obs,
        )

        # add these observables to the model so we can extract a time series of them
        # later via e.g. `model.trace['lambda_t']`
        pm.Deterministic("lambda_t", lambda_t)
        pm.Deterministic("new_cases", new_cases_inferred)
    return model



def _SIR_model(lambda_t, mu, S_begin, I_begin, N):
    """
        Implements the susceptible-infected-recovered model
        Parameters
        ----------
        lambda_t : ~numpy.ndarray
            time series of spreading rate, the length of the array sets the
            number of steps to run the model for
        mu : number
            recovery rate
        S_begin : number
            initial number of susceptible at first time step
        I_begin : number
            initial number of infected
        N : number
            population size
        Returns
        -------
        S : array
            time series of the susceptible
        I : array
            time series of the infected
        new_I : array
            time series of the new infected
    """
    new_I_0 = tt.zeros_like(I_begin)

    def next_day(lambda_t, S_t, I_t, _, mu, N):
        new_I_t = lambda_t / N * I_t * S_t
        S_t = S_t - new_I_t
        I_t = I_t + new_I_t - mu * I_t
        I_t = tt.clip(I_t, 0, N)  # for stability
        return S_t, I_t, new_I_t

    # theano scan returns two tuples, first one containing a time series of
    # what we give in outputs_info : S, I, new_I
    outputs, _ = theano.scan(
        fn=next_day,
        sequences=[lambda_t],
        outputs_info=[S_begin, I_begin, new_I_0],
        non_sequences=[mu, N],
    )
    return outputs

In [None]:
'''
Author: Junbong Jang
Date 4/29/2020

Load timeseries data, train the model, and forecast

'''

import pandas as pd
from datetime import date
import statistics
import math

from src.EDA.parseJohnsHopkins import johnsHopkinsPopulation, getTzuHsiClusters
from src.SEIR_Forecast import Bayesian_Inference_SEIR
from src.SEIR_Forecast.Bayesian_Inference_plotting import plot_cases
from src.SEIR_Forecast.timeseries_eval import *
from src.SEIR_Forecast.reassignment_helper import *
from src.SEIR_Forecast.data_processing import *
from src.SEIR_Forecast.visualizer import *
from src.SEIR_Forecast.SIR import sir_forecast_a_county


def forecast_main(clusters, cases_df, vel_cases_df, population_df, cluster_mode, init_cluster_num, max_cluster_num, initial_date,
                  final_date, final_change_date, num_days_future, dataset_final_date, run_mode, root_save_path):
    rmse_per_cluster_list = []
    total_re_per_cluster_list = []
    mean_rsquared_per_cluster_list = []
    mape_per_cluster_list = []
    wape_per_cluster_list = []

    mse_per_county_per_cluster_list = []
    re_per_county_per_cluster_list = []
    rsquared_per_county_per_cluster_list = []
    mape_per_county_per_cluster_list = []
    wape_per_county_per_cluster_list = []

    unclustered_rmse_per_cluster_list = []
    unclustered_mse_per_county_per_cluster_list = []
    unclustered_total_re_per_cluster_list = []
    unclustered_re_per_county_per_cluster_list = []

    initial_date = datetime.datetime.strptime(f'{initial_date}/2020', '%m/%d/%Y')
    final_date = datetime.datetime.strptime(f'{final_date}/2020', '%m/%d/%Y')
    dataset_final_date = datetime.datetime.strptime(f'{dataset_final_date}/2020', '%m/%d/%Y')

    # cluster id 0 was not clustered by Tzu Hsi but I still use it
    for cluster_id in range(init_cluster_num, max_cluster_num):
        if cluster_mode == 'unclustered':
            cluster_id = 'All'
            chosen_cluster_series = clusters
        else:
            chosen_cluster_series = clusters[clusters == cluster_id]
        cluster_counties = chosen_cluster_series.index.tolist()

        print('-----------------------------')
        print('Cluster ID: ', cluster_id)
        # ------------- Create save folders --------------
        cluster_save_path = root_save_path + f'/cluster_{cluster_id}/'
        if os.path.isdir(cluster_save_path) is False:
            os.mkdir(cluster_save_path)
        cluster_all_save_path = root_save_path + f'/cluster_All/'

        # -------------- Data Preprocessing --------------
        cluster_cases_df, proc_population_series = preprocess_dataset(cases_df.copy(), population_df.copy(), cluster_counties)
        cluster_vel_cases_df, _ = preprocess_dataset(vel_cases_df.copy(), population_df.copy(), cluster_counties)

        cluster_cases_df, current_cumulative_cases_df, future_cumulative_cases_df, old_cumulative_infected_cases_series, date_begin_sim, num_days_sim = \
            process_date(cluster_cases_df, initial_date, final_date, dataset_final_date, num_days_future)
        cluster_vel_cases_df, current_vel_cases_df, future_vel_cases_df, _, _, _ = \
            process_date(cluster_vel_cases_df, initial_date, final_date, dataset_final_date, num_days_future)


        current_cumulative_cases_series = current_cumulative_cases_df.sum(axis=1)
        current_vel_cases_series = current_vel_cases_df.sum(axis=1)
        cluster_total_population = proc_population_series.sum()
        future_cumulative_cases = future_cumulative_cases_df.sum(axis=1)[-1]

        print('old_cumulative_infected_cases_series:', old_cumulative_infected_cases_series)
        print('Cumulative future cases:', future_cumulative_cases)
        print('population:', cluster_total_population)
        print('Remaining population:', cluster_total_population - future_cumulative_cases)

        visualize_trend_with_r_not(cluster_id, cluster_vel_cases_df, cluster_save_path)

        # --------------- Get SIR Model -----------------
        # convert cumulative infected to daily total infected cases
        current_total_cases_series = current_cumulative_cases_series - old_cumulative_infected_cases_series.sum()
        future_total_cases_df = future_cumulative_cases_df - old_cumulative_infected_cases_series

        day_1_cumulative_infected_cases = current_cumulative_cases_series[0]
        S_begin_beta = cluster_total_population - day_1_cumulative_infected_cases
        I_begin_beta = current_total_cases_series[0]  # day 1 total infected cases

        print('day_1_cumulative_infected_cases: ', day_1_cumulative_infected_cases)
        print('S_begin_beta: ', S_begin_beta)
        print('I_begin_beta: ', I_begin_beta)

        change_points = get_change_points(final_date, final_change_date, cluster_id)
#         print('change_points')
#         print(change_points)
#         sys.stderr.close()
        sir_model = SIR_with_change_points(S_begin_beta,
                                           I_begin_beta,
                                           current_vel_cases_series.to_numpy(),  # current_total_cases_series.to_numpy(),
                                           change_points_list=change_points,
                                           date_begin_simulation=date_begin_sim,
                                           num_days_sim=num_days_sim,
                                           diff_data_sim=0,
                                           N=cluster_total_population)

        # ---------- Estimate Parameters for SIR model ------------
        if run_mode == 'train':
            run(sir_model, N_SAMPLES=500, cluster_save_path=cluster_save_path)

        elif run_mode == 'eval':
            trace = pm.load_trace(cluster_save_path + 'sir_model.trace', model=sir_model)
            susceptible_series = proc_population_series - current_cumulative_cases_df.loc[final_date]

            # ---------- Forecast using unclustered data ------------------
            t = range(len(future_vel_cases_df.values))
            if cluster_mode == 'clustered':
                lambda_t, μ = np.load(cluster_all_save_path + 'SIR_params.npy', allow_pickle=True)
                beta, gamma = lambda_t[-1], μ[0]
                print('beta, gamma', beta, gamma)
                cluster_all_vel_case_forecast = sir_forecast_a_county(susceptible_series.sum(), moving_average_from_df(current_vel_cases_df).sum(),
                                                             cluster_total_population, t, beta, gamma, '', '')
            else:
                cluster_all_vel_case_forecast = None

            # ----------- Forecast using clustered data ------------------
            lambda_t, μ = np.load(cluster_save_path + 'SIR_params.npy', allow_pickle=True)
            beta, gamma = lambda_t[-1], μ[0]
            print('beta, gamma', beta, gamma)
            cluster_forecast_I0 = np.mean(trace['new_cases'][:, len(current_vel_cases_series)], axis=0)

            cluster_vel_case_forecast = sir_forecast_a_county(susceptible_series.sum(), cluster_forecast_I0,
                                                         cluster_total_population, t, beta, gamma, '', '')

            # ----------- Forecast Visualization ---------------
            plot_cases(cluster_id, trace, current_vel_cases_series, future_vel_cases_df, cluster_vel_case_forecast,
                    cluster_all_vel_case_forecast, date_begin_sim, diff_data_sim=0, num_days_future=num_days_future, cluster_save_path=cluster_save_path)

            # ---------- Evaluation per county -----------
            cluster_mse_dict, cluster_re_dict, cluster_rsquared_dict, cluster_mape_dict, cluster_wape_dict = \
                eval_per_cluster(susceptible_series, moving_average_from_df(current_vel_cases_df), future_vel_cases_df, proc_population_series, num_days_future,cluster_save_path)

            if cluster_mode == 'unclustered':
                for cluster_id in range(0, max_cluster_num):
                    local_mse_list = []
                    local_re_list = []
                    chosen_cluster_series = clusters[clusters == cluster_id]
                    cluster_counties = chosen_cluster_series.index.tolist()
                    for a_county in cluster_counties:
                        if a_county in cluster_mse_dict:
                            local_mse_list.append(cluster_mse_dict[a_county])
                        if a_county in cluster_re_dict:
                            local_re_list.append(cluster_re_dict[a_county])
                    unclustered_rmse_per_cluster_list.append(math.sqrt(statistics.mean(local_mse_list)))
                    unclustered_mse_per_county_per_cluster_list.append(local_mse_list)
                    unclustered_total_re_per_cluster_list.append(statistics.mean(local_re_list))
                    unclustered_re_per_county_per_cluster_list.append(local_re_list)

            elif cluster_mode == 'clustered':
                rmse_per_cluster_list.append(math.sqrt(statistics.mean(cluster_mse_dict.values())))
                total_re_per_cluster_list.append(statistics.mean(cluster_re_dict.values()))
                mean_rsquared_per_cluster_list.append(statistics.mean(cluster_rsquared_dict.values()))
                # mape_per_cluster_list.append(statistics.mean(cluster_mape_dict.values()))
                wape_per_cluster_list.append(statistics.mean(cluster_wape_dict.values()))

                mse_per_county_per_cluster_list.append(list(cluster_mse_dict.values()))
                re_per_county_per_cluster_list.append(list(cluster_re_dict.values()))
                rsquared_per_county_per_cluster_list.append(list(cluster_rsquared_dict.values()))
                mape_per_county_per_cluster_list.append(list(cluster_mape_dict.values()))
                wape_per_county_per_cluster_list.append(list(cluster_wape_dict.values()))

        if cluster_mode == 'unclustered':
            break  # only run once for unclustered dataset

    return rmse_per_cluster_list, mse_per_county_per_cluster_list, mean_rsquared_per_cluster_list, rsquared_per_county_per_cluster_list, \
           mape_per_cluster_list, mape_per_county_per_cluster_list, wape_per_cluster_list, wape_per_county_per_cluster_list, \
           total_re_per_cluster_list, re_per_county_per_cluster_list, \
           unclustered_rmse_per_cluster_list, unclustered_mse_per_county_per_cluster_list, \
           unclustered_total_re_per_cluster_list, unclustered_re_per_county_per_cluster_list


if __name__ == "__main__":
    # Fitting the SEIR model to the data and estimating the parameters with the cluster id.
    dataset_final_date = '8/1'
    cluster_type = "no_constants"
    run_mode = 'train'
    # cluster_mode = 'clustered'

    # initial_date_list = ['3/15', '4/1', '4/15', '5/1', '5/15']
    # final_date_list = ['4/30', '5/15', '5/31', '6/15', '6/30']
    # final_change_date_list = [datetime.datetime(2020, 4, 15), datetime.datetime(2020, 4, 30), datetime.datetime(2020, 5, 15), datetime.datetime(2020, 5, 31), datetime.datetime(2020, 6, 15)]

    initial_date_list = ['5/1']
    final_date_list = ['6/15']
    final_change_date_list = [datetime.datetime(2020, 6, 5)]

    # initial_date_list = ['4/1']
    # final_date_list = ['5/15']
    # final_change_date_list = [datetime.datetime(2020, 5, 5)]

    for initial_date, final_date, final_change_date in zip(initial_date_list, final_date_list, final_change_date_list):
        # load data
        initial_clusters = getTzuHsiClusters(column_date=f"{initial_date}~{final_date}", cluster_type=cluster_type)
        max_cluster_num = len(initial_clusters.unique())
        cases_df = pd.read_csv(f'../../generated/us_cases_counties.csv', header=0, index_col=0)
        vel_cases_df = pd.read_csv(f'../../generated/us_velocity_cases_counties.csv', header=0, index_col=0)
        population_df = johnsHopkinsPopulation()

        # set save path
        date_info = f'{initial_date.replace("/","-")}_{final_date.replace("/","-")}_{dataset_final_date.replace("/","-")}_{final_change_date.strftime("%m-%d")}_{cluster_type}'
        root_save_path = f'../../generated/plots/{date_info}/'
        if os.path.isdir(root_save_path) is False:
            os.mkdir(root_save_path)

        # initial Parameters
        reassigned_clusters = initial_clusters
        REASSIGN_COUNTER_MAX = 1
        reassign_counter_init = 0  # to load cluster data from intermediate reassign num
        init_cluster_num = 0

        if run_mode == 'eval':
            max_cluster_num = max_cluster_num - 1  # remove cluster 11 as a outlier

        # if cluster_mode == 'unclustered':
        #     REASSIGN_COUNTER_MAX = 1

        # root_save_path = f'../../generated/plots/{date_info}/reassign_{reassign_counter_init-1}/'
        # reassigned_clusters = pd.read_csv(root_save_path + f'clusters.csv', header=0, index_col=0)
        # reassigned_clusters = reassigned_clusters.iloc[:, 0]
        # reassigned_clusters = reassign_county(reassigned_clusters, max_cluster_num, cases_df, population_df, initial_date, final_date, num_days_future, root_save_path)

        # -------------- Run Reassignment Model ------------------------
        for reassign_counter in range(reassign_counter_init, REASSIGN_COUNTER_MAX):
            print('reassign_counter: ', reassign_counter)
            if reassign_counter < REASSIGN_COUNTER_MAX - 1:
                num_days_future = 7  # only validation set
            else:
                num_days_future = 14  # validation + test set
                
            root_save_path = f'../../generated/plots/{date_info}/reassign_{reassign_counter}/'
            if run_mode == 'eval':  # load reassigned data
                reassigned_clusters = pd.read_csv(root_save_path + f'clusters.csv', header=0, index_col=0)
                reassigned_clusters = reassigned_clusters.iloc[:, 0]
            elif run_mode == 'train':
                if os.path.isdir(root_save_path) is False:
                    os.mkdir(root_save_path)
                reassigned_clusters.to_csv(root_save_path + f'clusters.csv')  # save current county assignment to clusters
            print(reassigned_clusters)

            # ------------ Forecast -----------------
            rmse_per_cluster_list, mse_per_county_per_cluster_list, mean_rsquared_per_cluster_list, rsquared_per_county_per_cluster_list, \
             mape_per_cluster_list, mape_per_county_per_cluster_list, wape_per_cluster_list, wape_per_county_per_cluster_list, \
            total_re_per_cluster_list, re_per_county_per_cluster_list, \
            _, _, _, _ = forecast_main(reassigned_clusters, cases_df, vel_cases_df, population_df, 'clustered',
                                       init_cluster_num, max_cluster_num, initial_date, final_date, final_change_date,
                                       num_days_future, dataset_final_date, run_mode, root_save_path)

            _, _, _, _, \
            _, _, _, _, \
            _, _, \
            unclustered_rmse_per_cluster_list, unclustered_mse_per_county_per_cluster_list, \
            unclustered_total_re_per_cluster_list, unclustered_re_per_county_per_cluster_list \
                = forecast_main(reassigned_clusters, cases_df, vel_cases_df, population_df, 'unclustered',
                                init_cluster_num, max_cluster_num, initial_date, final_date, final_change_date,
                                num_days_future, dataset_final_date, run_mode, root_save_path)

            # ------------------------------
            if run_mode == 'train':
                if reassign_counter < REASSIGN_COUNTER_MAX - 1:
                    reassigned_clusters = reassign_county(reassigned_clusters, max_cluster_num, cases_df, vel_cases_df,
                                                          population_df, initial_date, final_date,
                                                          num_days_future, root_save_path)

            # ------------ Evaluation ---------------
            if run_mode == 'eval':

                # ------------------ for clustered dataset -----------------------
                clustered_average_mse = average_from_list_of_list(mse_per_county_per_cluster_list)
                clustered_average_rmse = round(math.sqrt(clustered_average_mse), 3)
                average_of_rsquared = round(average_from_list_of_list(rsquared_per_county_per_cluster_list), 3)
                average_of_mape = round(average_from_list_of_list(mape_per_county_per_cluster_list), 3)
                average_of_wape = round(average_from_list_of_list(wape_per_county_per_cluster_list), 3)

                histogram_clusters(reassigned_clusters, max_cluster_num, root_save_path)

                violin_eval_clusters(mse_per_county_per_cluster_list, 'MSE', root_save_path)
                violin_eval_clusters(rsquared_per_county_per_cluster_list, 'R^2', root_save_path)
                violin_eval_clusters(mape_per_county_per_cluster_list, 'MAPE', root_save_path)
                violin_eval_clusters(wape_per_county_per_cluster_list, 'WAPE', root_save_path)

                bar_eval_clusters(max_cluster_num, rmse_per_cluster_list, clustered_average_rmse, 'RMSE', 'clustered',
                                  cluster_type, root_save_path)
                bar_eval_clusters(max_cluster_num, mean_rsquared_per_cluster_list, average_of_rsquared, 'R^2', 'clustered',
                                  cluster_type, root_save_path)
                # bar_eval_clusters(max_cluster_num, mape_per_cluster_list, average_of_mape, 'MAPE', 'clustered',
                #                   cluster_type, root_save_path)
                bar_eval_clusters(max_cluster_num, wape_per_cluster_list, average_of_wape, 'WAPE', 'clustered',
                                  cluster_type, root_save_path)

                # --------------- unclustered ---------------------
                unclustered_average_mse = average_from_list_of_list(unclustered_mse_per_county_per_cluster_list)
                unclustered_average_rmse = round(math.sqrt(unclustered_average_mse), 3)
                bar_eval_clusters(max_cluster_num, unclustered_rmse_per_cluster_list, unclustered_average_rmse, 'RMSE', 'unclustered',
                                  cluster_type, root_save_path)

                # ---- clustered and unclustered -----
                print('----------------')

                clustered_average_re = average_from_list_of_list(re_per_county_per_cluster_list)
                unclustered_average_re = average_from_list_of_list(unclustered_re_per_county_per_cluster_list)
                clustered_average_re = round(clustered_average_re, 3)
                unclustered_average_re = round(unclustered_average_re, 3)

                bar_eval_clusters_compare(rmse_per_cluster_list, unclustered_rmse_per_cluster_list, clustered_average_rmse, unclustered_average_rmse, 'RMSE', root_save_path)
                bar_eval_clusters_compare(total_re_per_cluster_list, unclustered_total_re_per_cluster_list, clustered_average_re, unclustered_average_re, 'Relative Error', root_save_path)