# SEIR & Machine Learning Model for COVID19 Global forecast

In [1]:
import dill
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from numbers import Number
from scipy.optimize import differential_evolution
from scipy.optimize import brute
from scipy.integrate import solve_ivp
from sklearn.metrics import mean_squared_log_error
from sklearn.metrics import mean_squared_error

# Definição das funções utilizadas

### Função para leitura dos dados

In [2]:
def get_country_data(country):
    
    confirmed_filename = '../jhu_data/time_series_covid19_confirmed_global.csv'
    deaths_filename = '../jhu_data/time_series_covid19_deaths_global.csv'
    recovered_filename = '../jhu_data/time_series_covid19_recovered_global.csv'

    var_dict = {'confirmed': confirmed_filename,
                'deaths': deaths_filename,
                'recovered': recovered_filename}

    country_data = pd.DataFrame()

    for var_name, var_file in zip(var_dict.keys(), var_dict.values()):
        var_global = pd.read_csv(var_file)

        var_country = var_global[var_global['Country/Region']==country]
        var_country = var_country.loc[:, '1/22/20':]
        var_country = var_country.sum()

        country_data[var_name] = var_country
        
    country_data.index = pd.to_datetime(country_data.index)
    country_data = country_data[country_data['confirmed'] > 0]
    country_data['infection_days'] = np.arange(1, len(country_data)+1)
    
    return country_data

### Carregando datasets para avaliação

In [3]:
x_train_scaled = pd.read_csv('../data_eval/out(confirmed)in(confirmed-infection_days)/lf20_lb13/x_train_lf20_lb13.csv')
x_val_scaled = pd.read_csv('../data_eval/out(confirmed)in(confirmed-infection_days)/lf20_lb13/x_val_lf20_lb13.csv')
x_test_scaled = pd.read_csv('../data_eval/out(confirmed)in(confirmed-infection_days)/lf20_lb13/x_test_lf20_lb13.csv')

y_train_scaled = pd.read_csv('../data_eval/out(confirmed)in(confirmed-infection_days)/lf20_lb13/y_train_lf20_lb13.csv')
y_val_scaled = pd.read_csv('../data_eval/out(confirmed)in(confirmed-infection_days)/lf20_lb13/y_val_lf20_lb13.csv')
y_test_scaled = pd.read_csv('../data_eval/out(confirmed)in(confirmed-infection_days)/lf20_lb13/y_test_lf20_lb13.csv')

### Carregando objeto "Scaler" para converter datasets de volta para a escala original

In [4]:
with open('../data_eval/scaler.dill', 'rb') as scaler_file:
    scaler = dill.load(scaler_file)

### Convertendo dados para a escala original

In [5]:
x_train = scaler.get_original_scale(x_train_scaled, lb=14, lf=None)
x_val = scaler.get_original_scale(x_val_scaled, lb=14, lf=None)
x_test = scaler.get_original_scale(x_test_scaled, lb=14, lf=None)

y_train = scaler.get_original_scale(y_train_scaled, lb=None, lf=20)
y_val = scaler.get_original_scale(y_val_scaled, lb=None, lf=20)
y_test = scaler.get_original_scale(y_test_scaled, lb=None, lf=20)

### Selecionando dataset para EUA

In [6]:
cols_to_drop = ['infection_days_t-{}'.format(x) for x in range(1, 14)]
x_usa = x_train.loc[x_train.loc[:,'region']=='US', :]
x_usa = x_usa.drop(columns=cols_to_drop)

y_usa = y_train.loc[y_train.loc[:,'region']=='US', :]

### Selecionando dataset para Brasil

In [7]:
cols_to_drop = ['infection_days_t-{}'.format(x) for x in range(1, 14)]
x_brazil = x_test.loc[x_test.loc[:,'region']=='Brazil', :]
x_brazil = x_brazil.drop(columns=cols_to_drop)

y_brazil = y_test.loc[y_test.loc[:,'region']=='Brazil', :]

### Selecionando dataset para Coréia do Sul

In [8]:
cols_to_drop = ['infection_days_t-{}'.format(x) for x in range(1, 14)]
x_skorea = x_train.loc[x_train.loc[:,'region']=='Korea, South', :]
x_skorea = x_skorea.drop(columns=cols_to_drop)

