In [None]:
import pandas as pd
import numpy as np

recalc_everything = True

import altair as alt
print(alt.__version__)
alt.data_transformers.disable_max_rows()
alt.renderers.enable('default')
import altair_saver
import selenium
print(selenium.__version__)

save_figures = True
import vl_convert as vlc
def save_chart(chart_to_save, filename):
    svg_str = vlc.vegalite_to_svg(chart_to_save.to_json())
    with open(filename, "wt") as f:
        f.write(svg_str)

import termcolor
import io
import os

mainstandirname = '../../../Taiwan_Backup/Monkeypox_2022'
os.makedirs(mainstandirname, exist_ok=True)

from IPython.display import Markdown as md

%matplotlib inline
%config matplotlib_inline.matplotlib_formats = 'retina'
%config InlineBackend.figure_format = 'retina'
import matplotlib
import matplotlib.pyplot as plt

# chinese font
from matplotlib import font_manager
fontP = font_manager.FontProperties(fname="./NotoSerifTC-Regular.otf")
fontP.set_size(10)

import pathlib
import platform

import cmdstanpy as cmdstan
import arviz as az
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module='arviz')
standistribdir = '../../../../CmdStan'
cmdstan.set_cmdstan_path(standistribdir[3:] if platform.system()=='Darwin' else standistribdir)

func_dict = {"q2.5": lambda x: np.percentile(x, 2.5),
             "q25": lambda x: np.percentile(x, 25),
             "median": lambda x: np.percentile(x, 50),
             "q75": lambda x: np.percentile(x, 75),
             "q97.5": lambda x: np.percentile(x, 97.5)}

def get_stats(cmdstan_data, varnames, round_to_=5):
    # include mean and hdi
    stats = az.summary(cmdstan_data, round_to=round_to_, var_names=varnames, hdi_prob=0.95).loc[:, ['mean','hdi_2.5%','hdi_97.5%','ess_bulk','ess_tail','r_hat']].reset_index().rename(columns={'index':'var', 'hdi_2.5%':'hdi2.5', 'hdi_97.5%':'hdi97.5'})
    stats = az.summary(cmdstan_data, round_to=round_to_, var_names=varnames, hdi_prob=0.50).loc[:, ['hdi_25%','hdi_75%']].reset_index().rename(columns={'index':'var', 'hdi_25%':'hdi25', 'hdi_75%':'hdi75'}).\
        merge(stats, left_on='var', right_on='var')
    # include percentiles
    stats = az.summary(cmdstan_data, round_to=round_to_, var_names=varnames, stat_funcs=func_dict, extend=False).reset_index().rename(columns={'index': 'var'}).merge(stats, left_on='var', right_on='var')
    stats['time'] = stats['var'].apply(lambda st: st[st.find("[")+1:st.find("]")])
    stats['time'] = ['NA' if "[" not in y else int(x)+1 for x,y in zip(stats['time'],stats['var'])]
    stats['var'] = stats['var'].apply(lambda st: st[:st.find("[")] if "[" in st else st)
    return stats.loc[:,['var','time','mean','hdi2.5','hdi25','hdi75','hdi97.5','q2.5','q25','median','q75','q97.5','ess_bulk','ess_tail','r_hat']]

def get_stats_2d(cmdstan_data, varnames, rounding=2):
    # include mean and hpd
    stats = az.summary(cmdstan_data, var_names=varnames, hdi_prob=0.95, round_to=rounding).loc[:, ['mean','hdi_2.5%','hdi_97.5%','ess_bulk','ess_tail','r_hat']].reset_index().rename(columns={'index':'var', 'hdi_2.5%':'hdi2.5', 'hdi_97.5%':'hdi97.5'})
    stats = az.summary(cmdstan_data, var_names=varnames, hdi_prob=0.50, round_to=rounding).loc[:, ['hdi_25%','hdi_75%']].reset_index().rename(columns={'index':'var', 'hdi_25%':'hdi25', 'hdi_75%':'hdi75'}).\
        merge(stats, left_on='var', right_on='var')
    # include percentiles
    stats = az.summary(cmdstan_data, var_names=varnames, stat_funcs=func_dict, extend=False, round_to=rounding).reset_index().rename(columns={'index': 'var'}).merge(stats, left_on='var', right_on='var')
    stats['time'] = stats['var'].apply(lambda st: st[st.find("[")+1:st.find("]")])
    stats['time'] = ['NA' if "[" not in y else x for x,y in zip(stats['time'],stats['var'])]
    stats['var'] = stats['var'].apply(lambda st: st[:st.find("[")] if "[" in st else st)
    return stats.loc[:,['var','time','mean','hdi2.5','hdi25','hdi75','hdi97.5','q2.5','q25',
                        'median','q75','q97.5','ess_bulk','ess_tail','r_hat']]

clrs_ = ["#00a1d5", "#fee391", "#d8daeb", "#bababa", "k"] # via https://nanx.me/ggsci/index.html #blue = #74add1 #yellow = #fee391

num_warmup = 1000
num_iterations = 1250
num_chains = 4

import rpy2.rinterface
%load_ext rpy2.ipython
rpy2.robjects.r['options'](warn=-1)

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module='arviz')
warnings.filterwarnings("ignore", category=RuntimeWarning, module='arviz')
from pytz_deprecation_shim import PytzUsageWarning
warnings.filterwarnings('ignore', category=PytzUsageWarning)

import pyreadr

In [None]:
remote_userID = 'XXX'
remote_userPSW = 'XXX'
remote_IP = 'XXX'
remote_ip = remote_IP

num_warmup = 1000
num_iterations = 1250
num_chains = 4

# Data

In [None]:
date_upper = pd.to_datetime("2023-07-05", format="%Y-%m-%d")
date_lower = date_upper - pd.DateOffset(days = 45)
date_lower

In [None]:
mindate = pd.to_datetime("2023-05-15", format="%Y-%m-%d")
mindate_to_show = pd.to_datetime("2023-05-01", format="%Y-%m-%d")
maxdate_to_show = pd.to_datetime("2023-07-20", format="%Y-%m-%d")
cutoff_date = pd.to_datetime("2023-07-15", format="%Y-%m-%d")
truncation_date = pd.to_datetime("2023-08-15", format="%Y-%m-%d")

## A. Main epicurve

In [None]:
df_cases = pd.read_csv('../../data/epicurve_figure_1.csv')
df_cases['Date'] = pd.to_datetime(df_cases['Date'], format="%Y-%m-%d")
df_cases

In [None]:
clrs_ = ["#00a1d5", "#fee391", "#d8daeb", "#bababa", "k"]
colors = ['#4575b4', '#fdae61']
ymx = 23

base = alt.Chart(df_cases).encode(
    alt.X('Date', axis=alt.Axis(title="Date of symptom onset (month/day in 2023)", format = ("%m/%d")), 
          scale=alt.Scale(domain=[mindate, maxdate_to_show]))
)

bar_cases_obs = base.mark_area(interpolate='step-before', color=clrs_[0], binSpacing=0, width=5.25, clip=True).encode(
    alt.Y('Onset:Q', scale=alt.Scale(domain=[0, ymx]), sort=['local', 'imported'][::-1])
).configure_range(
        category=alt.RangeScheme(clrs_)
).resolve_scale(y = 'independent').properties(width=400, height=300).configure_axis(grid=False)

bar_cases_obs

## B. Loading linelist data from Hong Kong

In [None]:
df_linelist_hk = pd.read_excel(os.path.join("../../data", "20230823-HongKong-linelist.xlsx")).drop(['Source', 'Remarks'], axis=1)
df_linelist_hk

## C. Identifying some variable for the simulations

In [None]:
TJul1_ = (pd.to_datetime('2023-07-01')-mindate).days + 1
TAug1_ = (pd.to_datetime('2023-08-01')-mindate).days + 1
TSep1_ = (pd.to_datetime('2023-09-01')-mindate).days + 1

# <font color="green">2. Analysis</font>

## <font color="orange">2a. Identifying the reporting delay from HK data</font>

In [None]:
truncation_date_hk = pd.to_datetime('2023-08-23')

### By reporting date

In [None]:
basename = truncation_date_hk.strftime("%Y%m%d")+f'_reporting_delay_HK_gengamma'
standirname = os.path.join(mainstandirname, basename)

paths = [str(x) for x in list(pathlib.Path(standirname).glob("trace*.csv"))]
print(paths)
fit_rep_delay = cmdstan.from_csv(paths)

idata_rep_delay = az.from_cmdstanpy(posterior=fit_rep_delay)

mod_delay_summary_output = get_stats(idata_rep_delay.posterior, ['mean_delay', 'sd_delay', 'a', 'sigma', 'mu', 'q', 'loga', 'logsigma'])
mod_delay_summary_output['var'] = mod_delay_summary_output['var'].astype('string')
for col in mod_delay_summary_output.columns[2:]:
    mod_delay_summary_output[col] = mod_delay_summary_output[col].astype('float')
mod_delay_summary_output['Mean (95% CI)'] = ["%.2f (%.2f, %.2f)" % (x, y, z) for (x,y,z) in zip(mod_delay_summary_output['mean'], mod_delay_summary_output['q2.5'], mod_delay_summary_output['q97.5'])]
mod_delay_summary_output['Parameter'] = ['Mean delay, days','SD, days','a', 'sigma', 'mu', 'q', 'loga', 'logsigma'] 

