In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import json
import os
import copy

import matplotlib.pyplot as plt
import seaborn as sns

import matplotlib.dates as mdates
from matplotlib.dates import SA

from utils.fitting.util import get_ensemble_params
from scripts.seir.combine_multiple_runs import combine_multiple_runs
from utils.fitting.util import create_output

In [None]:
outputs_dir = '/scratche/users/sansiddh/covid-modelling/phparams_2021_0407_120856/athena/'

In [None]:
def create_run_ledger(outputs_dir):
    run_ledger = pd.DataFrame(columns=['run', 'start_date', 'seed'])

    for i, run in enumerate(os.listdir(outputs_dir)):
        if run[0] == '.':
            continue
        json_file = f'{outputs_dir}/{run}/config.json'
        with open(json_file) as f:
            config = json.load(f)
        start_date = config['fitting']['split']['start_date']
        seed = config['fitting']['optimiser_params']['seed']
        run_ledger.loc[i, :] = [run, start_date, seed]
        
    return run_ledger

In [None]:
def calculate_ess(run_ledger, outputs_dir,use_beta = True):
    for i, run in enumerate(run_ledger['run']):
        if use_beta :
            beta = np.load(f'{outputs_dir}/{run}/beta.npy')
        else:
            beta = 0
        trials_params = np.load(f'{outputs_dir}/{run}/trials_params.npy', allow_pickle=True)
        trials_losses = np.load(f'{outputs_dir}/{run}/trials_losses.npy', allow_pickle=True)

        loss_wt = np.exp(-beta*trials_losses)
        loss_wt = loss_wt / np.sum(loss_wt)
        ess = 1/np.sum(loss_wt**2)

        run_ledger.loc[i, 'ess'] = round(ess, 0)
        run_ledger.loc[i, 'beta'] = beta

    return run_ledger

In [None]:
import arviz as az
def _calculate_ensemble_params(run_ledger, outputs_dir, use_beta=True, use_hpdi=False):
    df_params_master = copy.deepcopy(run_ledger)
    df_params_master.columns = pd.MultiIndex.from_arrays([df_params_master.columns, ['']*len(df_params_master.columns)])

    for i, run in enumerate(os.listdir(outputs_dir)):
        if run[0] == '.':
            continue
        idx = df_params_master[df_params_master['run'] == run].index[0]

        if use_beta :
            beta = np.load(f'{outputs_dir}/{run}/beta.npy')
        else:
            beta = 0
        trials_params = np.load(f'{outputs_dir}/{run}/trials_params.npy', allow_pickle=True)
        trials_losses = np.load(f'{outputs_dir}/{run}/trials_losses.npy', allow_pickle=True)
        params_dict = {param: [param_dict[param] for param_dict in trials_params]
                   for param in list(trials_params[0].keys())}
        em_params, em_params_dev = get_ensemble_params(trials_params, trials_losses, beta, return_dev=True)
        list_of_params = list(em_params.keys())
        for param in list(em_params.keys()):
            df_params_master.loc[idx, (param, 'mean')] = em_params[param]
            if use_hpdi:
                l,u = az.hdi(np.array(params_dict[param]),0.95)
                df_params_master.loc[idx, (param, 'bound_l')] = l
                df_params_master.loc[idx, (param, 'bound_u')] = u
                df_params_master.loc[idx, (param, 'std_err')] = em_params_dev[param]/np.sqrt(df_params_master.loc[idx, ('ess', '')])
            else:
                df_params_master.loc[idx, (param, 'std_err')] = em_params_dev[param]/np.sqrt(df_params_master.loc[idx, ('ess', '')])

    df_params_master[('start_date', '')] = pd.to_datetime(df_params_master[('start_date', '')], format='%Y-%m-%d')
    df_params_master.columns = df_params_master.columns.map('_'.join)
    df_params_master.columns = [x if x[-1] != '_' else x[:-1] for x in df_params_master.columns]
    
    return df_params_master,list_of_params

In [None]:
run_ledger = create_run_ledger(outputs_dir)
run_ledger = calculate_ess(run_ledger, outputs_dir)
run_ledger

In [None]:
df_params_master,list_of_params = _calculate_ensemble_params(run_ledger, outputs_dir,use_beta = False,use_hpdi=True)
df_params_master.sort_values('start_date')

In [None]:
df_params_comb = df_params_master.groupby('start_date').mean()
df_params_comb.reset_index(inplace=True)

In [None]:
df_params_comb

# KDE Plot

In [None]:
# fig, axs = plt.subplots(figsize=(16, 21), nrows=4, ncols=2)
# for i, param in enumerate(list_of_params):
#     ax = axs.flat[i]
#     sns.kdeplot(data=df_params, y=param, weights='loss_wt', ax=ax)

In [None]:
fig, axs = plt.subplots(figsize=(16, 21), nrows=4, ncols=2)
for i, param in enumerate(list_of_params):
    ax = axs.flat[i]
    sns.scatterplot(data=df_params_master, x='start_date', y=f'{param}_mean', hue='seed', ax=ax)
    ax.set_title(f'Ensemble Mean of {param} for different seeds')
    ax.set_ylabel(param)
    ax.set_xlabel('Starting Date')
    
    ax.xaxis.set_major_locator(mdates.DayLocator(interval=14))
    ax.xaxis.set_minor_locator(mdates.DayLocator(interval=7))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
    ax.tick_params('x', labelrotation=45)
    ax.grid(alpha=0.5)
    
    ax.set_xlim(ax.get_xlim()[0], ax.get_xlim()[1]+7)
    
fig.suptitle(f'Ensemble Mean of all params vs time for different seeds (2k trials each)')
fig.subplots_adjust(top=0.95, hspace=0.3)

In [None]:
fig.savefig('em-mean-params-diff-seeds.png')

In [None]:
fig, axs = plt.subplots(figsize=(16, 21), nrows=4, ncols=2)
use_hpdi =True
for i, param in enumerate(list_of_params):
    ax = axs.flat[i]
    if use_hpdi:
        ax.errorbar(x=df_params_comb['start_date'], y=df_params_comb[f'{param}_mean'], yerr=np.array(df_params_comb[f'{param}_mean']-df_params_comb[f'{param}_bound_l'],df_params_comb[f'{param}_bound_u']-df_params_comb[f'{param}_mean']),ecolor = 'lightblue')
    else:
        ax.errorbar(x=df_params_comb['start_date'], y=df_params_comb[f'{param}_mean'].rolling(window = 4).mean(), yerr=df_params_comb[f'{param}_std_err'])
    ax.set_title(f'Ensemble Mean of {param} +- std error')
    ax.set_ylabel(param)
    ax.set_xlabel('Starting Date')
    if param == 'lockdown_R0':
        ax.set_ylim(.65,1.4)
    ax.xaxis.set_major_locator(mdates.DayLocator(interval=14))
    ax.xaxis.set_minor_locator(mdates.DayLocator(interval=7))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d'))
    ax.tick_params('x', labelrotation=45)
    ax.grid(alpha=0.5)
    
    ax.set_xlim(ax.get_xlim()[0], ax.get_xlim()[1]+7)
    
fig.suptitle(f'Ensemble Mean of all params vs time (5 seeds combined)')
fig.subplots_adjust(top=0.95, hspace=0.3)

In [None]:
fig.savefig('em-mean-params-std-error-v1.png')rolling

In [None]:
np.savetxt('../../configs/exper/runs.txt', run_ledger['run'].to_numpy().reshape((-1, 5)).astype(int), fmt='%d')

In [None]:
np.loadtxt('../../configs/exper/runs.txt', dtype='int', delimiter=' ').tolist()