In [22]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.patches import ConnectionPatch, ArrowStyle
import os 
import datetime
import cupy as cp


import matplotlib.figure as figure
from model import configuration
from model import plots_funtions
from model import parameters_control
from model import simulation_functions

In [48]:
COUNTRY = 'Spain'

deaths_ref = configuration.load_deaths_list(COUNTRY)
deaths_ref = configuration.smooth_deaths_list(deaths_ref)

p_active =  configuration.smooth_deaths_list(configuration.load_p_active(COUNTRY))

config = configuration.read_configuration(COUNTRY, prefix="used/", sufix="")
config["n_simulations"] = 3000000

In [24]:

def specific_kernel(country, config):
    parameters_manager = parameters_control.Params_Manager(config)

    fixed_params = cp.zeros(len(parameters_control.fixed_params_to_index), dtype=cp.float64)
    parameters_manager.set_fixed_params(fixed_params)

    _n_simulations = config["simulation"]["n_simulations"]
    params = cp.zeros( (len(parameters_control.param_to_index),_n_simulations) , dtype=cp.float64)
    log_diff = cp.zeros(_n_simulations, dtype=cp.float64)
    


    ## Cargar muertes y p_active
    deaths_list = configuration.load_deaths_list(country)
    deaths_list_smooth = configuration.smooth_deaths_list(deaths_list)
    p_active = configuration.smooth_deaths_list(configuration.load_p_active(country))



    ## Bucle principal
    for execution in range(config["simulation"]["n_executions"]):

        ## Preparar ejecución
        log_diff[:] = 0
        parameters_manager.set_params(params)
        states = plots_funtions.prepare_states(params, config["total_population"])

        ## Simular estados
        simulation_functions.evolve_gpu(params, fixed_params, states, p_active, deaths_list_smooth, log_diff, config)
        
        ## Buscar mejores simulaciones
        best_params, best_log_diff =simulation_functions.get_best_parameters(params, log_diff, 0.1)
        
    return best_params, best_log_diff

In [44]:
NEW_MAX_DAYS = config["max_days"] + 40

fit_times = [50, 65, 80]

_5p = np.zeros((len(fit_times), NEW_MAX_DAYS))
_median = np.zeros((len(fit_times), NEW_MAX_DAYS))
_95p = np.zeros((len(fit_times), NEW_MAX_DAYS))

parameters_manager = parameters_control.Params_Manager(config)
fixed_params = cp.zeros(len(parameters_control.fixed_params_to_index), dtype=cp.float64)
parameters_manager.set_fixed_params(fixed_params)


In [45]:
for index, MAX_DAYS in enumerate(fit_times):
    
    params, log_diff = specific_kernel(COUNTRY, config)
    p05, p50, p95 = plots_funtions.get_states_boundaries(MAX_DAYS, params, fixed_params, p_active, config)

    _5p[index] = p05
    _median[index] = p50
    _95p[index] = p95


In [50]:
colors = ['#FF2C00', '#0C5DA5', '#00B945']

fontsize='small'