az.plot_trace(idata_rep_delay, var_names=('mean_delay', 'sd_delay')); plt.tight_layout()
mod_delay_summary = az.summary(idata_rep_delay, var_names=['mean_delay', 'sd_delay'], hdi_prob=0.95)

display(mod_delay_summary_output.loc[:, ['Parameter', 'Mean (95% CI)']])

In [None]:
stats_summary_rep_delay = az.summary(idata_rep_delay, ['mu', 'loga', 'logsigma'])
stats_summary_rep_delay

### By diagnosis date

In [None]:
basename = truncation_date_hk.strftime("%Y%m%d")+f'_reporting_delay_HK_gengamma-diangosis'
standirname = os.path.join(mainstandirname, basename)

paths = [str(x) for x in list(pathlib.Path(standirname).glob("trace*.csv"))]
print(paths)
fit_rep_delay_diagnosis = cmdstan.from_csv(paths)

idata_rep_delay_diagnosis = az.from_cmdstanpy(posterior=fit_rep_delay_diagnosis)

mod_delay_diagnosis_summary_output = get_stats(idata_rep_delay_diagnosis.posterior, ['mean_delay', 'sd_delay', 'a', 'sigma', 'mu', 'q', 'loga', 'logsigma'])
mod_delay_diagnosis_summary_output['var'] = mod_delay_diagnosis_summary_output['var'].astype('string')
for col in mod_delay_diagnosis_summary_output.columns[2:]:
    mod_delay_diagnosis_summary_output[col] = mod_delay_diagnosis_summary_output[col].astype('float')
mod_delay_diagnosis_summary_output['Mean (95% CI)'] = ["%.2f (%.2f, %.2f)" % (x, y, z) for (x,y,z) in zip(mod_delay_diagnosis_summary_output['mean'], mod_delay_diagnosis_summary_output['q2.5'], mod_delay_diagnosis_summary_output['q97.5'])]
mod_delay_diagnosis_summary_output['Parameter'] = ['Mean delay, days','SD, days','a', 'sigma', 'mu', 'q', 'loga', 'logsigma'] 

az.plot_trace(idata_rep_delay_diagnosis, var_names=('mean_delay', 'sd_delay')); plt.tight_layout()
mod_delay_diagnosis_summary = az.summary(idata_rep_delay_diagnosis, var_names=['mean_delay', 'sd_delay'], hdi_prob=0.95)

display(mod_delay_diagnosis_summary_output.loc[:, ['Parameter', 'Mean (95% CI)']])

In [None]:
stats_summary_rep_delay_diagnosis = az.summary(idata_rep_delay_diagnosis, ['mu', 'loga', 'logsigma'])
stats_summary_rep_delay_diagnosis

## <font color="orange">2b. Estimating the R0 in Mainland China</font>

In [None]:
stan_code_exp_growth = """functions {
    real gengamma_cdf(real x, real q, real mu, real sigma) {
        real logx = log(x),
            z = (logx - mu) / sigma,
            a = inv_square(q),
            value = gamma_cdf(a * exp(q * z) | a, 1);

        return value;
    }

    /* discretized version */
    vector dgengamma(real q, real mu, real sigma, int D) {
        vector[D] res;
        for (k in 1:D)
            res[k] = gengamma_cdf(k - 0.5 | q, mu, sigma);

        if (D > 1)
            return append_row(append_row(res[1], tail(res, D-1) - head(res, D-1)), 1.0 - res[D]);
        else 
            return to_vector({res[1], 1 - res[1]});
    }

    vector dgamma(real param1, real param2, int K) {
        vector[K] res;
        for (k in 1:K)
            res[k] = gamma_cdf(k - 0.5 | param1, param2);

        return append_row(res[1], tail(res, K-1) - head(res, K-1));
    }
}

data {
    int<lower = 1> Tlower, Tupper; // the cutoffday for the estimation of the exponential growth 
    array[Tupper] int<lower = 0> cases_onset; // number of cases by date of symptom onset starting to be recorded till the day Tupper

    int<lower = Tupper> T; // total number of days for which we have the data on number of cases by date of notification (their symptom onset date is missed)
    array[T] int<lower = 0> cases_reported, cases_diagnosis;

    int<lower = Tupper> Tpred;
    int<upper = Tpred> TJul1, TAug1, TSep1;

    // reporting delay described by the generalized gamma distribution
    real loga, logsigma, mu;

    // delay b/w onset and diagnosis described by the generalized gamma distribution
    real loga_diagnosis, logsigma_diagnosis, mu_diagnosis;

    // generation time (scale and shape of the Gamma distribution estimated in Guzetta et al. 2022)
    real<lower = 0> genalpha, geninvbeta;
}

transformed data {
    // reporting delay
    real a = exp(loga), q = inv_sqrt(a), sigma = exp(logsigma);

    // delay b/w onset and diagnosis
    real a_diagnosis = exp(loga_diagnosis), q_diagnosis = inv_sqrt(a_diagnosis), sigma_diagnosis = exp(logsigma_diagnosis);

    // generation time
    real genmean = genalpha * geninvbeta,
        gensigma = sqrt(genalpha) * geninvbeta;

    // backprojecting from the reporting date
    array[T] int cases_onset_backprj = rep_array(0, T);
    for (t in 1:T) 
        if (cases_reported[t] > 0) {
            vector[t+1] probs_for_backprojection = dgengamma(q, mu, sigma, t);
            array[t+1] int counts_backprj = multinomial_rng(probs_for_backprojection, cases_reported[t]);
            for (i in 1:t) 
                cases_onset_backprj[t-i+1] += counts_backprj[i];
        }

    // backprojecting from the diagnosis date
    for (t in 1:T) 
        if (cases_diagnosis[t] > 0) {
            vector[t+1] probs_for_backprojection = dgengamma(q_diagnosis, mu_diagnosis, sigma_diagnosis, t);
            array[t+1] int counts_backprj = multinomial_rng(probs_for_backprojection, cases_diagnosis[t]);
            for (i in 1:t) 
                cases_onset_backprj[t-i+1] += counts_backprj[i];
        }


    array[Tupper] int cases;
    for (t in 1:Tupper) 
        cases[t] = cases_onset[t] + cases_onset_backprj[t];

    real jitter = 1e-9;
}

parameters {
    // exponential growth rate
    real logr;
    // initial incidence
    real<lower = 0> i0;
    // process error
    real<lower = 0> phi;
}

transformed parameters {
    real r = exp(logr);
}

model {
    logr ~ std_normal();
    i0 ~ normal(5, 10);
    phi ~ gamma(1, 1);

    for (t in Tlower:Tupper)
        if (cases[t] > 0)
            target += neg_binomial_2_lupmf(cases[t] | i0 * exp(r * t) + jitter, phi); 
}

generated quantities {
    // basic reproduction number
    real R0_norm_approx = exp(r * genmean - 0.5 * square(r) * square(gensigma)),
        R0 = pow(1 + r * geninvbeta, genalpha);

    // doubling time
    real doubling_time = log(2) / r; 

    array[Tpred] int cases_prj = rep_array(0, Tpred);
    for (t in 1:Tpred)
        cases_prj[t] = (t > Tupper) ? neg_binomial_2_rng(i0 * exp(r * t), phi) : cases[t];

    int casesJul = sum(cases_prj[TJul1:TAug1-1]), 
        casesAug = sum(cases_prj[TAug1:TSep1-1]);

    array[Tpred] int cases_reported_prj = rep_array(0, Tpred);
    for (t in 1:Tpred) 
        if (cases_prj[t] > 0) {
            vector[Tpred-t+2] probs_for_prj = dgengamma(q, mu, sigma, Tpred-t+1);
            array[Tpred-t+2] int counts = multinomial_rng(probs_for_prj, cases_prj[t]);
            for (s in 1:Tpred-t+1)
                cases_reported_prj[t+s-1] += counts[s];
        }

    int reportedJul = sum(cases_reported_prj[TJul1:TAug1-1]), 
        reportedAug = sum(cases_reported_prj[TAug1:TSep1-1]);
}"""

