In [None]:
import pandas as pd
import numpy as np
import h5py
import os
import datetime
import copy
import matplotlib.pyplot as plt
from matplotlib import ticker
plt.rcParams['figure.facecolor'] = 'w'
plt.rcParams['axes.facecolor'] = 'w'

### Model parameters

In [None]:
# Define compartments
secir_dict = {0:'Susceptible', 1:'Exposed',2:'Carrier', 3:'Infected', 4:'Hospitalized',
              5:'ICU', 6:'Recovered', 7:'Death'}


# Define age groups
age_groups = ['0-4 Years', '5-14 Years', '15-34 Years', '35-59 Years', '60-79 Years', '80+ Years']

# Define population data for incidence values and relative plots
base = 100000
age_group_sizes = np.array([3961376,7429883,19117865,28919134,18057318,5681135])

relative_dict = {}
for i in range(len(age_group_sizes)):
    relative_dict['Group' + str(i+1)] = age_group_sizes[i]/base
    
relative_dict['Total'] = np.sum(age_group_sizes)/base   

### Simulation parameters 

In [None]:
# Define start day and simulation period
year, month, day = '2020', '6', '1'
start_date = pd.Timestamp(year + '.' + month.zfill(2) + '.' + day.zfill(2))
tmax = '90'
daysPlot = 90

# Define scenario path and different folders that will be read and plotted
date_str = '_' + str(year) + '_' + str(month) + '_' + str(day) + '_' + str(tmax)
path_sim = 'data/'
path_rki = 'data/extrapolated_rki_results'
scenario_list = ['']

### Loading data

In [None]:
files = open_files(read_casereports_extrapolation=plotRKI)
files[''].keys()

In [None]:
# Opens files from folder
# @param path_sim Path where simulation files have been written
# @param path_rki Path where extrapolated real data have been written
# @param spec_str_sim Specified string after results (e.g. date) that points to a specific set of scenario folders
# @param spec_str_rki1 Specified string in results folder (e.g. date) that points to a specific RKI data folder
# @param spec_str_rki2 Specified string in results file that points to a specific RKI data file
# @param scenario_list List of string indicators for scenarios to be plotted
# @param percentiles List of percentiles to be printed (sublist from ['p50','p25','p75','p05','p95'])
# @param read_casereports_extrapolation Defines if extrapolated reporting data (from RKI) will be loaded
def open_files(path_sim = path_sim, spec_str_sim = date_str, path_rki = path_rki, spec_str_rki1 = date_str, spec_str_rki2 = '',
               scenario_list = scenario_list, percentiles = ['p50','p25','p75','p05','p95'], read_casereports_extrapolation = False):
    
    files = {}

    for scenario in scenario_list:
        files[scenario] = {}

        path = path_sim + 'results' + spec_str_sim + scenario

        for p in percentiles:
            files[scenario][p] = h5py.File(path + '/' + p + '/Results_sum.h5', 'r')

        if read_casereports_extrapolation:
            files[scenario]['RKI'] = h5py.File(path_rki + spec_str_rki1 + '/Results_rki_sum' + spec_str_rki2 +'.h5', 'r')
                
    
    return files

# Closes file handles in @files
# @param files File handles of open HDF5 files
def close_files(files):
    for group in files:
        for file in files[group]:
            files[group][file].close()

### Plot parameters

In [None]:
# define colors for age groups
def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)

plt_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

colors = {}
colors['Total'] = plt_colors[0]
for i in range(len(age_groups)):
    colors['Group' + str(i+1)] = plt_colors[i+1]

In [None]:
plotRKI = True          # Plots RKI Data if true
plotRelative = False     # Plots incidence values if true
plotPercentiles = True  # Plots 25 and 75 percentiles if true
plotConfidence = True   # Plots 05 and 95 percentiles if true

In [None]:
savePlot = True          # saves plot file if true
if savePlot:
    try:
        os.mkdir('Plots')
    except:
        print('Directory "Plots" already exists')

In [None]:
opacity = 0.15           
lineWidth = 3.5
fontsize = 18
figsize = (16, 10)

#define x-ticks for plots
datelist = np.array(pd.date_range(start_date.date(), periods=daysPlot, freq='D').strftime('%m-%d').tolist())
tick_range = (np.arange(int(daysPlot / 10) + 1) * 10)
tick_range[-1] -= 1

