# PySIR - Python implementation of SIR-based models

#### Author: Lucas Vilas Boas Alves <lucas.vbalves@gmail.com>

In [1]:
import pandas as pd
import numpy as np
from numbers import Number
from scipy.optimize import differential_evolution
from scipy.optimize import brute
from scipy.integrate import solve_ivp
from sklearn.metrics import mean_squared_log_error
from sklearn.metrics import mean_squared_error

### Base class for estimators of SIR-based models

In [2]:
class BaseEstimator:
    
    def fit(self, data, bounds=None, **kwargs):
        
        if bounds is None:
            self.bounds = self.model.bounds
        else:
            self.bounds = bounds
        
        self.params = fit_model(self.model, data, bounds=self.bounds, **kwargs)
    
    def evaluate(self, data, **kwargs):
        
        loss = evaluate_model(self.model, self.params, data, **kwargs)
        
        return loss
    
    def solve(self, n_days):
        
        solution = solve_model(self.model, self.params, n_days, verbose=True)
        
        return solution
    
    def predict(self, data, look_forward):
        
        y_pred = self.model.get_outputs(self.params, data, look_forward)
        
        return y_pred

### Estimator class for the SIR model

In [3]:
class SIR(BaseEstimator):
    
    def __init__(self, population):
        
        self.population = population
        self.model = SIR_model(population)

### Estimator class for the SEIR model

In [4]:
class SEIR(BaseEstimator):
    
    def __init__(self, population):
        
        self.population = population
        self.model = SEIR_model(population)

### Class containing the ordinary differential equations (ODEs), attributes and methods of the SIR model

In [5]:
class SIR_model:
    
    def __init__(self, population):
        
        self.population = population
        
        # Bounds dictionary
        self.bounds = {
            'R_t': (0.1, 8),
            'T_inf': (0.1, 15),
            'N_inf': (1, 100)
        }
        
        # Expected keys in bounds dictionary
        self.params_names = list(self.bounds.keys())
        
        # List of compartiments names
        self.compart_names = ['susceptible', 'infected', 'resistant']
        
    # State at time = 0 for SIR model
    def get_init_state(self, params):
        
        N_inf = params.get('N_inf', None)
        
        susceptible_0 = (self.population-N_inf)/self.population
        infected_0 = N_inf/self.population
        resistant_0 = 0
        
        return [susceptible_0, infected_0, resistant_0]
    
    def get_outputs(self, params, data, look_forward=None):
        
        past_days = data['infection_days_t0'].iloc[-1]
        
        if look_forward is None:
            n_days = past_days
        else:
            n_days = past_days + look_forward
        
        solution = solve_model(self, params, n_days)
        
        infect_pred = solution[1]
        resist_pred = solution[2]
        confirm_pred = infect_pred+resist_pred
        
        out = _confirmed_formatter(data, confirm_pred, look_forward)
        
        return out
    
    def ode(self, t, y, params):
        
        R_t = params.get('R_t', None)
        T_inf = params.get('T_inf', None)
        
        if callable(R_t):
            reproduction = R_t(t, params)
        else:
            reproduction = R_t

        S, I, R = y
        
        # Susceptible equation
        S_out = -(R_t / T_inf) * (I * S)
        
        # Infected equation
        I_out = (R_t / T_inf) * (I * S) - (1/T_inf) * I
        
        # Resistant equation
        R_out = (1/T_inf) * I

        return [S_out, I_out, R_out]

### Class containing the ordinary differential equations (ODEs), attributes and methods of the SEIR model

