In [None]:
import arviz
import numpy as np
import os

In [None]:
def summarize_and_save(country, date, file, func_dict):
    """
    Calculates and saves summary for trace data.
    :param country: the country code / folder from which to load
    :param date: name of subfolder containing trace data for date of interest
    :param file: name of the file to process
    :param func_dict: dictionary defining which statistics should be contained in the summary
    :return: None
    """
    # path where trace data is located
    trace_path = country + "/" + date + "/" + file
    
    # define name of resulting .csv file
    save_file = country + "_" + date + "_" + file.replace(".nc", "") + "_summary.csv"
    
    if file.endswith("trace.nc"):

        # read data and extract posterior
        idata = arviz.from_netcdf(trace_path)
        posterior = idata.posterior.stack(sample=("chain", "draw"))

        # calculate summary
        summary = arviz.summary(
            idata,
            var_names=["r_t"],
            stat_funcs=func_dict,
            extend=False
        )
        summary = summary.set_index(np.array(posterior.date))

        # save as .csv
        summary.to_csv(save_path + save_file)
        
    else:
        print("Please pass trace data to return summary.")

In [None]:
# folder from which to process files:
country = "DE"

# tag to match file names against. Should be set to "all" for the national level
tag = "" # empty string to process all files, "all" for national level only

# define path where resulting summaries are saved
save_path = "rtlive_summaries/" 

# make sure the folder exists
if not os.path.isdir(save_path):
    os.mkdir(save_path)

In [None]:
func_dict = {
    "mean": np.mean,
    "std": np.std,
    "2.5%": lambda x: np.percentile(x, 2.5),
    "25%": lambda x: np.percentile(x, 25),
    "median": lambda x: np.percentile(x, 50),
    "75%": lambda x: np.percentile(x, 75),
    "97.5%": lambda x: np.percentile(x, 97.5),
}

In [None]:
for date in os.listdir(country):
    print("Starting " + date + "...")
    for file in os.listdir(country + "/" + date):
        if tag in file: # process only files containing a specific tag
            summarize_and_save(country, date, file, func_dict)