y_skorea = y_train.loc[y_train.loc[:,'region']=='Korea, South', :]

In [9]:
class SEIR:
    def __init__(self):
        pass
    
    def fit(self, data, pop, bounds=None, **kwargs):
        
        if bounds is None:
            self.bounds = {
                'R_0': (0.1, 8),
                'T_inc': (0.1, 8),
                'T_inf': (0.1, 15),
                'N_inf': (1, 100)
            }
        else:
            self.bounds = bounds
        
        self.params = SEIR_fit(data=data, population=pop, bounds=self.bounds, **kwargs)
    
    def evaluate(self, data, pop, **kwargs):
        
        loss = SEIR_evaluate(self.params, data=data, population=pop, **kwargs)
        
        return loss
    
    def predict(self, data, pop, look_forward):
        
        region = data['region'].values
        past_days = data['infection_days_t0'].iloc[-1]
        n_days = past_days + look_forward
        
        solution = SEIR_solve(self.params, population=pop, n_days=n_days)
        
        sus, exp, inf, res = solution
        pred_confirmed = inf+res
        
        y_pred_confirm = np.vstack(
            [pred_confirmed[(day):(day+look_forward)] for day in data['infection_days_t0']]
        )

        confirmed_cols = ['confirmed_t+{}'.format(x) for x in range(1, look_forward+1)]
        y_pred_confirm_df = pd.DataFrame(y_pred_confirm, columns=confirmed_cols)
        y_pred_confirm_df.insert(0, column='region', value=region)
        
        return y_pred_confirm_df

### Funções do modelo SEIR

In [10]:
# Susceptible equation
def dS_dt(S, I, R_t, T_inf):
    return -(R_t / T_inf) * (I * S)

# Exposed equation
def dE_dt(S, E, I, R_t, T_inf, T_inc):
    return (R_t / T_inf) * (I * S) - (1/T_inc) * E

# Infected equation
def dI_dt(I, E, T_inc, T_inf):
    return (1/T_inc) * E - (1/T_inf) * I

# Resistant equation
def dR_dt(I, T_inf):
    return (1/T_inf) * I

def SEIR_equations(t, y, R_t, T_inf, T_inc):
    
    if callable(R_t):
        reproduction = R_t(t)
    else:
        reproduction = R_t
    
    S, E, I, R = y
    
    S_out = dS_dt(S, I, reproduction, T_inf)
    E_out = dE_dt(S, E, I, reproduction, T_inf, T_inc)
    I_out = dI_dt(I, E, T_inc, T_inf)
    R_out = dR_dt(I, T_inf)
    
    return [S_out, E_out, I_out, R_out]

### Função para predições do modelo

In [11]:
def SEIR_solve(params, population, n_days, verbose=False):
    
    R_0 = params['R_0']
    T_inc = params['T_inc']
    T_inf = params['T_inf']
    N_inf = params['N_inf']
    
    # State at time = 0 for SEIR model
    susceptible_0 = (population-N_inf)/population
    exposed_0 = 0
    infected_0 = N_inf/population
    resistant_0 = 0
    y0 = [susceptible_0, exposed_0, infected_0, resistant_0]
    
    solution = solve_ivp(SEIR_equations, [0, n_days], y0, method='RK45', max_step=1,
                         args=(R_0, T_inf, T_inc), t_eval=np.arange(n_days))
    
    if solution.success == False:
        print('Status {}: {}'.format(solution.status, solution.message))
    
    solution = solution.y*population
    
    if (solution[2].shape[0] < n_days) or (solution[3].shape[0] < n_days):
        print('n_days:', n_days)
        print(solution)
    
    
    solution = np.around(solution, 0)
    
    if verbose:
        # Check for negative values on the predictions
        for pred, pred_name in zip(solution, ['susceptible', 'exposed', 'infected', 'resistant']):
            if not np.all(pred>=0):
                print('WARNING: Negative values on the {} predictions!'.format(pred_name))
    
    return solution

### Função para avaliação do modelo