In [6]:
class SEIR_model:
    
    def __init__(self, population):
        
        self.population = population
        
        # Bounds dictionary
        self.bounds = {
            'R_t': (0.1, 8),
            'T_inc': (0.1, 8),
            'T_inf': (0.1, 15),
            'N_inf': (1, 100)
        }
        
        # Expected keys in bounds dictionary
        self.params_names = list(self.bounds.keys())
        
        # List of compartiments names
        self.compart_names = ['susceptible', 'exposed', 'infected', 'resistant']
        
    # State at time = 0 for SEIR model
    def get_init_state(self, params):
        
        N_inf = params.get('N_inf', None)
        
        susceptible_0 = (self.population-N_inf)/self.population
        exposed_0 = 0
        infected_0 = N_inf/self.population
        resistant_0 = 0
        
        return [susceptible_0, exposed_0, infected_0, resistant_0]
    
    def get_outputs(self, params, data, look_forward=None):
        
        past_days = data['infection_days_t0'].iloc[-1]
        
        if look_forward is None:
            n_days = past_days
        else:
            n_days = past_days + look_forward
        
        solution = solve_model(self, params, n_days)
        
        infect_pred = solution[2]
        resist_pred = solution[3]
        confirm_pred = infect_pred+resist_pred
        
        confirm_pred = infect_pred+resist_pred
        
        out = _confirmed_formatter(data, confirm_pred, look_forward)
        
        return out
        
    def ode(self, t, y, params):
        
        R_t = params.get('R_t', None)
        T_inf = params.get('T_inf', None)
        T_inc = params.get('T_inc', None)
        
        if callable(R_t):
            reproduction = R_t(t, params)
        else:
            reproduction = R_t

        S, E, I, R = y
        
        # Susceptible equation
        S_out = -(R_t / T_inf) * (I * S)
        
        # Exposed equation
        E_out = (R_t / T_inf) * (I * S) - (1/T_inc) * E
        
        # Infected equation
        I_out = (1/T_inc) * E - (1/T_inf) * I
        
        # Resistant equation
        R_out = (1/T_inf) * I

        return [S_out, E_out, I_out, R_out]

### Function for solving ordinary differential equations (ODEs) of the model

In [7]:
def solve_model(model, params, n_days, verbose=False):
    
    y0 = model.get_init_state(params)
    
    solution = solve_ivp(model.ode, [0, n_days], y0, method='RK45', max_step=1,
                         args=([params]), t_eval=np.arange(n_days))
    
    if solution.success == False:
        print('WARNING! solve_ivp status {}: {}'.format(solution.status, solution.message))
    
    solution = solution.y*model.population
    solution = np.around(solution, 0)
    
    if verbose:
        # Check for negative values on the predictions
        for pred, pred_name in zip(solution, model.compart_names):
            if not np.all(pred>=0):
                print('WARNING! Negative values on the {} predictions!'.format(pred_name))
    
    return solution

### Function to evaluate model error metrics

In [8]:
def evaluate_model(model, params, data, **kwargs):
    
    # Getting function parameters
    metric = kwargs.get('metric', 'mse')
    weights = kwargs.get('weights', None)
    
    y_true, y_pred = model.get_outputs(params, data)
    
    if metric == 'mse':
        loss = mean_squared_error(y_true, y_pred, weights)
    if metric == 'rmse':
        loss = mean_squared_error(y_true, y_pred, weights, squared=False)
    elif metric == 'msle':
        loss = mean_squared_log_error(y_true, y_pred, weights)
    
    return loss

### Function to adjust model parameters to data

In [9]:
def fit_model(model, data, bounds, **kwargs):
    
    # Getting function parameters
    metric = kwargs.get('metric', 'mse')
    weights = kwargs.get('weights', None)
    method = kwargs.get('method', 'diff_evol')
    Ns = kwargs.get('Ns', 10)
    workers = kwargs.get('workers', -1)
    verbose = kwargs.get('verbose', False)
    
    # Expected keys in bounds dictionary
    params_names = model.params_names
    
    if not (list(bounds.keys()) == params_names):
        raise ValueError('Dict keys must be {}.'.format(params_names))
    
    optim_params_names, optim_params_bounds, fixed_params_dict = _params_parser(bounds, method)
    
    if len(optim_params_bounds) == 0:
        print('No parameters to optimize. Returning fixed parameters.')
        return fixed_params_dict
    
    loss_func_args = [optim_params_names, fixed_params_dict,
                      model, data, weights, metric]
    
    if method == 'diff_evol':
        
        optim_res = differential_evolution(
            _loss_func_wrapper, bounds=optim_params_bounds,
            args=(loss_func_args),
            popsize=40, mutation=0.8, recombination=0.9,
            updating='deferred', polish=True, workers=workers, seed=0
        )
        
        if verbose:
            print('Value of objective function:', optim_res.fun)
            print('Number of evaluations of the objective function:', optim_res.nfev)
            print('Number of iterations performed by the optimizer:', optim_res.nit)
            print('Optimizer exited successfully:', optim_res.success)
            print('Message:', optim_res.message)
            print()
        
        best_params = dict(zip(optim_params_names, optim_res.x))
        best_params.update(fixed_params_dict)
        
    elif method == 'brute':
        x_out, fval, _, _ = brute(
            _loss_func_wrapper, ranges=optim_params_bounds,
            args=(loss_func_args),
            Ns=Ns, full_output=True, finish=None, workers=workers)
        
        if verbose:
            print('Value of objective function:', fval)
        
        best_params = dict(zip(optim_params_names, x_out))
        best_params.update(fixed_params_dict)
    
    return best_params

