In [1]:
from model_complex import Forecats, FactoryBRModel, EpidData, Calibration
import matplotlib.pyplot as plt
import datetime
from sklearn.metrics import r2_score
import csv

def forecast_plot(type, st_time, end_time, end_prog, city, method):
    epid_data = EpidData(city=city, path='./', 
                    start_time=st_time, end_time=end_time)
    epid_data.get_wave_data(regime=type)
    data = epid_data.prepare_for_calibration()
    rho = epid_data.get_rho()//10
    plot_data = epid_data.prepare_for_plot()

    if type == 'age':
        init_infect = [100, 100]
        model = FactoryBRModel.age_group()
        label = {0: '0-14 years', 1: '15+ years'}
    else:
        init_infect = [100]
        model = FactoryBRModel.total()
        label = {0: 'total'}

    calibration = Calibration(init_infect, model, data, rho)

    if method == 'mcmc':
        alpha, beta = calibration.mcmc_calibration()
    else:
        alpha, beta = calibration.abc_calibration()

    dur = (datetime.datetime.strptime(end_prog, "%d-%m-%Y") - datetime.datetime.strptime(end_time, "%d-%m-%Y")).days//7

    Res = Forecats(data, model, init_infect, alpha, beta, rho, dur).forecast()

    color = {0: 'blue', 1: 'orange'}

    epid_data = EpidData(city=city, path='./', 
                    start_time=end_time, end_time=end_prog)
    epid_data.get_wave_data(regime=type)
    new_plot_data = epid_data.prepare_for_plot()

    model.simulate(
        alpha=[a.mean() for a in alpha], 
        beta=[b.mean() for b in beta], 
        initial_infectious=init_infect, 
        rho=rho, 
        modeling_duration=len(data)//len(init_infect)
    )

    res = model.get_result()

    model.simulate(
        alpha=[a.mean() for a in alpha], 
        beta=[b.mean() for b in beta], 
        initial_infectious=init_infect, 
        rho=rho, 
        modeling_duration=(len(data)//len(init_infect) +len(new_plot_data[:,0]))
    )

    new_res = model.get_result()

    return plot_data, new_plot_data, res, new_res, Res





In [2]:
times = [
    ['01-07-2010', '20-01-2011', '30-06-2011'],
    ['01-07-2015', '20-01-2016', '30-06-2016'],
    ['01-08-2012', '20-03-2013', '30-06-2013'],
    ['01-08-2013', '20-03-2014', '30-06-2014'],
    ['01-10-2011', '10-04-2012', '30-06-2012']
]

In [3]:

import matplotlib.dates as mdates
import numpy as np
import locale
locale.setlocale(locale.LC_TIME, 'ru_RU.UTF-8')


fig, ax = plt.subplots(figsize=(12,8))

for i in range(len(times)):
        start, end_time, end = times[i]
        label_fontsize = 18
        tick_fontsize = 16
        type = 'age'
        city = 'spb'
        method = 'abc'

        start_date = np.datetime64(f'{start.split("-")[2]}-{start.split("-")[1]}-{start.split("-")[0]}')
        end_date = np.datetime64(f'{end.split("-")[2]}-{end.split("-")[1]}-{end.split("-")[0]}')


        # Generate an array of dates with a weekly step
        date_array = np.arange(start_date, end_date + np.timedelta64(1, 'D'), np.timedelta64(1, 'W'))

        plot_data, new_plot_data, res, new_res, Res = forecast_plot(type, start, end_time, end, city, method)
        r2 = []
        data_lab = {0: '0-14', 1:'15+'}
        age_color = {0: 'forestgreen', 1: 'royalblue'}
        mod_col = {0: "crimson", 1: 'orangered'}
        for i in range(2):
                actual_data = plot_data[:, i]
                forecast_data = new_plot_data[:, i]
                model_curve = new_res[i]
                lower_ci = Res[i, :, 0]
                upper_ci = Res[i, :, 1]

                ax.set_xticklabels(date_array, rotation=20)
                ax.xaxis.set_major_locator(mdates.MonthLocator())  # Set major ticks to months
                ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))  # Format ticks as 'Month Year'
                # ax.set_ylim([-50, 35000])
                ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
                ax.set_ylabel('Заболеваемость, случаи', fontsize=label_fontsize)

                ax.plot(date_array[:len(actual_data)], actual_data, '--o', color=age_color[i], label='Доступные данные ' + data_lab[i])
                ax.fill_between(date_array[len(actual_data):-1], lower_ci[len(actual_data):], 
                                upper_ci[len(actual_data):], color=age_color[i], alpha=0.3, label='Доверительный интервал ' + data_lab[i])
                ax.plot(date_array[len(actual_data):-1], forecast_data, '--o', color=age_color[i], 
                        alpha=0.6, label='Прогнозируемые данные ' + data_lab[i])
                r2 += [round(r2_score(actual_data, model_curve[:len(actual_data)]), 2)]
                ax.plot(date_array[:-1], model_curve, linewidth=3, label='Модельная кривая ' + data_lab[i], color=age_color[i], alpha=0.8)
                ax.legend(fontsize=tick_fontsize)


        ax.set_title(r'$R^2' + '_{0-14}' + '={}$'.format(r2[0]) +  '  $R^2' + '_{15+}' +'={}$'.format(r2[1]), fontsize=label_fontsize)


        fig.savefig(r'Fplots/forecast_total_abc_{}-{}.png'.format(str(start_date), str(end_date)), dpi=400, bbox_inches='tight')
        fig.savefig(r'Fplots/forecast_total_abc_{}-{}.pdf'.format(str(start_date), str(end_date)), bbox_inches='tight')


Error: unsupported locale setting