In [None]:
def sim(Tlower_, Tupper_):
    basename = truncation_date_hk.strftime("%Y%m%d")+f'_exp_growth_China_expwindow_{Tlower_}-{Tupper_}_Miura2023'
    standirname = os.path.join(mainstandirname, basename)
    !rm {standirname}/*
    os.makedirs(standirname, exist_ok=True)
    stanscriptdir = '../Dropbox/'+standirname[9:]

    # Miura et al. 2023
    mean_gt_ = 10.1; sd_gt_ = 6.1
    genalpha_ = (mean_gt_ / sd_gt_)**2
    geninvbeta_ = (sd_gt_**2) / mean_gt_ 

    Df_ = df_cases
    
    stan_data = dict({
        'Tlower': Tlower_,
        'Tupper': Tupper_,
        'Tpred': TSep1_+1,
        'TJul1': TJul1_,
        'TAug1': TAug1_,
        'TSep1': TSep1_,
        'cases_onset': Df_['Onset'].astype('int64').values[:Tupper_],
        'T': Df_.shape[0],
        'cases_reported': Df_['Reported'].astype('int64').values,
        'loga': stats_summary_rep_delay.loc['loga']['mean'],
        'logsigma': stats_summary_rep_delay.loc['logsigma']['mean'],
        'mu': stats_summary_rep_delay.loc['mu']['mean'],
        'cases_diagnosis': Df_['Diagnosis'].astype('int64').values,
        'loga_diagnosis': stats_summary_rep_delay_diagnosis.loc['loga']['mean'],
        'logsigma_diagnosis': stats_summary_rep_delay_diagnosis.loc['logsigma']['mean'],
        'mu_diagnosis': stats_summary_rep_delay_diagnosis.loc['mu']['mean'],
        'genalpha': genalpha_,
        'geninvbeta': geninvbeta_
    })
    stan_data_file = os.path.join(standirname, 'Data.json')
    cmdstan.write_stan_json(stan_data_file, stan_data)

    stan_inits = dict({
        'r': 0.1,
        'phi': 1.0
    })
    stan_init_file = os.path.join(standirname, 'Inits.json')
    cmdstan.write_stan_json(stan_init_file, stan_inits)

    stan_code_file = os.path.join(standirname, f'fit_exp_growth.stan')
    with open(stan_code_file, "w+") as f:
        f.write(stan_code_exp_growth)
        f.close()
    
    model = cmdstan.CmdStanModel(stan_file=stan_code_file, cpp_options={'STAN_THREADS': 'TRUE'}, compile='force')
    fit = model.sample(data=stan_data_file, seed = 1, iter_warmup=num_warmup, iter_sampling=1, inits=stan_init_file, parallel_chains=10,
                       show_console=False, show_progress=False, chains = 4000)
    fit.save_csvfiles(dir=standirname)

In [None]:
%%time
for date_lower_ in pd.date_range(pd.to_datetime('2023-05-16'), pd.to_datetime('2023-05-26')):
    for date_upper_ in pd.date_range(pd.to_datetime('2023-06-30'), pd.to_datetime('2023-07-10')):
        Tlower_ = (date_lower_-mindate).days + 1
        Tupper_ = (date_upper_-mindate).days + 1
        print([date_lower_, date_upper_], [Tlower_, Tupper_])
        sim(Tlower_, Tupper_)

In [None]:
basename = truncation_date_hk.strftime("%Y%m%d")+f'_exp_growth_China_expwindow_{Tlower_}-{Tupper_}_Miura2023'
standirname = os.path.join(mainstandirname, basename)
idata_chn = az.from_cmdstanpy(cmdstan.from_csv(path=standirname))
df_stats_chn = get_stats(idata_chn.posterior, ['r', 'i0', 'phi', 'R0', 'doubling_time']) 
df_stats_chn

In [None]:
%%time
Df_stats = None
fldrs = !ls {mainstandirname} | grep _expwindow_ | grep Miura2023
for fldr in fldrs:
    print(fldr)
    standirname = os.path.join(mainstandirname, fldr)
    Tlower_ = int(fldr.split("_expwindow_")[-1].split("-")[0])
    Tupper_ = int(fldr.split("_expwindow_")[-1].split("-")[-1].split("_")[0])
    idata_chn = az.from_cmdstanpy(cmdstan.from_csv(path=standirname))
    df_stats_ = get_stats(idata_chn.posterior, ['R0', 'doubling_time', 'reportedJul', 'reportedAug']) 
    df_stats_['Tlower'] = Tlower_
    df_stats_['Tupper'] = Tupper_
    df_stats_ = df_stats_.drop(['ess_bulk', 'ess_tail', 'r_hat'], axis=1)
    Df_stats = df_stats_ if Df_stats is None else pd.concat([Df_stats, df_stats_], ignore_index=True)

In [None]:
Df_stats

In [None]:
Df_stats.Tlower.drop_duplicates()

In [None]:
Df_stats['date_lower'] = [mindate + pd.DateOffset(days = x - 1) for x in Df_stats['Tlower']]
Df_stats['date_upper'] = [mindate + pd.DateOffset(days = x - 1) for x in Df_stats['Tupper']]

In [None]:
Df_stats_table = Df_stats.loc[lambda d: d['var']=='R0', ['Tlower', 'Tupper', 'median']].pivot_table(values='median', index='Tlower', columns='Tupper')
Df_stats_table

In [None]:
Df_stats.date_lower.drop_duplicates()

In [None]:
Df_stats.date_upper.drop_duplicates()

### Estimates of Re

In [None]:
wd_ = 300; ht_ = 270
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='R0', ['date_lower', 'date_upper', 'median']]).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('median:Q', scale=alt.Scale(scheme='greens', reverse=False), title="median")
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='R0', ['date_lower', 'date_upper', 'median']]
print("Min median:", Df_stats_['median'].min())
print("Max median:", Df_stats_['median'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-Re-median.svg')
chart

In [None]:
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='R0', ['date_lower', 'date_upper', 'q2.5']].rename(columns={'q2.5':'lower'})).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('lower:Q', scale=alt.Scale(scheme='greens', reverse=False), title="lower") #lighttealblue
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='R0', ['date_lower', 'date_upper', 'q2.5']]
print("Min lower:", Df_stats_['q2.5'].min())
print("Max lower:", Df_stats_['q2.5'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-Re-lower.svg')
chart

In [None]:
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='R0', ['date_lower', 'date_upper', 'q97.5']].rename(columns={'q97.5':'upper'})).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('upper:Q', scale=alt.Scale(scheme='greens', reverse=False), title="upper") #tealblues
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='R0', ['date_lower', 'date_upper', 'q97.5']]
print("Min upper:", Df_stats_['q97.5'].min())
print("Max upper:", Df_stats_['q97.5'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-Re-upper.svg')
chart

### Reported cases in July

In [None]:
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='reportedJul', ['date_lower', 'date_upper', 'median']]).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('median:Q', scale=alt.Scale(scheme='lighttealblue', reverse=False), title="median")
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='reportedJul', ['date_lower', 'date_upper', 'median']]
print("Min median:", Df_stats_['median'].min())
print("Max median:", Df_stats_['median'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-reportedJul-median.svg')
chart

In [None]:
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='reportedJul', ['date_lower', 'date_upper', 'q2.5']].rename(columns={'q2.5':'lower'})).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('lower:Q', scale=alt.Scale(scheme='lighttealblue', reverse=False), title="lower") #lighttealblue
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='reportedJul', ['date_lower', 'date_upper', 'q2.5']]
print("Min lower:", Df_stats_['q2.5'].min())
print("Max lower:", Df_stats_['q2.5'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-reportedJul-lower.svg')
chart

In [None]:
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='reportedJul', ['date_lower', 'date_upper', 'q97.5']].rename(columns={'q97.5':'upper'})).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('upper:Q', scale=alt.Scale(scheme='lighttealblue', reverse=False), title="upper") #tealblues
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='reportedJul', ['date_lower', 'date_upper', 'q97.5']]
print("Min upper:", Df_stats_['q97.5'].min())
print("Max upper:", Df_stats_['q97.5'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-reportedJul-upper.svg')
chart

### Reported cases in August

In [None]:
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='reportedAug', ['date_lower', 'date_upper', 'median']]).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('median:Q', scale=alt.Scale(scheme='blues', reverse=False), title="median")
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='reportedAug', ['date_lower', 'date_upper', 'median']]
print("Min median:", Df_stats_['median'].min())
print("Max median:", Df_stats_['median'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-reportedAug-median.svg')
chart

In [None]:
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='reportedAug', ['date_lower', 'date_upper', 'q2.5']].rename(columns={'q2.5':'lower'})).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('lower:Q', scale=alt.Scale(scheme='blues', reverse=False), title="lower")
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='reportedAug', ['date_lower', 'date_upper', 'q2.5']]
print("Min lower:", Df_stats_['q2.5'].min())
print("Max lower:", Df_stats_['q2.5'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-reportedAug-lower.svg')
chart

In [None]:
heatmap = alt.Chart(Df_stats.loc[lambda d: d['var']=='reportedAug', ['date_lower', 'date_upper', 'q97.5']].rename(columns={'q97.5':'upper'})).mark_rect().encode(
    x=alt.X('yearmonthdate(date_lower)', title="Lower bound date", axis=alt.Axis(format='%m/%d', labelAngle=-90)),
    y=alt.Y('yearmonthdate(date_upper)', title="Upper bound date", axis=alt.Axis(format='%m/%d', labelAngle=0), scale=alt.Scale(reverse=True)),
    color=alt.Color('upper:Q', scale=alt.Scale(scheme='blues', reverse=False), title="upper") #tealblues
)

Df_stats_ = Df_stats.loc[lambda d: d['var']=='reportedAug', ['date_lower', 'date_upper', 'q97.5']]
print("Min upper:", Df_stats_['q97.5'].min())
print("Max upper:", Df_stats_['q97.5'].max())

chart = alt.layer(heatmap).properties(width=wd_, height=ht_)
save_chart(chart, f'../../figures/sensitivity-reportedAug-upper.svg')
chart

## <font color="orange">2c. Considering only symptomatic cases</font>

In [None]:
stan_code_exp_growth_only_symptomatic = """functions {
    real gengamma_cdf(real x, real q, real mu, real sigma) {
        real logx = log(x),
            z = (logx - mu) / sigma,
            a = inv_square(q),
            value = gamma_cdf(a * exp(q * z) | a, 1);

        return value;
    }

    /* discretized version */
    vector dgengamma(real q, real mu, real sigma, int D) {
        vector[D] res;
        for (k in 1:D)
            res[k] = gengamma_cdf(k - 0.5 | q, mu, sigma);

        if (D > 1)
            return append_row(append_row(res[1], tail(res, D-1) - head(res, D-1)), 1.0 - res[D]);
        else 
            return to_vector({res[1], 1 - res[1]});
    }

    vector dgamma(real param1, real param2, int K) {
        vector[K] res;
        for (k in 1:K)
            res[k] = gamma_cdf(k - 0.5 | param1, param2);

        return append_row(res[1], tail(res, K-1) - head(res, K-1));
    }
}