In [12]:
def SEIR_evaluate(params, data, population, **kwargs):
    
    # Getting function parameters
    metric = kwargs.get('metric', 'mse')
    weights = kwargs.get('weights', None)
    
    n_days = data['infection_days_t0'].iloc[-1]
    
    solution = SEIR_solve(params, population, n_days)
    
    infect_pred = solution[2]
    resist_pred = solution[3]
    confirm_pred = infect_pred+resist_pred
    
    y_true_confirm = data.filter(regex='^confirmed', axis=1).values
    window = y_true_confirm.shape[1]
    
    y_pred_confirm = np.vstack( [confirm_pred[(day-window):(day)] for day in data['infection_days_t0']] )
    
#     try:
#         y_pred_confirm = np.vstack( [confirm_pred[(day-window):(day)] for day in data['infection_days_t0']] )
#         old_confirm_pred = confirm_pred
#     except:
#         print('confirm_pred')
#         print(confirm_pred.shape)
#         print(confirm_pred)
#         print('old_confirm_pred')
#         print(old_confirm_pred.shape)
#         print(old_confirm_pred)
        
    
#     weights = 1 / np.arange(1, n_days+1)[::-1]  # Recent data is more heavily weighted
    
    if metric == 'mse':
        loss = mean_squared_error(y_true_confirm, y_pred_confirm, weights)
    if metric == 'rmse':
        loss = mean_squared_error(y_true_confirm, y_pred_confirm, weights, squared=False)
    elif metric == 'msle':
        loss = mean_squared_log_error(y_true_confirm, y_pred_confirm, weights)
    
    return loss

### Função para otimização dos parâmetros do modelo

In [13]:
def _loss_func_helper(params_vals, params_names, fixed_params_dict, data, population, weights, metric):
    params = dict(zip(params_names, params_vals))
    params.update(fixed_params_dict)
    loss = SEIR_evaluate(params, data, population, weights=weights, metric=metric)
    return loss

In [14]:
def SEIR_fit(data, population, bounds, **kwargs):
    
    # Getting function parameters
    metric = kwargs.get('metric', 'mse')
    weights = kwargs.get('weights', None)
    method = kwargs.get('method', 'diff_evol')
    Ns = kwargs.get('Ns', 10)
    verbose = kwargs.get('verbose', False)
    
    # Expected keys in bounds dictionary
    params_list = ['R_0', 'T_inc', 'T_inf', 'N_inf']
    
    if not (set(bounds.keys()) == set(params_list)):
        print('Dict keys must be {}. Returning None.'.format(params_list))
        return None

    optim_params_names = list()
    optim_params_bounds = list()
    fixed_params_names = list()
    fixed_params_vals = list()
    
    # Checking if params are fixed or to optimize
    for param in params_list:
        
        # Param bounds is a number, so it's a fixed param.
        if isinstance(bounds[param], Number):
            fixed_params_names.append(param)
            fixed_params_vals.append(bounds[param])
        
        # Param bounds is a sequence (min, max), so it's a param to optimize.
        elif isinstance(bounds[param], (list, tuple)):
            
            if len(bounds[param]) == 2:
                optim_params_names.append(param)
                optim_params_bounds.append(bounds[param])
            else:
                print(
                    'ERROR: {} bounds has length {}.'.format(param, len(bounds[param])),
                    'Bounds must be in the form (min,max).',
                    'Returning None.'
                )
                return None
        
        # Param bounds is a slice (start, stop, step), so it's a param to optimize using 'brute'.
        elif isinstance(bounds[param], slice):
            
            if method == 'brute':
                optim_params_names.append(param)
                optim_params_bounds.append(bounds[param])
            else:
                print(
                    'ERROR: {} bounds are in a slice object.'.format(param),
                    'Slice object must be used with method="brute".',
                    'Returning None.'
                )
                return None
        else:
            print(
                'ERROR: {} bounds specified as {}.'.format(param, type(bounds[param])),
                'Bounds must be a number, for a fixed parameter, or',
                'a sequence in the form (min,max) or a slice object',
                'in the form (start, stop, step) if method="brute".',
                'Returning None.'
                )
            return None
    
    fixed_params_dict = dict(zip(fixed_params_names, fixed_params_vals))
    
    if len(optim_params_bounds) == 0:
        print('No parameters to optimize. Returning fixed parameters.')
        return fixed_params_dict
    
    if method == 'diff_evol':
        
        optim_res = differential_evolution(
            _loss_func_helper, bounds=optim_params_bounds,
            args=(optim_params_names, fixed_params_dict,
                  data, population, weights, metric),
            popsize=40, mutation=0.8, recombination=0.9,
            updating='deferred', polish=True, workers=-1, seed=0
        )
        
        if verbose:
            print('Value of objective function:', optim_res.fun)
            print('Number of evaluations of the objective function:', optim_res.nfev)
            print('Number of iterations performed by the optimizer:', optim_res.nit)
            print('Optimizer exited successfully:', optim_res.success)
            print('Message:', optim_res.message)
            print()
        
        best_params = dict(zip(optim_params_names, optim_res.x))
        best_params.update(fixed_params_dict)
        
    elif method == 'brute':
        x_out, fval, _, _ = brute(
            _loss_func_helper, ranges=optim_params_bounds,
            args=(optim_params_names, fixed_params_dict,
                  data, population, weights, metric),
            Ns=Ns, full_output=True, finish=None, workers=-1)
        
        if verbose:
            print('Value of objective function:', fval)
        
        best_params = dict(zip(optim_params_names, x_out))
        best_params.update(fixed_params_dict)
    
    return best_params

