In [1]:
from IPython.display import HTML

import math
import pandas as pd
import numpy as np

#from scipy.optimize import minimize
#from scipy.optimize import differential_evolution    

from lmfit import Minimizer, Parameters, report_fit

from bokeh.plotting import figure, output_notebook, show
from bokeh.models import LinearAxis, Range1d
from bokeh.models import DatetimeTickFormatter, MonthsTicker, NumeralTickFormatter, Legend


output_notebook()

In [2]:
import sys
sys.path.insert(1,'../covid-website')

import SISV as s
from data import Data
from ContactRate import contact_rate

from maxlik import fit_model, loglik_leastsquare , loglik_poisson, loglik_negbin
from maxlik import piecewiseexp_diffevol

In [3]:
#---------------------------------------------------------------------------------
#transform auxiliary values (which are bounded [0,1] into the breakpoint dates)
def aux_to_breakpoints(aux, x0, xn, minwindow):
    breakpoints = []
    bi1 = x0
    n=len(aux)
    for i,a in enumerate(aux):
        bi = a * (xn-(n-i)*minwindow - (bi1+minwindow)) + bi1+minwindow
        breakpoints.append(bi)
        bi1=bi   
    return breakpoints

#---------------------------------------------------------------------------------
def breakpoints_to_aux(breakpoints, x0, xn, minwindow):
    aux = []
    bi1 = x0
    n=len(breakpoints)
    for i,b in enumerate(breakpoints):
        temp = xn - (n-i)*minwindow - (bi1+minwindow)
        a = 1.0 if temp==0 else (b - bi1 - minwindow) / temp
        aux.append(a)
        bi1=b
    return aux  
    
#aux_to_breakpoints([0.5, 0.5], 0, 100, 10)

#---------------------------------------------------------------------------------
#b = aux_to_breakpoints([0.5,0.5,0.5], 5, 100, 10)
#a = breakpoints_to_aux(b, 5, 100, 10)
#print(b)
#print(a)
#---------------------------------------------------------------------------------


def override(key, overrides, default):
    return default if key not in overrides else overrides[key]

In [4]:
source = 'Johns Hopkins'    
region = 'Europe'
state = 'Belgium'

cutoff_positive = 1
cutoff_death = 1
truncate = 0


d = Data(source=source, region=region, state=state, county="", cutoff_positive=cutoff_positive, cutoff_death=cutoff_death, truncate=truncate) 


In [29]:
params, breakpoints, likelihood, fit  = piecewiseexp_diffevol(d.x[d.minD+1:], d.dfatalities, breaks=4, minwindow=7)

p = figure(title='Piecewise Exponential Growth', plot_width=800, plot_height=600 , y_axis_type="log")
p.y_range.start = 1

#historical data
r0 = p.line(d.xd[d.minD+1:], d.dfatalities, line_width=1, line_color='red', line_dash='dotted', alpha=0.3)
r1 = p.circle(d.xd[d.minD+1:], d.dfatalities, size=5, color="red", alpha=0.3)

#plot 7-day rolling average
rolling = pd.DataFrame(data = d.dfatalities).interpolate().rolling(7).mean()
r2 = p.line(d.xd[d.minD+1:], rolling.loc[:,0].values, line_width=1, line_color='red')

#fit
r3_1 = p.line(d.xd[d.minD+1:], fit, line_width=1, line_color='black', line_dash='solid', alpha=0.7)

print('f0:', params[0])
print('doubling:', f'{math.log(2)/params[1]:.1f}')
for i in range(len(breakpoints)):
    print(f'{breakpoints[i]:.0f}', '\t', f'{math.log(2)/params[i+2]:.1f}')
    

show(p)

f0: 1.6421362821474952
doubling: 3.0
35 	 -15.1
135 	 -223.8
186 	 11.6
254 	 -39.2


In [9]:
breaks = 5
results = []
for i in range(50):
    try:
        params, breakpoints, likelihood, fit  = piecewiseexp_diffevol(d.x[d.minD+1:], d.dfatalities, breaks=breaks, minwindow=7)
        res = {"params": params,
              "breakpoints": breakpoints,
              "likelihood": likelihood,
              "fit": fit
              }
        results.append(res)
    except:
        pass

import pickle
a_file = open("data.pkl", "wb")
pickle.dump(results, a_file)
a_file.close()

In [24]:
import numpy as np
import scipy.special

from bokeh.layouts import gridplot
from bokeh.plotting import figure, output_file, show