### Wrapper function to adapt optimization methods outputs and loss evaluation function inputs

In [10]:
def _loss_func_wrapper(params_vals, *args):
    
    params_names = args[0]
    fixed_params_dict = args[1]
    
    model = args[2]
    data = args[3]
    
    weights = args[4]
    metric = args[5]
            
    params = dict(zip(params_names, params_vals))
    params.update(fixed_params_dict)
    loss = evaluate_model(model, params, data, weights=weights, metric=metric)
    return loss

### Function to parse arguments received by the function that fit the model parameters

In [11]:
def _params_parser(bounds, method, from_callable=False):
    
    optim_params_names = list()
    optim_params_bounds = list()
    fixed_params_dict = dict()
    
    # Checking if params are fixed or to optimize
    for param in bounds.keys():
        
        # Param bounds is a number, so it's a fixed param.
        if isinstance(bounds[param], Number):
            fixed_params_dict.update([(param, bounds[param])])
        
        # Param bounds is a sequence (min, max) or (func, func_kwargs)
        elif isinstance(bounds[param], (list, tuple)):
            
            # Can be a param to optimize (min, max) or a time varying function (func, func_kwargs)
            if len(bounds[param]) == 2:
                
                # Param is a time varying function
                if callable(bounds[param][0]):
                    
                    # Add function as a fixed param
                    fixed_params_dict.update([(param, bounds[param][0])])
                    
                    # Parse function param
                    callable_params = _params_parser(bounds[param][1], method, from_callable=True)
                    
                    # Add function params to be optmized
                    optim_params_names.extend(callable_params[0])
                    optim_params_bounds.extend(callable_params[1])
                    
                    # Add fixed function params
                    fixed_params_dict.update(callable_params[2])
                
                # Param is to be optmized
                else:
                    optim_params_names.append(param)
                    optim_params_bounds.append(bounds[param])
                    
            else:
                raise ValueError(
                    '''{} bounds has length {}. Bounds must be in the form (min,max) or (func, func_kwargs).
                    '''.format(param, len(bounds[param])),
                )
        
        # Param bounds is a slice (start, stop, step), so it's a param to optimize using 'brute'.
        elif isinstance(bounds[param], slice):
            
            if method == 'brute':
                optim_params_names.append(param)
                optim_params_bounds.append(bounds[param])
            else:
                raise ValueError(
                    '''{} bounds are in a slice object. Slice object must be used with method="brute".
                    '''.format(param),
                )
        else:
            # If param is for a callable, add as a fixed param (user defined)
            if from_callable:
                fixed_params_dict.update([(param, bounds[param])])
                
            else:
                raise TypeError(
                    '''{} bounds specified as {}. Bounds must be:
                    - A number, for a fixed parameter;
                    - A sequence in the form (min,max);
                    - A sequence in the form (func, func_kwargs);
                    - A slice object in the form (start, stop, step) if method="brute".
                    '''.format(param, type(bounds[param])),
                )
    
    return (optim_params_names, optim_params_bounds, fixed_params_dict)

### Function to format input data and model predictions for COVID-19 confirmed cases

In [12]:
def _confirmed_formatter(data, confirm_pred, look_forward):
    
    # Generates formatted outputs for the loss evaluation function
    if look_forward is None:
        y_true = data.filter(regex='^confirmed', axis=1).values
        look_back = y_true.shape[1]

        y_pred = np.vstack(
            [confirm_pred[(day-look_back):(day)] for day in data['infection_days_t0']]
        )
        return [y_true, y_pred]
    
    # Generates formatted output for the predictions
    else:
        region = data['region'].values

        y_pred = np.vstack(
            [confirm_pred[(day):(day+look_forward)] for day in data['infection_days_t0']]
        )

        confirmed_cols = ['confirmed_t+{}'.format(x) for x in range(1, look_forward+1)]
        y_pred = pd.DataFrame(y_pred, columns=confirmed_cols)
        y_pred.insert(0, column='region', value=region)

        return y_pred