### Função para plotar predições do modelo

In [15]:
def plot_model(solution, data, title=''):
    sus, exp, inf, res = solution
    pred_confirmed = inf+res
    million = 1000000
    
    f = plt.figure(figsize=(16,5))
    
    # Subplot 1
    ax = f.add_subplot(1,2,1)
    ax.plot(sus/million, 'b', label='Susceptible')
    ax.plot(exp/million, 'salmon', label='Exposed')
    ax.plot(inf/million, 'magenta', label='Infected')
    ax.plot(res/million, 'green', label='Resistant')
    
    plt.title('SEIR model')
    plt.xlabel("Days", fontsize=10)
    plt.ylabel("Population (millions)", fontsize=10)
    plt.legend(loc='best')
    
    # Subplot 2
    ax2 = f.add_subplot(1,2,2)
    ax2.plot(range(len(data)), pred_confirmed[:len(data)], 'cyan', label='Predicted confirmed cases')
    ax2.plot(range(len(data)), data['confirmed'].values, 'black', label='True confirmed cases')
    
    plt.title('Model prediction and real data')
    plt.ylabel("Population", fontsize=10)
    plt.xlabel("Days", fontsize=10)
    plt.legend(loc='best')
    
    plt.suptitle(title)
    plt.show()

In [16]:
a = x_usa.iloc[0:1, :].filter(regex='^confirmed', axis=1).values.flatten()
b = x_usa.iloc[1:, -2].values
c = np.concatenate([a,b])
confirmed_cols = ['confirmed_t{}'.format(x) for x in range(-len(c)+1, 1, 1)]
cols = ['region']
cols.extend(confirmed_cols)
cols.extend(['infection_days_t0'])
region = x_usa.iloc[0,0]
inf_days = x_usa.iloc[-1,-1]

vals = [region]
vals.extend(c)
vals.extend([inf_days])

# d = np.concatenate([[region], c, [inf_days]])
# d = np.reshape(d, (1,-1))
df_dict = dict(zip(cols, vals))

x_usa_long = pd.DataFrame(df_dict, index=[0])
x_usa_long

Unnamed: 0,region,confirmed_t-68,confirmed_t-67,confirmed_t-66,confirmed_t-65,confirmed_t-64,confirmed_t-63,confirmed_t-62,confirmed_t-61,confirmed_t-60,...,confirmed_t-8,confirmed_t-7,confirmed_t-6,confirmed_t-5,confirmed_t-4,confirmed_t-3,confirmed_t-2,confirmed_t-1,confirmed_t0,infection_days_t0
0,US,1,1,2,2,5,5,5,5,5,...,33276,43843,53736,65778,83836,101657,121465,140909,161831,69