data {
    int<lower = 1> Tlower, Tupper; // the cutoffday for the estimation of the exponential growth 
    array[Tupper] int<lower = 0> cases_onset; // number of cases by date of symptom onset starting to be recorded till the day Tupper

    int<lower = Tupper> Tpred;
    int<upper = Tpred> TJul1, TAug1, TSep1;

    // reporting delay described by the generalized gamma distribution
    real loga, logsigma, mu;

    // generation time (scale and shape of the Gamma distribution estimated in Guzetta et al. 2022)
    real<lower = 0> genalpha, geninvbeta;
}

transformed data {
    // reporting delay
    real a = exp(loga), q = inv_sqrt(a), sigma = exp(logsigma);

    // generation time
    real genmean = genalpha * geninvbeta,
        gensigma = sqrt(genalpha) * geninvbeta;
          
    real jitter = 1e-9;
}

parameters {
    // exponential growth rate
    real logr;
    // initial incidence
    real<lower = 0> i0;
    // process error
    real<lower = 0> phi;
}

transformed parameters {
    real r = exp(logr);
}

model {
    logr ~ std_normal();
    i0 ~ normal(5, 10);
    phi ~ gamma(1, 1);

    for (t in Tlower:Tupper)
        if (cases_onset[t] > 0)
            target += neg_binomial_2_lupmf(cases_onset[t] | i0 * exp(r * t) + jitter, phi); 
}

generated quantities {
    // basic reproduction number
    real R0_norm_approx = exp(r * genmean - 0.5 * square(r) * square(gensigma)),
        R0 = pow(1 + r * geninvbeta, genalpha);

    // doubling time
    real doubling_time = log(2) / r; 

    array[Tpred] int cases_prj = rep_array(0, Tpred);
    for (t in 1:Tpred)
        cases_prj[t] = (t > Tupper) ? neg_binomial_2_rng(i0 * exp(r * t), phi) : cases_onset[t];

    int casesJul = sum(cases_prj[TJul1:TAug1-1]), 
        casesAug = sum(cases_prj[TAug1:TSep1-1]);

    array[Tpred] int cases_reported_prj = rep_array(0, Tpred);
    for (t in 1:Tpred) 
        if (cases_prj[t] > 0) {
            vector[Tpred-t+2] probs_for_prj = dgengamma(q, mu, sigma, Tpred-t+1);
            array[Tpred-t+2] int counts = multinomial_rng(probs_for_prj, cases_prj[t]);
            for (s in 1:Tpred-t+1)
                cases_reported_prj[t+s-1] += counts[s];
        }

    int reportedJul = sum(cases_reported_prj[TJul1:TAug1-1]), 
        reportedAug = sum(cases_reported_prj[TAug1:TSep1-1]);
}"""

In [None]:
def sim_only_sympt(Tlower_, Tupper_):
    basename = truncation_date_hk.strftime("%Y%m%d")+f'_exp_growth_China_only_sympt_window_{Tlower_}-{Tupper_}_Miura2023'
    standirname = os.path.join(mainstandirname, basename)
    !rm {standirname}/*
    os.makedirs(standirname, exist_ok=True)
    stanscriptdir = '../Dropbox/'+standirname[9:]

    # Miura et al. 2023
    mean_gt_ = 10.1; sd_gt_ = 6.1
    genalpha_ = (mean_gt_ / sd_gt_)**2
    geninvbeta_ = (sd_gt_**2) / mean_gt_ 

    Df_ = df_cases
    
    stan_data = dict({
        'Tlower': Tlower_,
        'Tupper': Tupper_,
        'Tpred': TSep1_+1,
        'TJul1': TJul1_,
        'TAug1': TAug1_,
        'TSep1': TSep1_,
        'cases_onset': Df_['Onset'].astype('int64').values[:Tupper_],
        'T': Df_.shape[0],
        'loga': stats_summary_rep_delay.loc['loga']['mean'],
        'logsigma': stats_summary_rep_delay.loc['logsigma']['mean'],
        'mu': stats_summary_rep_delay.loc['mu']['mean'],
        'loga_diagnosis': stats_summary_rep_delay_diagnosis.loc['loga']['mean'],
        'logsigma_diagnosis': stats_summary_rep_delay_diagnosis.loc['logsigma']['mean'],
        'mu_diagnosis': stats_summary_rep_delay_diagnosis.loc['mu']['mean'],
        'genalpha': genalpha_,
        'geninvbeta': geninvbeta_
    })
    stan_data_file = os.path.join(standirname, 'Data.json')
    cmdstan.write_stan_json(stan_data_file, stan_data)

    stan_inits = dict({
        'r': 0.1,
        'phi': 1.0
    })
    stan_init_file = os.path.join(standirname, 'Inits.json')
    cmdstan.write_stan_json(stan_init_file, stan_inits)

    stan_code_file = os.path.join(standirname, f'fit_exp_growth.stan')
    with open(stan_code_file, "w+") as f:
        f.write(stan_code_exp_growth_only_symptomatic)
        f.close()
    
    model = cmdstan.CmdStanModel(stan_file=stan_code_file)
    fit = model.sample(data=stan_data_file, seed = 1, iter_warmup=num_warmup, iter_sampling=1000, inits=stan_init_file,
                       show_console=False, show_progress=False, chains = 4)
    fit.save_csvfiles(dir=standirname)

In [None]:
date_upper_ = pd.to_datetime("2023-07-05", format="%Y-%m-%d")
date_lower_ = date_upper_ - pd.DateOffset(days = 45)

Tlower_ = (date_lower_-mindate).days + 1
Tupper_ = (date_upper_-mindate).days + 1
print([date_lower_, date_upper_], [Tlower_, Tupper_])
sim_only_sympt(Tlower_, Tupper_)

In [None]:
%%time
Df_stats_only_sympt = None
fldrs = !ls {mainstandirname} | grep _window_ | grep Miura2023
for fldr in fldrs:
    print(fldr)
    standirname = os.path.join(mainstandirname, fldr)
    Tlower_ = int(fldr.split("_window_")[-1].split("-")[0])
    Tupper_ = int(fldr.split("_window_")[-1].split("-")[-1].split("_")[0])
    idata_chn = az.from_cmdstanpy(cmdstan.from_csv(path=standirname))
    df_stats_ = get_stats(idata_chn.posterior, ['R0', 'doubling_time', 'reportedJul', 'reportedAug']) 
    df_stats_['Tlower'] = Tlower_
    df_stats_['Tupper'] = Tupper_
    df_stats_ = df_stats_.drop(['ess_bulk', 'ess_tail', 'r_hat'], axis=1)
    Df_stats_only_sympt = df_stats_ if Df_stats is None else pd.concat([Df_stats_only_sympt, df_stats_], ignore_index=True)

In [None]:
Df_stats_only_sympt

## <font color="orange">3a. Selected national outbreaks of 2022: varying the exp. window</font>

In [None]:
stan_code_exp_growth_global = """functions {
    real gengamma_cdf(real x, real q, real mu, real sigma) {
        real logx = log(x),
            z = (logx - mu) / sigma,
            a = inv_square(q),
            value = gamma_cdf(a * exp(q * z) | a, 1);

        return value;
    }

    /* discretized version */
    vector dgengamma(real q, real mu, real sigma, int D) {
        vector[D] res;
        for (k in 1:D)
            res[k] = gengamma_cdf(k - 0.5 | q, mu, sigma);

        if (D > 1)
            return append_row(append_row(res[1], tail(res, D-1) - head(res, D-1)), 1.0 - res[D]);
        else 
            return to_vector({res[1], 1 - res[1]});
    }

    vector dgamma(real param1, real param2, int K) {
        vector[K] res;
        for (k in 1:K)
            res[k] = gamma_cdf(k - 0.5 | param1, param2);

        return append_row(res[1], tail(res, K-1) - head(res, K-1));
    }
}

data {
    int<lower = 1> Tupper; // the cutoffday for the estimation of the exponential growth 
    array[Tupper] int<lower = 0> cases_onset; // number of cases by date of symptom onset starting to be recorded till the day Tupper

    int<lower = Tupper> T; // total number of days for which we have the data on number of cases by day of notification (their symptom onset date is missed)
    array[T] int<lower = 0> cases_reported, cases_diagnosis;

    // reporting delay described by the generalized gamma distribution
    real loga, logsigma, mu;

    // delay b/w onset and diagnosis described by the generalized gamma distribution
    real loga_diagnosis, logsigma_diagnosis, mu_diagnosis;

    // generation time (scale and shape of the Gamma distribution estimated in Guzetta et al. 2022)
    real<lower = 0> genalpha, geninvbeta;
}