def make_plot(title, hist, edges):
    p = figure(title=title, tools='', background_fill_color="#fafafa")
    p.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:],
           fill_color="navy", line_color="white", alpha=0.5)

    return p

val = [res['breakpoints'][0] for res in results]
hist, edges = np.histogram(val, density=True, bins=50)

p1 = make_plot("Likelihood", hist, edges)
show(p1)

In [33]:
p = figure(y_axis_type='log')
p.y_range.start = 1




for r in results:
    p.line(d.x[d.minD+1:], r['fit'])
    
#historical data
r0 = p.line(d.x[d.minD+1:], d.dfatalities, line_width=1, line_color='red', line_dash='dotted', alpha=0.3)
r1 = p.circle(d.x[d.minD+1:], d.dfatalities, size=5, color="red", alpha=0.3)

#plot 7-day rolling average
rolling = pd.DataFrame(data = d.dfatalities).interpolate().rolling(7).mean()
r2 = p.line(d.x[d.minD+1:], rolling.loc[:,0].values, line_width=1, line_color='red')


show(p)

In [8]:
a_file = open("data.pkl", "rb")
output = pickle.load(a_file)
print(output)
a_file.close()

[{'params': array([ 1.589704  ,  0.23118989, -0.054852  , -0.00641203,  0.08508679,
       -0.01918889,  0.00711448]), 'breakpoints': [35.5322668181704, 114.24371246350833, 201.1953805289242, 249.0991151616141, 337.64886558079326], 'likelihood': 1443.0667860234007, 'fit': array([  1.589704  ,   2.00318361,   2.52420864,   3.1807515 ,
         4.00806017,   5.05055057,   6.36419115,   8.01950767,
        10.10536951,  12.73376087,  16.04579286,  20.21927937,
        25.47828342,  32.1051465 ,  40.45564665,  50.9780993 ,
        64.23742601,  80.94548358, 101.99928171, 128.52914097,
       161.95937659, 204.08476605, 257.16690574, 324.05562986,
       408.3420102 , 450.11526299, 426.09047   , 403.34799452,
       381.81939314, 361.43987564, 342.14810995, 323.88603757,
       306.59869888, 290.23406769, 274.74289472, 260.07855936,
       246.19692934, 233.05622795, 220.61690831, 208.84153434,
       197.6946681 , 187.14276314, 177.15406355, 167.69850838,
       158.74764117, 150.27452433,

In [34]:

breakpoint_guess = breakpoints_to_aux(np.add(breakpoints, 0), d.x[0], d.x[-1], 7)
print(breakpoint_guess)

[0.08720848262357278, 0.3201128853183471, 0.22685962602284696, 0.3970362340690371]


In [35]:
#---------------------------------------------------------------------------------
def SISV_lmfit_fixedt(d, overrides, solver, breakpoint_guess):

    #--------------------------------------
    def merge_params(params, constants, solve_ti=False):
    
        p = params.valuesdict()  #params should be a LMFIT Parameters instance; constants should be a dict
        for i, (k,v) in enumerate(constants.items()):
            if k not in p:  #do not override values that may already be in the Parameters array
                p[k]=v
    
        if solve_ti:  #we are solving for time breakpoints
            #calculate "ti" variables from "auxi" variables
            n = p['segments']
            
            aux = []
            for i in range(1, n+1):
                aux.append(params['aux{}'.format(i-1)])
                
            breakpoints = aux_to_breakpoints(aux, d.x[0], d.x[-1], minwindow)
            
            for i in range(1,n+1):
                #p['beta{}'.format(i)] = params[len(param_list)+i-1]
                p['t{}'.format(i)] = breakpoints[i-1]
                
        return p


    
    #--------------------------------------
    def lmfit_inner(params, x, constants, column, data=None):
        p = merge_params(params, constants, solve_ti=True if column==s.cF else False)   #solve for time breakpoints Ti through Auxi variables when using fatalities data
        yhat = s.SISV_J(x, p)
        if data is None:
            return yhat
        else:
            return yhat[:,column] - data


    #--------------------------------------
    #first stage: calibrate initial infectious population and contact rate over time on fatalities data
    #--------------------------------------
    
        
    gamma = override('gamma', overrides, 1/3)
    segments = override('segments', overrides, 7)
    minwindow = override('minwindow', overrides, 7)

    params = Parameters()
    #(name, value, vary, min, max, expr)
    params.add_many( 
                     ('exp_stages',    1, False),
                     ('inf_stages',    1, False),
                     ('crit_stages',   1, False),
                     ('test_stages',   1, False),
                     
                     ('gamma_exp',      gamma, False),
                     ('gamma',          gamma, False),
                     ('gamma_pos',      1/14, False),
                     ('gamma_crit',     1/14, False),
                     
                     ('death_rate',     0.5e-2, False),
                     ('detection_rate', 5e-2, False),

                     ('population',     d.population, False),
                     ('i0',             1, True, 1, 100000),


                     ('immun',          0, False),
                     ('vacc_start',     365, False),
                     ('vacc_rate',      0, False),
                     ('vacc_immun',     1/180, False),
        
                     ('segments',       segments         , False),

                     ('beta0',          2*gamma       , True, 0.01*gamma, 8*gamma),                 
                   )

    for i in range(1, segments+1):  
        params.add('aux{}'.format(i-1),value=breakpoint_guess[i-1], vary=False, min=0.9*breakpoint_guess[i-1], max=1.1*breakpoint_guess[i-1])
        params.add('beta{}'.format(i), value= 0.8*gamma, vary=True, min=0.01*gamma, max=8*gamma)

    #lmfit Parameters cannot accept string values so they get passed in a separate argument
    constants = { 
        'interv':'piecewise linear',
        'init_beta':'',
    }

    for idx, (k,v) in enumerate(overrides.items()):
        if k in params:
            params[k].set(value=v)
        else:
            constants[k]=v
            
    fitter = Minimizer(lmfit_inner, params, fcn_args=(d.x, constants, s.cF, d.fatalities))
    result = fitter.minimize(method=solver)

    p = merge_params(result.params, constants, solve_ti=True)  #merge the calibrated variables into the dictionary of params
    #print(p)    
    
    #-------------------------
    #second stage: calibrate detection rate on positive test results data
    #------------------------

    params = Parameters()
    #(name, value, vary, min, max, expr)
    params.add_many( 
                     ('detection_rate', 15e-2, True, 0, 1),
                   )

    fitter = Minimizer(lmfit_inner, params, fcn_args=(d.x, p, s.cP, d.positives))
    result = fitter.minimize(method=solver)

    p = merge_params(result.params, p, solve_ti=False) #merge the calibrated variables into the dictionary of params




    
    #yy = s.SISV(d.x, p)
    return p
    


In [36]:
%%time

p1 = SISV_lmfit_fixedt(d, solver='leastsq',
                  overrides={
                      'exp_stages'  : 1,
                      'inf_stages'  : 1,
                      'crit_stages' : 1,
                      'test_stages' : 1,
                      'death_rate'  : 0.73e-2,
                      'gamma_exp'   : 1/4,
                      'gamma'       : 1/4,
                      'gamma_crit'  : 1/17, #1/(17+21),
                      'gamma_pos'   : 1/15,
                      'interv'      :'piecewise constant',
                      'init_beta'   : '',
                      'minwindow'   : 7,
                      'segments'    : len(breakpoint_guess),
                  }, breakpoint_guess=breakpoint_guess)



CPU times: user 131 ms, sys: 4 ms, total: 135 ms
Wall time: 153 ms


In [37]:
#---------------------------------------------------------------------------------
def SISV_lmfit(d, overrides, solver, breakpoint_guess):

    #--------------------------------------
    def merge_params(params, constants, solve_ti=False):
    
        p = params.valuesdict()  #params should be a LMFIT Parameters instance; constants should be a dict
        for i, (k,v) in enumerate(constants.items()):
            if k not in p:  #do not override values that may already be in the Parameters array
                p[k]=v
    
        if solve_ti:  #we are solving for time breakpoints
            #calculate "ti" variables from "auxi" variables
            n = p['segments']
            
            aux = []
            for i in range(1, n+1):
                aux.append(params['aux{}'.format(i-1)])
                
            breakpoints = aux_to_breakpoints(aux, d.x[0], d.x[-1], minwindow)
            
            for i in range(1,n+1):
                #p['beta{}'.format(i)] = params[len(param_list)+i-1]
                p['t{}'.format(i)] = breakpoints[i-1]
                
        return p


    #--------------------------------------
    def lmfit_inner_fixedt(params, x, constants, column, data=None):
        #print("--------")
        #print('lmfit_inner_fixedt():') 
        #params.pretty_print()
        p = merge_params(params, constants, solve_ti=True)   #solve for beta
        yhat = s.SISV_J(x, p)
        if data is None:
            return yhat
        else:
            return yhat[:,column] - data
    
    
    def lmfit_inner2(params, x, constants, column, data=None):
        #set up a new solver to find the beta for the given times
        #print("--------")
        #print('lmfit_inner2():')
        #params.pretty_print()
        
        p = params.copy()
        p['i0'].set(vary=True)
        p['beta0'].set(vary=True)
        for i in range(1, segments+1):  
            p['aux{}'.format(i-1)].set(vary=False)
            p['beta{}'.format(i)].set(vary=True)
    
        
        fitter = Minimizer(lmfit_inner_fixedt, p, fcn_args=(d.x, constants, s.cF, d.fatalities))
        result = fitter.minimize(method="leastsq")
    
        return lmfit_inner_fixedt(result.params, d.x, constants, s.cF, data)
    
    #--------------------------------------
    def lmfit_inner(params, x, constants, column, data=None):
        p = merge_params(params, constants, solve_ti=True if column==s.cF else False)   #solve for time breakpoints Ti through Auxi variables when using fatalities data
        yhat = s.SISV_J(x, p)
        if data is None:
            return yhat
        else:
            return yhat[:,column] - data


    #--------------------------------------
    #first stage: calibrate initial infectious population and contact rate over time on fatalities data
    #--------------------------------------
    
        
    gamma = override('gamma', overrides, 1/3)
    segments = override('segments', overrides, 7)
    minwindow = override('minwindow', overrides, 7)

    params = Parameters()
    #(name, value, vary, min, max, expr)
    params.add_many( 
                     ('exp_stages',    1, False),
                     ('inf_stages',    1, False),
                     ('crit_stages',   1, False),
                     ('test_stages',   1, False),
                     
                     ('gamma_exp',      gamma, False),
                     ('gamma',          gamma, False),
                     ('gamma_pos',      1/14, False),
                     ('gamma_crit',     1/14, False),
                     
                     ('death_rate',     0.5e-2, False),
                     ('detection_rate', 5e-2, False),

                     ('population',     d.population, False),
                     ('i0',             1, False, 1, 100000),


                     ('immun',          0, False),
                     ('vacc_start',     365, False),
                     ('vacc_rate',      0, False),
                     ('vacc_immun',     1/180, False),
        
                     ('segments',       segments         , False),

                     ('beta0',          2*gamma       , False, 0.01*gamma, 8*gamma),                 
                   )

    for i in range(1, segments+1):  
        params.add('aux{}'.format(i-1),value=0.1, vary=True, min=breakpoint_guess[i-1]*0.9, max=breakpoint_guess[i-1]*1.1)
        params.add('beta{}'.format(i), value= 0.8*gamma, vary=False, min=0.1*gamma, max=8*gamma)

    #lmfit Parameters cannot accept string values so they get passed in a separate argument
    constants = { 
        'interv':'piecewise linear',
        'init_beta':'',
    }

    for idx, (k,v) in enumerate(overrides.items()):
        if k in params:
            params[k].set(value=v)
        else:
            constants[k]=v
            
    fitter = Minimizer(lmfit_inner2, params, fcn_args=(d.x, constants, s.cF, d.fatalities))
    result = fitter.minimize(method=solver)

    p = merge_params(result.params, constants, solve_ti=True)  #merge the calibrated variables into the dictionary of params
    #print(p)    
    #-------------------------
    #second stage: calibrate detection rate on positive test results data
    #------------------------

    params = Parameters()
    #(name, value, vary, min, max, expr)
    params.add_many( 
                     ('detection_rate', 15e-2, True, 0, 1),
                   )

    fitter = Minimizer(lmfit_inner, params, fcn_args=(d.x, p, s.cP, d.positives))
    result = fitter.minimize(method=solver)

    p = merge_params(result.params, p, solve_ti=False) #merge the calibrated variables into the dictionary of params




    
    #yy = s.SISV(d.x, p)
    return p
    


In [38]:
%%time

p2 = SISV_lmfit(d, solver='differential_evolution',
                  overrides={
                      'exp_stages'  : 1,
                      'inf_stages'  : 1,
                      'crit_stages' : 1,
                      'test_stages' : 1,
                      'death_rate'  : 0.73e-2,
                      'gamma_exp'   : 1/4,
                      'gamma'       : 1/4,
                      'gamma_crit'  : 1/17, #1/(17+21),
                      'gamma_pos'   : 1/15,
                      'interv'      :'piecewise constant',
                      'init_beta'   : '',
                      'minwindow'   : 7,
                      'segments'    : len(breakpoint_guess),
                  }, breakpoint_guess=breakpoint_guess)





CPU times: user 29min 25s, sys: 112 ms, total: 29min 25s
Wall time: 29min 25s


In [39]:
y1 = s.SISV_J(d.x, p1) 
y2 = s.SISV_J(d.x, p2) 

def print_params(title, params):
    print('----------')
    print(title)
    print("detect:", params["detection_rate"])
    print("i0:", params["i0"])
    print("R0:", params['beta0']/params['gamma'])
    for i in range(1, 1+params['segments']):
          print(params['t{}'.format(i)],":", params['beta{}'.format(i)]/params['gamma'])

print_params("p1", p1)

print_params("p2", p2)



p = figure(title='Excess Fatalities', plot_width=800, plot_height=600 , y_axis_type="log")
p.y_range.start = 1

r0 = p.line(d.xd[d.minD+1:], d.dfatalities, line_width=1, line_color='red', line_dash='dotted', alpha=0.3)
r1 = p.circle(d.xd[d.minD+1:], d.dfatalities, size=5, color="red", alpha=0.3)

#plot 7-day rolling average
rolling = pd.DataFrame(data = d.dfatalities).interpolate().rolling(7).mean()
r2 = p.line(d.xd[d.minD+1:], rolling.loc[:,0].values, line_width=1, line_color='red')

r3 = p.line(d.xd[d.minD+1:], fit, line_width=1, line_color='black', line_dash='solid', alpha=0.7)

r4 = p.line(d.xd[d.minD+1:], np.diff(y1[d.minD:,s.cF]), line_width=1, line_color='green', line_dash='solid', alpha=0.7)

r5 = p.line(d.xd[d.minD+1:], np.diff(y2[d.minD:,s.cF]), line_width=1, line_color='blue', line_dash='solid', alpha=0.7)

#-------------

r0_p = p.line(d.xd[d.minP+1:], d.dpositives, line_width=1, line_color='red', line_dash='dotted', alpha=0.3)
r1_p = p.circle(d.xd[d.minP+1:], d.dpositives, size=5, color="red", alpha=0.3)

r4_p = p.line(d.xd[d.minP+1:], np.diff(y1[d.minP:,s.cP]), line_width=1, line_color='green', line_dash='dotted', alpha=0.7)
r5_p = p.line(d.xd[d.minP+1:], np.diff(y2[d.minP:,s.cP]), line_width=1, line_color='blue', line_dash='dotted', alpha=0.7)

legend = [
    ("COVID fatalities"   , [r0, r1]),
    ("COVID 7-day average"   , [r2]),
    
    ("expgrowth", [r3]),
    ("SISV_guess", [r4]),
    ("SISV_fit", [r5]),

#    ("COVID positives"   , [r0_p, r1_p]),
#    ("p1_p", [r3_1_p]),
]

    
p.add_layout(Legend(items=legend, location='center'), 'right')
   
p.yaxis[0].formatter = NumeralTickFormatter(format="0,0")

p.xaxis.ticker = MonthsTicker(months=list(range(1,13)))

p.xaxis.formatter=DatetimeTickFormatter(
        hours=["%d %B %Y"],
        days=["%d %B %Y"],
        months=["%d %B %Y"],
        years=["%d %B %Y"],
    )

p.xgrid.ticker = p.xaxis.ticker

p.xaxis.major_label_orientation = math.pi/4


p.extra_x_ranges['x2'] = Range1d(d.x[0], d.x[-1])
ax2 = LinearAxis(x_range_name="x2", axis_label="days")
p.add_layout(ax2, 'below')


p.ygrid.grid_line_color = 'navy'
p.ygrid.grid_line_alpha = 0.3
p.ygrid.minor_grid_line_color = 'navy'
p.ygrid.minor_grid_line_alpha = 0.1

p.yaxis[0].axis_label = 'Number of deaths per day'

#p.toolbar_location = None
show(p)

rplot = figure(title='R0')
R0_1 = contact_rate(d.x, p1) / p1['gamma']
rplot.line(d.x, R0_1, line_width=1, line_color='black', line_dash='solid', alpha=1)
show(rplot)


----------
p1
detect: 0.9999999955635523
i0: 209.2266682486937
R0: 2.000015732129831
34.732297474296146 : 0.8000122889188036
134.65072924452687 : 0.8002687852047449
186.42131100399507 : 0.8012424806215785
254.0005790821597 : 0.8131063652499033
----------
p2
detect: 1.0
i0: 1
R0: 2.0
32.66083241958919 : 0.8
124.49211692632124 : 0.8
174.0132328795381 : 0.8
239.99081176848853 : 0.8