### Plot functions

In [None]:
def plot_results(files, comp_idx, title, regionid='0', key='Total'):
    fig, ax = plt.subplots(figsize=figsize)
    
    if plotRelative:
        factor = relative_dict[key]
    else:
        factor = 1
        
    if 'Group1' not in files['p50'].keys(): 
        files_plot_p50 = files['p50'][regionid]
        X = files['p50'][regionid]['Time'][:]
    else: # backward stability for IO as of 2020/2021
        files_plot_p50 = files['p50']
        X = files['p50']['Time'][:]
    
    ax.plot(X, files_plot_p50[key][:, comp_idx]/factor, label='p50',
            color=colors[key], linewidth=lineWidth)
    if plotPercentiles:   
        if 'Group1' not in files['p25'].keys():
            files_plot_p25 = files['p25'][regionid]
            files_plot_p75 = files['p75'][regionid]
        else: # backward stability for IO as of 2020/2021 
            files_plot_p25 = files['p25']
            files_plot_p75 = files['p75']

        ax.plot(X, files_plot_p25[key][:, comp_idx]/factor,'--', label='p25',
                color=colors[key], linewidth=lineWidth)
        ax.plot(X, files_plot_p75[key][:, comp_idx]/factor,'--',  label='p75',
                color=colors[key], linewidth=lineWidth)
        ax.fill_between(X, files_plot_p25[key][:, comp_idx]/factor,
                        files_plot_p75[key][:, comp_idx]/factor,
                        color=colors[key], alpha=opacity) 
    if plotConfidence: 
        if 'Group1' not in files['p05'].keys():
            files_plot_p05 = files['p05'][regionid]
            files_plot_p95 = files['p95'][regionid]
        else: # backward stability for IO as of 2020/2021 
            files_plot_p05 = files['p05']
            files_plot_p95 = files['p95']    

        ax.plot(X, files_plot_p05[key][:, comp_idx]/factor,'--', label='p05',
                color=colors[key], linewidth=lineWidth)
        ax.plot(X, files_plot_p95[key][:, comp_idx]/factor,'--',  label='p95',
                color=colors[key], linewidth=lineWidth)
        ax.fill_between(X, files_plot_p05[key][:, comp_idx]/factor, 
                        files_plot_p95[key][:, comp_idx]/factor,
                        color=colors[key], alpha=opacity) 
        
    if plotRKI:
        if 'RKI' in files.keys():
            if 'Group1' not in files['p05'].keys():
                files_rki = files['RKI'][regionid]
            else: # backward stability for IO as of 2020/2021 
                files_rki = files['RKI']
            ax.plot(X, files_rki[key][:, comp_idx]/factor,'--', label='RKI',
                color='gray', linewidth=lineWidth)
        else:
            print('Error: Plotting extrapolated real data demanded but not read in.')
        

    ax.set_title(title, fontsize=18)
    ax.set_xticks(tick_range)
    ax.set_xticklabels(datelist[tick_range], rotation=45, fontsize=fontsize)
    if plotRelative:
        ax.set_ylabel('individuals relative per 100.000', fontsize=fontsize)
    else:
        ax.set_ylabel('number of individuals', fontsize=fontsize)
    ax.legend(fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    ax.grid(linestyle='dotted')
    
    formatter = ticker.ScalarFormatter(useMathText=True)
    formatter.set_scientific(True) 
    formatter.set_powerlimits((-1,1)) 
    ax.yaxis.set_major_formatter(formatter) 
    ax.yaxis.offsetText.set_fontsize(fontsize)
    
    if savePlot:
        fig.savefig('Plots/' + title + '.png')


### Plot total population 

In [None]:
files = open_files(read_casereports_extrapolation=plotRKI)
for scenario in scenario_list:
    for compart in range(len(secir_dict)):
        plot_results(files[scenario], compart, secir_dict[compart], key='Total')
        
close_files(files)

### Plot group population

In [None]:
print(files['']['p50'].keys())

In [None]:
files = open_files(read_casereports_extrapolation=plotRKI)
for scenario in scenario_list:
    for compart in range(len(secir_dict)):
        for group in range(len(age_groups)):
            plot_results(files[scenario], compart, secir_dict[compart] + ' ' + age_groups[group],
                         key='Group' + str(group+1))
            
close_files(files)