transformed data {
    // reporting delay
    real a = exp(loga), q = inv_sqrt(a), sigma = exp(logsigma);

    // delay b/w onset and diagnosis
    real a_diagnosis = exp(loga_diagnosis), q_diagnosis = inv_sqrt(a_diagnosis), sigma_diagnosis = exp(logsigma_diagnosis);

    // generation time
    real genmean = genalpha * geninvbeta,
        gensigma = sqrt(genalpha) * geninvbeta;

    // backprojecting from the reporting date
    array[T] int cases_onset_backprj = rep_array(0, T);
    for (t in 1:T) 
        if (cases_reported[t] > 0) {
            vector[t+1] probs_for_backprojection = dgengamma(q, mu, sigma, t);
            array[t+1] int counts_backprj = multinomial_rng(probs_for_backprojection, cases_reported[t]);
            for (i in 1:t) 
                cases_onset_backprj[t-i+1] += counts_backprj[i];
        }

    // backprojecting from the diagnosis date
    for (t in 1:T) 
        if (cases_diagnosis[t] > 0) {
            vector[t+1] probs_for_backprojection = dgengamma(q_diagnosis, mu_diagnosis, sigma_diagnosis, t);
            array[t+1] int counts_backprj = multinomial_rng(probs_for_backprojection, cases_diagnosis[t]);
            for (i in 1:t) 
                cases_onset_backprj[t-i+1] += counts_backprj[i];
        }


    array[Tupper] int cases;
    for (t in 1:Tupper) 
        cases[t] = cases_onset[t] + cases_onset_backprj[t];

    print(cases);
    real jitter = 1e-9;
}

parameters {
    // exponential growth rate
    real logr;
    // initial incidence
    real<lower = 0> i0;
    // process error
    real<lower = 0> phi;
}

transformed parameters {
    real r = exp(logr);
}

model {
    logr ~ std_normal();
    i0 ~ normal(5, 10);
    phi ~ gamma(1, 1);

    for (t in 1:Tupper)
        if (cases[t] > 0)
            target += neg_binomial_2_lupmf(cases[t] | i0 * exp(r * t) + jitter, phi); 
}

generated quantities {
    // basic reproduction number
    real R0_norm_approx = exp(r * genmean - 0.5 * square(r) * square(gensigma)),
        R0 = pow(1 + r * geninvbeta, genalpha);

    // doubling time
    real doubling_time = log(2) / r; 
}"""

In [None]:
fls = !ls ../../data/WHO | grep ^epicurve_
fls

In [None]:
%%time
starting_date_global = pd.to_datetime('2022-01-01', format='%Y-%m-%d')
ending_date_global = pd.to_datetime('2022-12-31', format='%Y-%m-%d')

df_cumcases = None
for fl_ in fls:
    df_ = pd.read_csv(os.path.join('../../data/WHO', fl_))
    df_['reference_date'] = pd.to_datetime(df_.reference_date, format='%Y-%m-%d')
    df_ = df_.loc[lambda d: d.reference_date <= ending_date_global]
    country_ = fl_.split('_')[1][:-4]
    who_region_ = df_.who_region.values[0]
    print(country_)
    if len(df_.loc[lambda d: d.date_type=='Onset']) > 0:
        df_mxs_ = df_.loc[lambda d: d.date_type=='Onset']
        df_mxs_ = df_mxs_.loc[lambda d: d.cases==np.max(df_mxs_.cases)]
        df_mxs_ = df_mxs_.loc[lambda d: d.reference_date==np.min(df_mxs_.reference_date)]
        df_cumcases_ = pd.DataFrame({'country': [country_], 'who_region': who_region_, 'cases': [df_['cases'].sum()], 
                                     'peak_cases': df_mxs_.cases, 'peak_date': df_mxs_.reference_date, 'peak_date_type': df_mxs_.date_type, 
                                     'file': fl_})
        df_cumcases = df_cumcases_ if df_cumcases is None else pd.concat([df_cumcases, df_cumcases_], ignore_index=True)
df_cumcases

In [None]:
# if no selection by symptomatic, then +Ireland 227 cases
df_cumcases_selection = df_cumcases.loc[lambda d: d.cases>700]
print("number of countries: ", len(df_cumcases_selection))
df_cumcases_selection

In [None]:
def sim_country(idx, Tupper_):
    df_sel_ = df_cumcases_selection.iloc[idx]
    print(df_sel_)
    
    basename = truncation_date_hk.strftime("%Y%m%d")+f'_Tupper-{Tupper_}_exp_growth_Miura2023_' + df_sel_.country
    standirname = os.path.join(mainstandirname, basename)
    !rm {standirname}/*
    os.makedirs(standirname, exist_ok=True)
    stanscriptdir = '../Dropbox/'+standirname[9:]

    # Miura et al. 2023
    mean_gt_ = 10.1; sd_gt_ = 6.1
    genalpha_ = (mean_gt_ / sd_gt_)**2
    geninvbeta_ = (sd_gt_**2) / mean_gt_ 

    fl_ = df_sel_['file']
    df_ = pd.read_csv(os.path.join('../../data/WHO', fl_))
    df_['reference_date'] = pd.to_datetime(df_.reference_date, format='%Y-%m-%d')
    df_ = df_.loc[lambda d: d.reference_date<=ending_date_global]

    df_country_ = df_.groupby(['date_type', 'reference_date'])['cases'].sum().reset_index().pivot_table(values='cases', columns='date_type', index='reference_date')
    df_country_ = pd.DataFrame(df_country_.to_records()).merge(pd.DataFrame({'reference_date': pd.date_range(starting_date_global, ending_date_global)}), how='outer').fillna(0)
    df_country_ = df_country_.sort_values('reference_date').set_index('reference_date').astype('int64')
    df_country_['reference_day'] = (df_country_.index - starting_date_global).days
    df_ = df_country_

    T_ = df_.shape[0]
    cases_onset_ = df_.loc[lambda d: (d.index>df_sel_.peak_date-pd.DateOffset(days=Tupper_))].Onset.values[:Tupper_]
    cases_reported_ = df_['Reported'].astype('int64').values if 'Reported' in df_.columns else [0]*T_
    cases_diagnosis_ = df_['Diagnosis'].astype('int64').values if 'Diagnosis' in df_.columns else [0]*T_

    stan_data = dict({
        'Tupper': Tupper_,
        'cases_onset': cases_onset_,
        'T': df_.shape[0],
        'cases_reported': cases_reported_,
        'loga': stats_summary_rep_delay.loc['loga']['mean'],
        'logsigma': stats_summary_rep_delay.loc['logsigma']['mean'],
        'mu': stats_summary_rep_delay.loc['mu']['mean'],
        'cases_diagnosis': cases_diagnosis_,
        'loga_diagnosis': stats_summary_rep_delay_diagnosis.loc['loga']['mean'],
        'logsigma_diagnosis': stats_summary_rep_delay_diagnosis.loc['logsigma']['mean'],
        'mu_diagnosis': stats_summary_rep_delay_diagnosis.loc['mu']['mean'],
        'genalpha': genalpha_,
        'geninvbeta': geninvbeta_
    })
    stan_data_file = os.path.join(standirname, 'Data.json')
    cmdstan.write_stan_json(stan_data_file, stan_data)

    stan_inits = dict({
        'r': 0.1,
        'phi': 1.0
    })
    stan_init_file = os.path.join(standirname, 'Inits.json')
    cmdstan.write_stan_json(stan_init_file, stan_inits)

    stan_code_file = os.path.join(standirname, f'fit_exp_growth.stan')
    with open(stan_code_file, "w+") as f:
        f.write(stan_code_exp_growth_global)
        f.close()
    
    model = cmdstan.CmdStanModel(stan_file=stan_code_file, cpp_options={'STAN_THREADS': 'TRUE'}, compile='force')
    fit = model.sample(data=stan_data_file, seed = 1, iter_warmup=num_warmup, iter_sampling=1, inits=stan_init_file, parallel_chains=10,
                       show_console=False, show_progress=False, chains = 4000)
    fit.save_csvfiles(dir=standirname)

    idata = az.from_cmdstanpy(posterior=fit)
    df_stats = get_stats(idata.posterior, ['r', 'i0', 'phi', 'R0', 'doubling_time']) 
    df_stats['country'] = df_sel_.country
    df_stats['who_region'] = df_sel_.who_region
    df_stats['expwindow'] = Tupper_

    return(df_stats)

In [None]:
%%time 
df_stats_countries = None
for Tupper_ in [30, 60]:
    for idx in range(len(df_cumcases_selection)):
        df_stats_ = sim_country(idx, Tupper_)
        df_stats_countries = df_stats_ if df_stats_countries is None else pd.concat([df_stats_countries, df_stats_], ignore_index=True)
df_stats_countries

In [None]:
df_stats_countries['SD'] = [(upper - lower) / 1.96 / 2 for lower, upper in zip(df_stats_countries['q2.5'], df_stats_countries['q97.5'])]
df_stats_countries

In [None]:
stan_code_meta_summary = """data {
  int<lower=1> J;  // number of studies with available data
  vector[J] mu_known;  // means of known studies
  vector<lower=0>[J] stderr_known;  // standard errors of known studies
}