with plt.style.context('science'):
    


    w, h = figure.figaspect(2.5/4)
    fig_ = figure.Figure(figsize=(w, h))
    ax_ = fig_.add_axes([0.1, 0.1, 0.8, 0.8])
    # fig_, ax_ = plt.subplots(figsize=(4,2.5))
    _visibility = 0.1

    time_list =  np.arange(NEW_MAX_DAYS)

    linewidth=1

    ls_modelo = []

    for index, MAX_DAYS in enumerate(fit_times):

            ax_.annotate("", xy=(MAX_DAYS, 0), xycoords='data',
            xytext=(MAX_DAYS, max(deaths_ref)/4), textcoords='data', arrowprops=dict(
                arrowstyle='->', 
                connectionstyle='arc3', 
                linewidth=linewidth,  
                color=colors[index],
                alpha=1,
            ), zorder=5)

            ax_.fill_between(time_list, _5p[index], _95p[index], color=colors[index], alpha=0.15, zorder=0, lw=0)
            l_modelo, = ax_.plot(time_list, _median[index], '-.', color=colors[index], label=f"{MAX_DAYS}", zorder=1)
            ls_modelo.append(l_modelo)



    # l_reales, = ax_.plot(time_list[:MAX_DAYS], deaths_ref.get()[0:MAX_DAYS], linestyle='dashed', linewidth=linewidth , color='red', zorder=2, label="Datos usados")
    # ax_.plot(time_list[MAX_DAYS:NEW_MAX_DAYS], deaths_ref.get()[MAX_DAYS:NEW_MAX_DAYS], linestyle='dotted', color='red', zorder=1)
    ax_.plot(time_list[:MAX_DAYS], deaths_ref.get()[0:MAX_DAYS], linewidth=linewidth*1.5 , color='black', zorder=1)
    l_usados, = ax_.plot(time_list[:NEW_MAX_DAYS], deaths_ref.get()[:NEW_MAX_DAYS], linestyle='solid', linewidth=linewidth, color='black', zorder=1, label='Datos suavizados')



    ax_p_active = ax_.twinx()
    l_active, = ax_p_active.plot(time_list[:NEW_MAX_DAYS], p_active.get()[:NEW_MAX_DAYS], linestyle='dotted', color='green', zorder=-1, label='p(t) Google')
    # ax_p_active.plot(time_list[:NEW_MAX_DAYS]-6, p_active.get()[:NEW_MAX_DAYS], '-.', color='orange', label='p_active offset')
    ax_p_active.set_ylabel('$p(t)$')
    # ax_p_active.legend(loc='upper center')
    ax_p_active.set_ylim([0,1.1])
    ax_p_active.set_yticks([0.1*i for i in range(0,11,5)])#, alpha=_visibility)
    
    # ax_.set_xlabel('Days after 22 January 2020')
    ax_.set_ylabel('Muertes diarias', fontsize=fontsize)
    # ax_.set_title('Fatalities per day')

    ax_.legend(handles=[l_usados, l_active], loc='upper center', bbox_to_anchor=(0.5, 1.13), ncol=2, frameon=False, fontsize=fontsize)
    ax_p_active.legend(handles=ls_modelo, loc='center left', frameon=False, fontsize=fontsize)
    
    ax_.set_xlim(xmin=0)#, xmax=MAX_DAYS)
    ax_.set_ylim([0,max(deaths_ref.get()[:MAX_DAYS])*1.3])
    
    first_day_deaths_list = datetime.datetime(*(map(int, config["first_day_deaths_list"].split('-'))))
    ticks = [i for i in range(10, NEW_MAX_DAYS, 14)]
    ticks_labels = [configuration.date_to_spanish((first_day_deaths_list+datetime.timedelta(i)).strftime("%b\n%d")) for i in ticks]
    
    ax_.set_xticks(ticks)#, alpha=_visibility)
    # ax_.set_yticks(ax_.get_yticks()[::2])#, alpha=_visibility)
    ax_.set_xticklabels(ticks_labels, fontsize=fontsize)
    ax_.grid(alpha=_visibility*5)
    

    # ax_.spines['bottom'].set(alpha=_visibility) #.set_color('#dddddd')
    # ax_.spines['top'].set(alpha=_visibility) #.set_color('#dddddd') 
    # ax_.spines['right'].set(alpha=_visibility) #.set_color('#dddddd')
    # ax_.spines['left'].set(alpha=_visibility) #.set_color('#dddddd')
    
    # ax_p_active.spines['bottom'].set(alpha=_visibility) #.set_color('#dddddd')
    # ax_p_active.spines['top'].set(alpha=_visibility) #.set_color('#dddddd') 
    # ax_p_active.spines['right'].set(alpha=_visibility) #.set_color('#dddddd')
    # ax_p_active.spines['left'].set(alpha=_visibility) #.set_color('#dddddd')

    
    filename_path = f"images/images_by_country/{COUNTRY}/"
    os.makedirs(filename_path, exist_ok=True)
    fig_.savefig(f'{filename_path}/predictions_landscape.png', dpi=600)
    fig_.savefig(f'{filename_path}/predictions_landscape.pdf')
    plt.close(fig_)