In [17]:
a = x_brazil.iloc[0:1, :].filter(regex='^confirmed', axis=1).values.flatten()
b = x_brazil.iloc[1:, -2].values
c = np.concatenate([a,b])
confirmed_cols = ['confirmed_t{}'.format(x) for x in range(-len(c)+1, 1, 1)]
cols = ['region']
cols.extend(confirmed_cols)
cols.extend(['infection_days_t0'])
region = x_brazil.iloc[0,0]
inf_days = x_brazil.iloc[-1,-1]

vals = [region]
vals.extend(c)
vals.extend([inf_days])

# d = np.concatenate([[region], c, [inf_days]])
# d = np.reshape(d, (1,-1))
df_dict = dict(zip(cols, vals))

x_brazil_long = pd.DataFrame(df_dict, index=[0])
x_brazil_long

Unnamed: 0,region,confirmed_t-33,confirmed_t-32,confirmed_t-31,confirmed_t-30,confirmed_t-29,confirmed_t-28,confirmed_t-27,confirmed_t-26,confirmed_t-25,...,confirmed_t-8,confirmed_t-7,confirmed_t-6,confirmed_t-5,confirmed_t-4,confirmed_t-3,confirmed_t-2,confirmed_t-1,confirmed_t0,infection_days_t0
0,Brazil,1,1,1,2,2,2,2,4,4,...,1546,1924,2247,2554,2985,3417,3904,4256,4579,34


In [18]:
import datetime

In [19]:
# country_data = get_country_data('Brazil')
# init_infect = country_data['confirmed'].iloc[0]

# Brazil
pop=210147125

# USA
# pop=330744872

# South Korea
# pop=51263393

begin = datetime.datetime.now()

model = SEIR()

model.fit(data=x_brazil, pop=pop)

end = datetime.datetime.now()

params = model.params
model_bounds = model.bounds

print('Bounds:', model_bounds)
print('Params:', params)
print('Time:', end-begin)
print()

loss = model.evaluate(data=x_brazil, pop=pop)
print(loss)

Bounds: {'R_0': (0.1, 8), 'T_inc': (0.1, 8), 'T_inf': (0.1, 15), 'N_inf': (1, 100)}
Params: {'R_0': 2.375904085947246, 'T_inc': 0.3116838334066516, 'T_inf': 6.807167280472915, 'N_inf': 8.043504084551898}
Time: 0:01:41.308117

15012.210884353743


In [20]:
y_pred_df = model.predict(data=x_brazil, pop=pop, look_forward=20)
display(y_pred_df)

Unnamed: 0,region,confirmed_t+1,confirmed_t+2,confirmed_t+3,confirmed_t+4,confirmed_t+5,confirmed_t+6,confirmed_t+7,confirmed_t+8,confirmed_t+9,...,confirmed_t+11,confirmed_t+12,confirmed_t+13,confirmed_t+14,confirmed_t+15,confirmed_t+16,confirmed_t+17,confirmed_t+18,confirmed_t+19,confirmed_t+20
0,Brazil,147.0,173.0,200.0,260.0,324.0,365.0,442.0,517.0,602.0,...,1134.0,1394.0,1698.0,2092.0,2513.0,2822.0,3219.0,3784.0,4522.0,5491.0
1,Brazil,173.0,200.0,260.0,324.0,365.0,442.0,517.0,602.0,829.0,...,1394.0,1698.0,2092.0,2513.0,2822.0,3219.0,3784.0,4522.0,5491.0,6639.0
2,Brazil,200.0,260.0,324.0,365.0,442.0,517.0,602.0,829.0,1134.0,...,1698.0,2092.0,2513.0,2822.0,3219.0,3784.0,4522.0,5491.0,6639.0,8011.0
3,Brazil,260.0,324.0,365.0,442.0,517.0,602.0,829.0,1134.0,1394.0,...,2092.0,2513.0,2822.0,3219.0,3784.0,4522.0,5491.0,6639.0,8011.0,9641.0
4,Brazil,324.0,365.0,442.0,517.0,602.0,829.0,1134.0,1394.0,1698.0,...,2513.0,2822.0,3219.0,3784.0,4522.0,5491.0,6639.0,8011.0,9641.0,11578.0
5,Brazil,365.0,442.0,517.0,602.0,829.0,1134.0,1394.0,1698.0,2092.0,...,2822.0,3219.0,3784.0,4522.0,5491.0,6639.0,8011.0,9641.0,11578.0,13942.0
6,Brazil,442.0,517.0,602.0,829.0,1134.0,1394.0,1698.0,2092.0,2513.0,...,3219.0,3784.0,4522.0,5491.0,6639.0,8011.0,9641.0,11578.0,13942.0,16859.0
7,Brazil,517.0,602.0,829.0,1134.0,1394.0,1698.0,2092.0,2513.0,2822.0,...,3784.0,4522.0,5491.0,6639.0,8011.0,9641.0,11578.0,13942.0,16859.0,20283.0
8,Brazil,602.0,829.0,1134.0,1394.0,1698.0,2092.0,2513.0,2822.0,3219.0,...,4522.0,5491.0,6639.0,8011.0,9641.0,11578.0,13942.0,16859.0,20283.0,24334.0
9,Brazil,829.0,1134.0,1394.0,1698.0,2092.0,2513.0,2822.0,3219.0,3784.0,...,5491.0,6639.0,8011.0,9641.0,11578.0,13942.0,16859.0,20283.0,24334.0,29263.0