parameters {
  real<lower=0> mu, tau_squared;
}

model {
  // Priors
  mu ~ normal(4, 8);  // weakly informative prior for the overall mean
  tau_squared ~ cauchy(0, 5);  // weakly informative prior for the between-study variability

  // Likelihood
  mu_known ~ normal(mu, sqrt(square(stderr_known) + tau_squared));
}

generated quantities {
  real mu_pred;

    {
        // Predict a future observation for a hypothetical new study
        mu_pred = normal_rng(mu, sqrt(tau_squared));
    }
}"""

Df_meta_stats = None
for Tupper_ in [30, 60]:
    basename = f'R0_meta_Tupper-{Tupper_}_Miura2023'
    standirname = os.path.join(mainstandirname, basename)
    if recalc_everything:
        !rm {standirname}/*
        os.makedirs(standirname, exist_ok=True)
        stanscriptdir = '../Dropbox/'+standirname[9:]
            
        stan_code_file = os.path.join(standirname, 'fit_R0_meta.stan')
        with open(stan_code_file, "w+") as f:
            f.write(stan_code_meta_summary)
            f.close()
    
        df_ = df_stats_countries.copy().loc[lambda d: (d['var']=='R0')&(d['expwindow']==Tupper_)][::-1]
        df_ = df_.loc[lambda d: ~((d.who_region=='AMRO')&(d.country!='united states of america')&(d.country!='canada'))]
        
        stan_data = dict({
            'J': df_.shape[0],
            'mu_known': df_['mean'].values,
            'stderr_known': df_['SD'].values
        })
        stan_data_file = os.path.join(standirname, 'Data.json')
        cmdstan.write_stan_json(stan_data_file, stan_data)
        
        def bash_file(stanscriptdir):
            return f"""#!/bin/bash
cwd=$(pwd)
cd {standistribdir_}
make -j4 {stanscriptdir}/fit_R0_meta
cd {stanscriptdir}
mkdir -p diagnostics
for i in {{1..4}}
do
    echo Running ${{i}}
    SEEDNUMBER=$((1+$i))
    ./fit_R0_meta \\
        method=sample num_samples={num_iterations} num_warmup={num_warmup} thin=1 save_warmup=0 adapt delta=0.98 \\
            algorithm=hmc \\
                engine=nuts \\
        random seed=${{SEEDNUMBER}} \\
        id=$i \\
        data file=Data.json \\
        output file=trace-$i.csv \\
            diagnostic_file=diagnostics/diagnostics-$i.csv > diagnostics/output-$i.txt &
done
echo Finished haha!
"""            
        model = cmdstan.CmdStanModel(stan_file=stan_code_file)
        fit = model.sample(data=stan_data_file, seed = 1, iter_warmup=num_warmup, iter_sampling=1000,
                           show_console=False, show_progress=True, chains = 4)
        fit.save_csvfiles(dir=standirname)

        idata_meta = az.from_cmdstanpy(posterior=fit)
        df_meta_stats = get_stats(idata_meta.posterior, ['mu', 'tau_squared'])
        df_meta_stats['expwindow'] = Tupper_
        Df_meta_stats = df_meta_stats if Df_meta_stats is None else pd.concat([Df_meta_stats, df_meta_stats], ignore_index = True)
Df_meta_stats

In [None]:
df_stats_countries.merge(df_cumcases_selection.loc[:, ['country', 'cases', 'peak_cases', 'peak_date']])

In [None]:
for Tupper_ in [30, 60]:
    fig = plt.figure(figsize = [5.5, 4.5]) 
    ax1 = fig.add_subplot()
    
    df_ = df_stats_countries.copy().merge(df_cumcases_selection.loc[:, ['country', 'cases', 'peak_cases', 'peak_date']])\
        .loc[lambda d: (d['var']=='R0')&(d['expwindow']==Tupper_)][::-1]
    
    df_['color'] = ['lightgrey' if (who_region=='AMRO')&(country!='united states of america')&(country!='canada') else 'k' for country, who_region in zip(df_.country,df_.who_region)]
    df_ = df_.sort_values('color')
    
    df_['Source'] = [x.title().replace(' Of ', ' of ') for x in df_['country']]
    df_['Source_y'] = [(i+1)/df_.shape[0] for i in range(df_.shape[0])]
    
    for clr_ in df_['color'].drop_duplicates().values:
        df__ = df_.loc[lambda d: d['color']==clr_]
        lsty = 'dashed' if clr_=='lightgrey' else 'solid'
        eb1 = ax1.errorbar(x=list(df__['median']), y=df__['Source_y'], 
                     xerr=[list(df__['median'] - df__['q2.5']), list(df__['q97.5'] - df__['median'])], color='k', capsize=2,
                     linestyle='None', linewidth=1, markersize=0)
        if clr_=='lightgrey':
            eb1[-1][0].set_linestyle('--')
            clr0_ = 'w'
        else:
            clr0_ = 'k'
        ax1.plot(list(df__['median']), df__['Source_y'], color='k', mfc=clr0_, 
                 linestyle='None', linewidth=1, marker="o", ms=5)
    
    # Pooled mean
    mean_meta = Df_meta_stats.loc[lambda d: d['expwindow']==Tupper_].loc[lambda d: d['var']=='mu', 'mean'].values[0]
    lower_meta = Df_meta_stats.loc[lambda d: d['expwindow']==Tupper_].loc[lambda d: d['var']=='mu', 'q2.5'].values[0]
    upper_meta = Df_meta_stats.loc[lambda d: d['expwindow']==Tupper_].loc[lambda d: d['var']=='mu', 'q97.5'].values[0]
    τsqr_meta = Df_meta_stats.loc[lambda d: d['expwindow']==Tupper_].loc[lambda d: d['var']=='tau_squared', 'mean'].values[0]
    y_ = -0.15/df_.shape[0]
    h_ = 0.3/df_.shape[0]
    from matplotlib.patches import Polygon
    pts = np.array([[lower_meta, y_], [mean_meta,y_+h_], [upper_meta, y_], [mean_meta, y_-h_], [lower_meta, y_]])
    p = Polygon(pts, closed=False, color='k')
    ax1.add_patch(p)
    
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_xlim(left = 0.8, right=3.0)
    ax1.set_xticks(np.arange(1, 3.5, .5))
    ax1.get_xaxis().tick_bottom()
    ax1.get_yaxis().tick_left()
    
    ax1.set_ylim(-1/df_.shape[0], np.max(df_['Source_y'])+.02)
    ax1.set_xlabel("Effective reproduction number")
    ax1.set_yticks(np.r_[df_['Source_y'], y_])
    ax1.set_yticklabels(np.r_[df_['Source'], [f'Pooled mean, τ² = {τsqr_meta:.2f}']] )
    
    ax2 = ax1.secondary_yaxis("right")
    ax2.set_yticks(ax1.get_yticks())
    ax2.spines['right'].set_visible(False)
    ax2.tick_params(length=0)
    df_['label'] = [f"{mu:.2f} ({mu_lower:.2f}–{mu_upper:.2f})" for mu, mu_lower, mu_upper in zip(df_['mean'], df_['q2.5'], df_['q97.5'])]
    ax2.set_yticklabels(np.r_[df_['label'], [f'{mean_meta:.2f} ({lower_meta:.2f}–{upper_meta:.2f})']], ha='center', fontsize=9)
    ax2.yaxis.set_tick_params(pad=45)

    ax3 = ax1.secondary_yaxis("right")
    ax3.set_yticks(ax1.get_yticks())
    ax3.spines['right'].set_visible(False)
    ax3.tick_params(length=0)
    df_['label2'] = [f"{(peak_date-pd.DateOffset(days=Tupper_)).strftime('%m/%d')}-{peak_date.strftime('%m/%d')}" for peak_date in df_['peak_date']]
    ax3.set_yticklabels(np.r_[df_['label2'], ['']], ha='center', fontsize=9)
    ax3.yaxis.set_tick_params(pad=125)
    
    [label.set_fontweight('bold') if 'Pool' in label.get_text() else None for label in ax1.get_yticklabels()]
    [label2.set_fontweight('bold') if 'Pool' in label.get_text() else None for label, label2 in zip(ax1.get_yticklabels(),ax2.get_yticklabels())]
    
    [label.set_color('k') if 'Pool' in label.get_text() else None for label in ax1.get_yticklabels()]
    [label2.set_color('k') if 'Pool' in label.get_text() else None for label, label2 in zip(ax1.get_yticklabels(),ax2.get_yticklabels())]
    
    [label.set_color(clr_chn_) if 'China' in label.get_text() else None for label in ax1.get_yticklabels()]
    [label2.set_color(clr_chn_) if 'China' in label.get_text() else None for label, label2 in zip(ax1.get_yticklabels(),ax2.get_yticklabels())]
    
    plt.text(.44, 1.07, 'Country', horizontalalignment='left', fontsize=10)
    plt.text(3.59, 1.07, 'Mean (95% CrI)', horizontalalignment='right', fontsize=10)
    plt.text(3.89, 1.05, 'Period\n(m/d in 2022)', horizontalalignment='center', fontsize=10)
    
    plt.rcParams["axes.labelweight"] = "bold"
    if save_figures:
        plt.savefig(f"../../figures/fig-R0_meta_Tupper-{Tupper_}_Miura2023.pdf", format="pdf", bbox_inches="tight")
    
    plt.show()

## <font color="orange">3b. Selected national outbreaks of 2022: only symptomatic cases</font>

In [None]:
stan_code_exp_growth_global_only_symptomatic = """functions {
    real gengamma_cdf(real x, real q, real mu, real sigma) {
        real logx = log(x),
            z = (logx - mu) / sigma,
            a = inv_square(q),
            value = gamma_cdf(a * exp(q * z) | a, 1);

        return value;
    }

    /* discretized version */
    vector dgengamma(real q, real mu, real sigma, int D) {
        vector[D] res;
        for (k in 1:D)
            res[k] = gengamma_cdf(k - 0.5 | q, mu, sigma);

        if (D > 1)
            return append_row(append_row(res[1], tail(res, D-1) - head(res, D-1)), 1.0 - res[D]);
        else 
            return to_vector({res[1], 1 - res[1]});
    }

    vector dgamma(real param1, real param2, int K) {
        vector[K] res;
        for (k in 1:K)
            res[k] = gamma_cdf(k - 0.5 | param1, param2);

        return append_row(res[1], tail(res, K-1) - head(res, K-1));
    }
}

data {
    int<lower = 1> Tupper; // the cutoffday for the estimation of the exponential growth 
    array[Tupper] int<lower = 0> cases_onset; // number of cases by date of symptom onset starting to be recorded till the day Tupper
    
    // generation time (scale and shape of the Gamma distribution estimated in Guzetta et al. 2022)
    real<lower = 0> genalpha, geninvbeta;
}

transformed data {
    // generation time
    real genmean = genalpha * geninvbeta,
        gensigma = sqrt(genalpha) * geninvbeta;
        
    real jitter = 1e-9;
}

parameters {
    // exponential growth rate
    real logr;
    // initial incidence
    real<lower = 0> i0;
    // process error
    real<lower = 0> phi;
}

transformed parameters {
    real r = exp(logr);
}

model {
    logr ~ std_normal();
    i0 ~ normal(5, 10);
    phi ~ gamma(1, 1);

    for (t in 1:Tupper)
        if (cases_onset[t] > 0)
            target += neg_binomial_2_lupmf(cases_onset[t] | i0 * exp(r * t) + jitter, phi); 
}

generated quantities {
    // basic reproduction number
    real R0_norm_approx = exp(r * genmean - 0.5 * square(r) * square(gensigma)),
        R0 = pow(1 + r * geninvbeta, genalpha);

    // doubling time
    real doubling_time = log(2) / r; 
}"""

In [None]:
fls = !ls ../../data/WHO | grep ^epicurve_
fls

In [None]:
%%time
starting_date_global = pd.to_datetime('2022-01-01', format='%Y-%m-%d')
ending_date_global = pd.to_datetime('2022-12-31', format='%Y-%m-%d')

df_cumcases = None
for fl_ in fls:
    df_ = pd.read_csv(os.path.join('../../data/WHO', fl_))
    df_['reference_date'] = pd.to_datetime(df_.reference_date, format='%Y-%m-%d')
    df_ = df_.loc[lambda d: d.reference_date<=ending_date_global]
    country_ = fl_.split('_')[1][:-4]
    who_region_ = df_.who_region.values[0]
    print(country_)
    if len(df_.loc[lambda d: d.date_type=='Onset']) > 0:
        df_mxs_ = df_.loc[lambda d: d.date_type=='Onset']
        df_mxs_ = df_mxs_.loc[lambda d: d.cases==np.max(df_mxs_.cases)]
        df_mxs_ = df_mxs_.loc[lambda d: d.reference_date==np.min(df_mxs_.reference_date)]
        df_cumcases_ = pd.DataFrame({'country': [country_], 'who_region': who_region_, 'cases': [df_['cases'].sum()], 
                                     'peak_cases': df_mxs_.cases, 'peak_date': df_mxs_.reference_date, 'peak_date_type': df_mxs_.date_type, 
                                     'file': fl_})
        df_cumcases = df_cumcases_ if df_cumcases is None else pd.concat([df_cumcases, df_cumcases_], ignore_index=True)
df_cumcases

In [None]:
# if no selection by symptomatic, then +Ireland 227 cases
df_cumcases_selection = df_cumcases.loc[lambda d: d.cases>700]
print("number of countries: ", len(df_cumcases_selection))
df_cumcases_selection

In [None]:
def sim_country_only_sympt(idx, Tupper_):
    df_sel_ = df_cumcases_selection.iloc[idx]
    print(df_sel_)
    
    basename = truncation_date_hk.strftime("%Y%m%d")+f'_Tupper-{Tupper_}_exp_growth_only_sympt_Miura2023_' + df_sel_.country
    standirname = os.path.join(mainstandirname, basename)
    !rm {standirname}/*
    os.makedirs(standirname, exist_ok=True)
    stanscriptdir = '../Dropbox/'+standirname[9:]

    # Miura et al. 2023
    mean_gt_ = 10.1; sd_gt_ = 6.1
    genalpha_ = (mean_gt_ / sd_gt_)**2
    geninvbeta_ = (sd_gt_**2) / mean_gt_ 

    fl_ = df_sel_['file']
    df_ = pd.read_csv(os.path.join('../../data/WHO', fl_))
    df_['reference_date'] = pd.to_datetime(df_.reference_date, format='%Y-%m-%d')
    df_ = df_.loc[lambda d: d.reference_date<=ending_date_global]

    df_country_ = df_.groupby(['date_type', 'reference_date'])['cases'].sum().reset_index().pivot_table(values='cases', columns='date_type', index='reference_date')
    df_country_ = pd.DataFrame(df_country_.to_records()).merge(pd.DataFrame({'reference_date': pd.date_range(starting_date_global, ending_date_global)}), how='outer').fillna(0)
    df_country_ = df_country_.sort_values('reference_date').set_index('reference_date').astype('int64')
    df_country_['reference_day'] = (df_country_.index - starting_date_global).days
    df_ = df_country_

    T_ = df_.shape[0]
    cases_onset_ = df_.loc[lambda d: (d.index>df_sel_.peak_date-pd.DateOffset(days=Tupper_))].Onset.values[:Tupper_]
    cases_reported_ = df_['Reported'].astype('int64').values if 'Reported' in df_.columns else [0]*T_
    cases_diagnosis_ = df_['Diagnosis'].astype('int64').values if 'Diagnosis' in df_.columns else [0]*T_

    stan_data = dict({
        'Tupper': Tupper_,
        'cases_onset': cases_onset_,
        'genalpha': genalpha_,
        'geninvbeta': geninvbeta_
    })
    stan_data_file = os.path.join(standirname, 'Data.json')
    cmdstan.write_stan_json(stan_data_file, stan_data)

    stan_inits = dict({
        'r': 0.1,
        'phi': 1.0
    })
    stan_init_file = os.path.join(standirname, 'Inits.json')
    cmdstan.write_stan_json(stan_init_file, stan_inits)

    stan_code_file = os.path.join(standirname, f'fit_exp_growth.stan')
    with open(stan_code_file, "w+") as f:
        f.write(stan_code_exp_growth_global_only_symptomatic)
        f.close()
    
    model = cmdstan.CmdStanModel(stan_file=stan_code_file, cpp_options={'STAN_THREADS': 'TRUE'}, compile='force')
    fit = model.sample(data=stan_data_file, seed = 1, iter_warmup=num_warmup, iter_sampling=1, inits=stan_init_file, parallel_chains=10,
                       show_console=False, show_progress=False, chains = 4000)
    fit.save_csvfiles(dir=standirname)

    idata = az.from_cmdstanpy(posterior=fit)
    df_stats = get_stats(idata.posterior, ['r', 'i0', 'phi', 'R0', 'doubling_time']) 
    df_stats['country'] = df_sel_.country
    df_stats['who_region'] = df_sel_.who_region
    df_stats['expwindow'] = Tupper_

    return(df_stats)

In [None]:
%%time 
df_stats_countries = None
for Tupper_ in [45, 30, 60]:
    for idx in range(len(df_cumcases_selection)):
        df_stats_ = sim_country_only_sympt(idx, Tupper_)
        df_stats_countries = df_stats_ if df_stats_countries is None else pd.concat([df_stats_countries, df_stats_], ignore_index=True)
df_stats_countries

In [None]:
df_stats_countries

In [None]:
df_stats_countries['SD'] = [(upper - lower) / 1.96 / 2 for lower, upper in zip(df_stats_countries['q2.5'], df_stats_countries['q97.5'])]
df_stats_countries

In [None]:
stan_code_meta_summary = """data {
  int<lower=1> J;  // number of studies with available data
  vector[J] mu_known;  // means of known studies
  vector<lower=0>[J] stderr_known;  // standard errors of known studies
}

parameters {
  real<lower=0> mu, tau_squared;
}