In [21]:
y_brazil

Unnamed: 0,region,confirmed_t+1,confirmed_t+2,confirmed_t+3,confirmed_t+4,confirmed_t+5,confirmed_t+6,confirmed_t+7,confirmed_t+8,confirmed_t+9,...,confirmed_t+11,confirmed_t+12,confirmed_t+13,confirmed_t+14,confirmed_t+15,confirmed_t+16,confirmed_t+17,confirmed_t+18,confirmed_t+19,confirmed_t+20
443,Brazil,38,52,151,151,162,200,321,372,621,...,1021,1546,1924,2247,2554,2985,3417,3904,4256,4579
444,Brazil,52,151,151,162,200,321,372,621,793,...,1546,1924,2247,2554,2985,3417,3904,4256,4579,5717
445,Brazil,151,151,162,200,321,372,621,793,1021,...,1924,2247,2554,2985,3417,3904,4256,4579,5717,6836
446,Brazil,151,162,200,321,372,621,793,1021,1546,...,2247,2554,2985,3417,3904,4256,4579,5717,6836,8044
447,Brazil,162,200,321,372,621,793,1021,1546,1924,...,2554,2985,3417,3904,4256,4579,5717,6836,8044,9056
448,Brazil,200,321,372,621,793,1021,1546,1924,2247,...,2985,3417,3904,4256,4579,5717,6836,8044,9056,10360
449,Brazil,321,372,621,793,1021,1546,1924,2247,2554,...,3417,3904,4256,4579,5717,6836,8044,9056,10360,11130
450,Brazil,372,621,793,1021,1546,1924,2247,2554,2985,...,3904,4256,4579,5717,6836,8044,9056,10360,11130,12161
451,Brazil,621,793,1021,1546,1924,2247,2554,2985,3417,...,4256,4579,5717,6836,8044,9056,10360,11130,12161,14034
452,Brazil,793,1021,1546,1924,2247,2554,2985,3417,3904,...,4579,5717,6836,8044,9056,10360,11130,12161,14034,16170


In [22]:
stop

NameError: name 'stop' is not defined

In [None]:
true_confirmed = x_brazil_long.filter(regex='^confirmed', axis=1).values.flatten()
true_confirmed

In [None]:
plt.plot(true_confirmed, label='True')
plt.plot(pred_confirmed, label='Pred')
plt.legend()

In [None]:
stop

In [None]:
# plot_model(solution, x_brazil, title='')

In [None]:
x_brazil

In [None]:
stop

In [None]:
solution = SEIR_solve(params,
                        population=pop,
                        n_days=69, verbose=False)

In [None]:
infect_pred = solution[2]
resist_pred = solution[3]
confirm_pred = infect_pred+resist_pred
confirm_pred

In [None]:
window=14
y_pred_confirm = np.vstack( [confirm_pred[(day-window):(day)] for day in x_usa['infection_days_t0']] )