model {
  // Priors
  mu ~ normal(4, 8);  // weakly informative prior for the overall mean
  tau_squared ~ cauchy(0, 5);  // weakly informative prior for the between-study variability

  // Likelihood
  mu_known ~ normal(mu, sqrt(square(stderr_known) + tau_squared));
}

generated quantities {
  real mu_pred;

    {
        // Predict a future observation for a hypothetical new study
        mu_pred = normal_rng(mu, sqrt(tau_squared));
    }
}"""

Df_meta_stats = None
for Tupper_ in [45, 30, 60]:
    basename = f'R0_meta_only_sympt_Miura2023_Tupper-{Tupper_}'
    standirname = os.path.join(mainstandirname, basename)
    if recalc_everything:
        !rm {standirname}/*
        os.makedirs(standirname, exist_ok=True)
        stanscriptdir = '../Dropbox/'+standirname[9:]
            
        stan_code_file = os.path.join(standirname, 'fit_R0_meta.stan')
        with open(stan_code_file, "w+") as f:
            f.write(stan_code_meta_summary)
            f.close()
    
        df_ = df_stats_countries.copy().loc[lambda d: (d['var']=='R0')&(d['expwindow']==Tupper_)][::-1]
        df_ = df_.loc[lambda d: ~((d.who_region=='AMRO')&(d.country!='united states of america')&(d.country!='canada'))]
        
        stan_data = dict({
            'J': df_.shape[0],
            'mu_known': df_['mean'].values,
            'stderr_known': df_['SD'].values
        })
        stan_data_file = os.path.join(standirname, 'Data.json')
        cmdstan.write_stan_json(stan_data_file, stan_data)
        
        def bash_file(stanscriptdir):
            return f"""#!/bin/bash
cwd=$(pwd)
cd {standistribdir_}
make -j4 {stanscriptdir}/fit_R0_meta
cd {stanscriptdir}
mkdir -p diagnostics
for i in {{1..4}}
do
    echo Running ${{i}}
    SEEDNUMBER=$((1+$i))
    ./fit_R0_meta \\
        method=sample num_samples={num_iterations} num_warmup={num_warmup} thin=1 save_warmup=0 adapt delta=0.98 \\
            algorithm=hmc \\
                engine=nuts \\
        random seed=${{SEEDNUMBER}} \\
        id=$i \\
        data file=Data.json \\
        output file=trace-$i.csv \\
            diagnostic_file=diagnostics/diagnostics-$i.csv > diagnostics/output-$i.txt &
done
echo Finished haha!
"""            
        model = cmdstan.CmdStanModel(stan_file=stan_code_file)
        fit = model.sample(data=stan_data_file, seed = 1, iter_warmup=num_warmup, iter_sampling=1000,
                           show_console=False, show_progress=True, chains = 4)
        fit.save_csvfiles(dir=standirname)

        idata_meta = az.from_cmdstanpy(posterior=fit)
        df_meta_stats = get_stats(idata_meta.posterior, ['mu', 'tau_squared'])
        df_meta_stats['expwindow'] = Tupper_
        Df_meta_stats = df_meta_stats if Df_meta_stats is None else pd.concat([Df_meta_stats, df_meta_stats], ignore_index = True)
Df_meta_stats

In [None]:
for Tupper_ in [45, 30, 60]:
    fig = plt.figure(figsize = [5.5, 4.5]) 
    ax1 = fig.add_subplot()
    
    df_ = df_stats_countries.copy().merge(df_cumcases_selection.loc[:, ['country', 'cases', 'peak_cases', 'peak_date']])\
        .loc[lambda d: (d['var']=='R0')&(d['expwindow']==Tupper_)][::-1]
    
    df_['color'] = ['lightgrey' if (who_region=='AMRO')&(country!='united states of america')&(country!='canada') else 'k' for country, who_region in zip(df_.country,df_.who_region)]
    df_ = df_.sort_values('color')
    
    df_['Source'] = [x.title().replace(' Of ', ' of ') for x in df_['country']]
    df_['Source_y'] = [(i+1)/df_.shape[0] for i in range(df_.shape[0])]
    
    for clr_ in df_['color'].drop_duplicates().values:
        df__ = df_.loc[lambda d: d['color']==clr_]
        lsty = 'dashed' if clr_=='lightgrey' else 'solid'
        eb1 = ax1.errorbar(x=list(df__['median']), y=df__['Source_y'], 
                     xerr=[list(df__['median'] - df__['q2.5']), list(df__['q97.5'] - df__['median'])], color='k', capsize=2,
                     linestyle='None', linewidth=1, markersize=0)
        if clr_=='lightgrey':
            eb1[-1][0].set_linestyle('--')
            clr0_ = 'w'
        else:
            clr0_ = 'k'
        ax1.plot(list(df__['median']), df__['Source_y'], color='k', mfc=clr0_, 
                 linestyle='None', linewidth=1, marker="o", ms=5)
    
    # Pooled mean
    mean_meta = Df_meta_stats.loc[lambda d: d['expwindow']==Tupper_].loc[lambda d: d['var']=='mu', 'mean'].values[0]
    lower_meta = Df_meta_stats.loc[lambda d: d['expwindow']==Tupper_].loc[lambda d: d['var']=='mu', 'q2.5'].values[0]
    upper_meta = Df_meta_stats.loc[lambda d: d['expwindow']==Tupper_].loc[lambda d: d['var']=='mu', 'q97.5'].values[0]
    τsqr_meta = Df_meta_stats.loc[lambda d: d['expwindow']==Tupper_].loc[lambda d: d['var']=='tau_squared', 'mean'].values[0]
    y_ = -0.15/df_.shape[0]
    h_ = 0.3/df_.shape[0]
    from matplotlib.patches import Polygon
    pts = np.array([[lower_meta, y_], [mean_meta,y_+h_], [upper_meta, y_], [mean_meta, y_-h_], [lower_meta, y_]])
    p = Polygon(pts, closed=False, color='k')
    ax1.add_patch(p)
    
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_xlim(left = 0.8, right=3.0)
    ax1.set_xticks(np.arange(1, 3.5, .5))
    ax1.get_xaxis().tick_bottom()
    ax1.get_yaxis().tick_left()
    
    ax1.set_ylim(-1/df_.shape[0], np.max(df_['Source_y'])+.02)
    ax1.set_xlabel("Effective reproduction number")
    ax1.set_yticks(np.r_[df_['Source_y'], y_])
    ax1.set_yticklabels(np.r_[df_['Source'], [f'Pooled mean, τ² = {τsqr_meta:.2f}']] )
    
    ax2 = ax1.secondary_yaxis("right")
    ax2.set_yticks(ax1.get_yticks())
    ax2.spines['right'].set_visible(False)
    ax2.tick_params(length=0)
    df_['label'] = [f"{mu:.2f} ({mu_lower:.2f}–{mu_upper:.2f})" for mu, mu_lower, mu_upper in zip(df_['mean'], df_['q2.5'], df_['q97.5'])]
    ax2.set_yticklabels(np.r_[df_['label'], [f'{mean_meta:.2f} ({lower_meta:.2f}–{upper_meta:.2f})']], ha='center', fontsize=9)
    ax2.yaxis.set_tick_params(pad=45)

    ax3 = ax1.secondary_yaxis("right")
    ax3.set_yticks(ax1.get_yticks())
    ax3.spines['right'].set_visible(False)
    ax3.tick_params(length=0)
    df_['label2'] = [f"{(peak_date-pd.DateOffset(days=Tupper_)).strftime('%m/%d')}-{peak_date.strftime('%m/%d')}" for peak_date in df_['peak_date']]
    ax3.set_yticklabels(np.r_[df_['label2'], ['']], ha='center', fontsize=9)
    ax3.yaxis.set_tick_params(pad=125)
    
    [label.set_fontweight('bold') if 'Pool' in label.get_text() else None for label in ax1.get_yticklabels()]
    [label2.set_fontweight('bold') if 'Pool' in label.get_text() else None for label, label2 in zip(ax1.get_yticklabels(),ax2.get_yticklabels())]
    
    [label.set_color('k') if 'Pool' in label.get_text() else None for label in ax1.get_yticklabels()]
    [label2.set_color('k') if 'Pool' in label.get_text() else None for label, label2 in zip(ax1.get_yticklabels(),ax2.get_yticklabels())]
    
    [label.set_color(clr_chn_) if 'China' in label.get_text() else None for label in ax1.get_yticklabels()]
    [label2.set_color(clr_chn_) if 'China' in label.get_text() else None for label, label2 in zip(ax1.get_yticklabels(),ax2.get_yticklabels())]
    
    plt.text(.44, 1.07, 'Country', horizontalalignment='left', fontsize=10)
    plt.text(3.59, 1.07, 'Mean (95% CrI)', horizontalalignment='right', fontsize=10)
    plt.text(3.89, 1.05, 'Period\n(m/d in 2022)', horizontalalignment='center', fontsize=10)
    
    plt.rcParams["axes.labelweight"] = "bold"
    if save_figures:
        plt.savefig(f"../../figures/fig-R0_meta_only_sympt_Tupper-{Tupper_}_Miura2023.pdf", format="pdf", bbox_inches="tight")
    
    plt.show()