In [None]:
y_true_confirm = x_usa.filter(regex='^confirmed', axis=1).values

In [None]:
mean_squared_error(y_true_confirm, y_pred_confirm)

In [None]:
stop

In [None]:
R_0_bounds = (0.1, 8)
len(R_0_bounds)

In [None]:
country_data = get_country_data('Brazil')
init_infect = country_data['confirmed'].iloc[0]
pop=209300000

R_0_guess = 4
R_0_bounds = (0.1, 8)

T_inf_guess = 7.5
T_inf_bounds = (0.1, 15)

T_inc_guess = 0.1
T_inc_bounds = (0.1, 8)

look_back=14
look_forward=20

sample=0

metrics = pd.DataFrame()

for i in range(look_back, (len(country_data)-look_forward+1)):
    input_data = country_data.iloc[0:i, :]
    target_data = country_data.iloc[i:i+look_forward, :]
    
    params = fit_model(data=input_data, population=pop,
                       init_guess=[R_0_guess, T_inf_guess, T_inc_guess],
                       bounds=[R_0_bounds, T_inf_bounds, T_inc_bounds],
                       eps=0.1, window=10, weighted=False, metric='mape')
    
    R_0, T_inf, T_inc = params
    solution = SEIR_solve(R_0, T_inf, T_inc, pop, init_infect, len(country_data))

    infect_pred = solution[2]
    resist_pred = solution[3]
    y_pred_confirm = infect_pred+resist_pred
    y_pred_confirm = y_pred_confirm[i:i+look_forward]
    
    y_true_confirm = target_data['confirmed'].values
    
    msle = mean_squared_log_error(y_true_confirm, y_pred_confirm)
    metrics.loc[sample, 'MSLE'] = msle
    
    sample=sample+1
    
    print('Sample ', sample)
    print('Params:', params)
    print()

In [None]:
metrics

In [None]:
country_data

In [None]:
stop

# Previsão do modelo com parâmetros fixos

- ### Brasil

In [None]:
country_data = get_country_data('Brazil')
country_data = country_data[country_data['confirmed'] > 0]
init_infect = country_data['confirmed'].iloc[0]
pop=209300000

In [None]:
solution = SEIR_solve(R_0=4.32, T_inf=2.9, T_inc=5.2, population=pop, n_infected=init_infect, n_days=120)

In [None]:
plot_model(solution, country_data, title='Predictions for Brazil')

- ### Itália

In [None]:
country_data = get_country_data('Italy')
country_data = country_data[country_data['confirmed'] > 0]
init_infect = country_data['confirmed'].iloc[0]
pop=60480000

In [None]:
solution = SEIR_solve(R_0=3.05, T_inf=2.9, T_inc=5.2, population=pop, n_infected=init_infect, n_days=120)

In [None]:
plot_model(solution, country_data, title='Predictions for Italy')

- ### Espanha

In [None]:
country_data = get_country_data('Spain')
country_data = country_data[country_data['confirmed'] > 0]
init_infect = country_data['confirmed'].iloc[0]
pop=46660000

In [None]:
solution = SEIR_solve(R_0=3.21, T_inf=2.9, T_inc=5.2, population=pop, n_infected=init_infect, n_days=120)

In [None]:
plot_model(solution, country_data, title='Predictions for Spain')

- ### China

In [None]:
country_data = get_country_data('China')
country_data = country_data[country_data['confirmed'] > 0]
init_infect = country_data['confirmed'].iloc[0]
pop=1386000000

In [None]:
solution = SEIR_solve(R_0=1.57, T_inf=2.9, T_inc=5.2, population=pop, n_infected=init_infect, n_days=120)

In [None]:
plot_model(solution, country_data, title='Predictions for China')

In [None]:
stop

In [None]:
SEIR_evaluate(params=params, T_inf=2.9, T_inc=5.2, data=country_data, population=pop, window=1, weighted=False)

In [None]:
solution = SEIR_solve(R_0=params, T_inf=2.9, T_inc=5.2,
                        population=pop, n_infected=init_infect, n_days=120)

print(solution[2]+solution[3])

In [None]:
plot_model(solution, country_data, title='Predictions for Brazil')