# PROJECT FOR DATA SCIENCE THEORY AND PRACTICE
## EPIDEMIOLOGICAL MODELS 
## STOCKHOLM UNIVERSITY


#### Importing libraries

In [None]:
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
%matplotlib inline 
!pip install mpld3
import mpld3
mpld3.enable_notebook()

from scipy.integrate import odeint
!pip install lmfit
import lmfit
from lmfit.lineshapes import gaussian, lorentzian

import warnings
warnings.filterwarnings('ignore')


## Definition of the plot function

In [251]:
plt.gcf().subplots_adjust(bottom=0.15)

def plot_it(t,S,E,I,R,D=None,R0=None,x_ticks=None,Q=None,Alpha=None,CFR=None):
    f, ax = plt.subplots(1,1,figsize=(10,4))
    
    print("percentage of Infected", (int(np.amax(I))/N_sv)*100)
    if x_ticks is None:
        ax.plot(t, S, 'b', alpha=0.7, linewidth=2, label='Susceptible')
        ax.plot(t, E, 'y', alpha=0.7, linewidth=2, label='Exposed')
        ax.plot(t, I, 'r', alpha=0.7, linewidth=2, label='Infected')
        ax.plot(t, R, 'g', alpha=0.7, linewidth=2, label='Recovered')


        if D is not None:

            print("percentage of Deceased", (int(np.amax(D))/N_sv)*100)

            ax.plot(t, D, 'k', alpha=0.7, linewidth=2, label='Dead')
            #plotting the total number of SEIRD for a sanity check. It should equal the total number of the population N and 
            # the blue dotted line should be straight
            ax.plot(t, S+E+I+R+D, 'c--', alpha=0.7, linewidth=2, label='Total')
        elif Q is not None:
            print("percentage of Quarantined", (int(np.amax(Q))/N_sv)*100)
            ax.plot(t, Q, 'm', alpha=0.7, linewidth=2, label='Quarantined')
            ax.plot(t, S+E+I+R+Q, 'c--', alpha=0.7, linewidth=2, label='Total')
        else:
            ax.plot(t, S+E+I+R, 'c--', alpha=0.7, linewidth=2, label='Total')
    else:
        ax.plot(x_ticks, S, 'b', alpha=0.7, linewidth=2, label='Susceptible')
        ax.plot(x_ticks, E, 'y', alpha=0.7, linewidth=2, label='Exposed')
        ax.plot(x_ticks, I, 'r', alpha=0.7, linewidth=2, label='Infected')
        ax.plot(x_ticks, R, 'g', alpha=0.7, linewidth=2, label='Recovered')


        if D is not None:
            ax.plot(x_ticks, D, 'k', alpha=0.7, linewidth=2, label='Dead')
            ax.plot(x_ticks, S+E+I+R+D, 'c--', alpha=0.7, linewidth=2, label='Total')
        elif Q is not None:
            ax.plot(x_ticks, Q, 'm', alpha=0.7, linewidth=2, label='Quarantined')
            ax.plot(x_ticks, S+E+I+R+Q, 'c--', alpha=0.7, linewidth=2, label='Total')
        else:
            ax.plot(x_ticks, S+E+I+R, 'c--', alpha=0.7, linewidth=2, label='Total')
       
    ax.set_xlabel('Time (days)')

    ax.yaxis.set_tick_params(length=0)
    ax.xaxis.set_tick_params(length=0)
    ax.grid(b=True, which='major', c='w', lw=2, ls='-')
    legend = ax.legend(borderpad=2.0)
    legend.get_frame().set_alpha(0.5)
    for spine in ('top', 'right', 'bottom', 'left'):
      
        ax.spines[spine].set_visible(False)


    if R0 is not None or CFR is not None:
        f = plt.figure(figsize=(12,4))
    if R0 is not None:
       
        ax1 = f.add_subplot(121)
        ax1.plot(t, R0, 'b--', alpha=0.7, linewidth=2, label='R_0')

        ax1.set_xlabel('Time (days)')
        ax1.title.set_text('R_0 over time')
        
        ax1.yaxis.set_tick_params(length=0)
        ax1.xaxis.set_tick_params(length=0)
        ax1.grid(b=True, which='major', c='w', lw=2, ls='-')
        legend = ax1.legend()
        legend.get_frame().set_alpha(0.5)
        for spine in ('top', 'right', 'bottom', 'left'):
            ax.spines[spine].set_visible(False)
        plt.gcf().subplots_adjust(bottom=0.15)

    
    month_starts = [1,32,61,92,122,153,183,214,245,275,306,336]
    month_names = ['Feb','Mar','Apr','May','Jun',
               'Jul','Aug','Sep','Oct','Nov','Dec','Jan'] 

    ax.set_xticks(month_starts)
    ax.set_xticklabels(month_names)

    plt.show();
    if Alpha is not None:
        
        ax2 = f.add_subplot(122)
        ax2.plot(t, Alpha, 'r--', alpha=0.7, linewidth=2, label='alpha')

        ax2.set_xlabel('Time (days)')
        ax2.title.set_text('fatality rate over time')
        
        ax2.yaxis.set_tick_params(length=0)
        ax2.xaxis.set_tick_params(length=0)
        ax2.grid(b=True, which='major', c='w', lw=2, ls='-')
        legend = ax2.legend()
        legend.get_frame().set_alpha(0.5)
        for spine in ('top', 'right', 'bottom', 'left'):
            ax.spines[spine].set_visible(False)
        plt.gcf().autofmt_xdate()

   

## Reading Data and preprocessing

In [None]:
#You can find the data from here:  https://data.humdata.org/dataset/novel-coronavirus-2019-ncov-cases
covid_data_confirmed = pd.read_csv(".../time_series_covid19_confirmed_global.csv")
covid_data_deaths = pd.read_csv(".../time_series_covid19_deaths_global.csv")


covid_data_confirmed = (covid_data_confirmed.set_index(["Province/State", "Country/Region", "Lat", "Long"])
         .stack()
         .reset_index(name='Confirmed')
         .rename(columns={'level_2':'Date'}))

covid_data_confirmed.rename(columns={'level_4':'Date'}, inplace=True)
covid_data_confirmed['Date'] = pd.to_datetime(covid_data_confirmed['Date'])
covid_data_confirmed= covid_data_confirmed.groupby('Country/Region').apply(lambda x: x.sort_values(by = ['Date'], ascending =False))
covid_data_confirmed.set_index("Country/Region",inplace=True)
covid_data_confirmed.reset_index(inplace=True)
Confirmed = covid_data_confirmed.Confirmed


covid_data_deaths = (covid_data_deaths.set_index(["Province/State", "Country/Region", "Lat", "Long"])
         .stack()
         .reset_index(name='Deaths')
         .rename(columns={'level_2':'Date'}))
covid_data_deaths.rename(columns={'level_4':'Date'}, inplace=True)
covid_data_deaths['Date'] = pd.to_datetime(covid_data_deaths['Date'])
covid_data_deaths= covid_data_deaths.groupby('Country/Region').apply(lambda x: x.sort_values(by = ['Date'], ascending =False))
covid_data_deaths.set_index("Country/Region",inplace=True)
covid_data_deaths.reset_index(inplace=True)


covid = pd.concat([covid_data_deaths, Confirmed], axis=1, sort=True)
covid.rename(columns={'Country/Region': 'Location'}, inplace=True)


In [None]:
dataItaly = covid[covid["Location"] == "Italy"][["Deaths", "Confirmed", "Date"]]
dataSweden = covid[covid["Location"] == "Sweden"][["Deaths",  "Confirmed", "Date"]]

In [None]:
#statistics about sweden age groups and beds: https://www.scb.se/en/finding-statistics/statistics-by-subject-area/population/population-composition/population-statistics/


agegroupsSV = [1194000., 1127000., 1277000., 1320000., 1264000., 1296000.,
       1094000.,  994000.,  432000.,   98000.]
agegroupsSV = pd.DataFrame(agegroupsSV)
agegroupsSV = agegroupsSV.T
agegroupsSV.columns = ['0-10', '10-20', '20-30', '30-40', '40-50', '50-60', '60-70', '70-80', '80-90','90-100']
agegroupsSV.index = ['population']

In [None]:
    #population Sweden
    N1 = agegroupsSV.iloc[:,0:2].sum(axis=1)[0]
    N2 = agegroupsSV.iloc[:,2:6].sum(axis=1)[0]
    N3 = agegroupsSV.iloc[:,6:10].sum(axis=1)[0]
    N_sv = agegroupsSV.sum(axis=1)[0]

In [None]:
N_it = 60360000 #population Italy

In [None]:
plt.gcf().subplots_adjust(bottom=0.15)
f, ax = plt.subplots(1,1,figsize=(8,4))

#ax.plot(dataItaly.groupby(["Date"]).sum()[["Deaths"]]/N_it, 'r', alpha=0.7, linewidth=2, label='Italy_Fatalities')
#ax.plot(dataItaly.groupby(["Date"]).sum()[["Confirmed"]]/N_it, 'g', alpha=0.7, linewidth=2, label='Italy_Confirmed')

#ax.plot(dataGreece.groupby(["Date"]).sum()[["Confirmed"]], 'g', alpha=0.7, linewidth=2, label='Greece_Confirmed')
ax.plot(dataSweden.groupby(["Date"]).sum()[["Deaths"]]/N_sv, 'b', alpha=0.7, linewidth=2, label='Sweden_Fatalities')
ax.plot(dataSweden.groupby(["Date"]).sum()[["Confirmed"]]/N_sv, 'y', alpha=0.7, linewidth=2, label='Sweden_Confirmed')

#ax.plot(dataSweden.groupby(["Date"]).sum()[["Confirmed"]], 'k', alpha=0.7, linewidth=2, label='Sweden_Confirmed')
ax.set_yscale('linear')
plt.style.use('fivethirtyeight')

ax.title.set_text('Sweden Fatalities and Total Cases')
ax.grid(b=True, which='major', c='w', lw=2, ls='-')
legend = ax.legend()
legend.get_frame().set_alpha(0.5)
for spine in ('top', 'right', 'bottom', 'left'):
    ax.spines[spine].set_visible(False)
plt.gcf().autofmt_xdate()

plt.show();

# SEIRD

In [None]:
def deriv_seird(y,t,beta,gamma,N,delta,alpha,rho):
    S,E,I,R, D = y

    dSdt = -beta(t) * S * I / N
    dEdt = beta(t) * S * I / N - delta * E
    dIdt = delta * E - (1 - alpha) * gamma * I - alpha * rho * I
    dRdt = (1 - alpha) * gamma * I#a: fatality rate, gamma:proportion of infected recovering per day 
    dDdt = alpha * rho * I#rate at which people die
    
    return dSdt, dEdt, dIdt, dRdt, dDdt

In [None]:
import datetime

D = 10.0 # infections lasts four days
gamma = 1.0 / D
delta = 1.0 / 5.0  # incubation period of five days

rho = 1/9  # 9 days from infection until death
alpha = 0.01  #according to: https://www.folkhalsomyndigheten.se/contentassets/53c0dc391be54f5d959ead9131edb771/infection-fatality-rate-covid-19-stockholm-technical-report.pdf
L=60 #starting day of lockdown

def Model(days):

#beta = R0*gamma
    def R_0(t):
        return 5.7 if t < L else 2.5
    def beta(t):
        return R_0(t) * gamma
    #total population of Sweden
    N = agegroupsSV.sum(axis=1)[0]

    #initial values: one infected in the beginning 
    y0 = N-1.0, 1.0, 0.0, 0.0, 0.0
    t = np.linspace(0, days-1, days)
    
    ret = odeint(deriv_seird, y0, t, args=(beta, gamma, N, delta, alpha, rho))
    S, E, I, R, D = ret.T
    R_0_over_time=None
    
    
    return t, S, E, I, R, D, R_0_over_time

#plotting a model with random (close to covid) initial values, x0 is the day of the deepest decrease of R0 , also k is how fast it decreases so i thought to make it like a not as strict lockdown
plot_it(*Model(days=360))


# SEIR WITH QUARANTINE



In [None]:
def deriv_seirdq(y,t,beta,gamma,N,delta,rho, pi, lamdal):
    S,E,I,R,Q = y
    
    dQdt = pi*S - lamdal*Q #pi: quarantine entrance rate, lamda: quarantine exit rate
    dSdt = -beta * S * I / N - pi*S + lamdal*Q
    dEdt = beta * S * I / N - delta * E
    dIdt = delta * E -  gamma * I 
    dRdt = gamma * I
    
    return  dSdt, dEdt, dIdt, dRdt, dQdt

if pi=lamda=0 then we go back to the normal SEIRD


pi quarantine entrance rate


lamda quarantine exit rate

comes from this brazilian paper: https://www.ufpe.br/documents/2744135/2765338/covid_quarentena.pdf/d0115080-ddea-466a-8a0f-22e4304f761a


Quarantines will be characterized by two values: the entrance rate
p and the exit rate λ. p is composed of two terms, γq and ξ. γq is the
average time it takes for a person to enter quarantine and ξ
is a dimensionless multiplicative factor representing the percentage of
individuals that in fact voluntarily quarantine. With this notation


$p = \frac{\xi}{\gamma}$


As an example, suppose that 70% of the population quarantine in an
interval of 2 days. Then p = 0.70/2 = 0.35. It will be assumed that
p ∈ (0.0 , 0.40) . 

p = 0 means that there is no quarantine. it
will be assumed that the time to leave quarantine will between 30 and
60 days; giving that λ ∈ (1/60 , 1/30).



In [None]:
# got the data about the age groups from: 
#age groups: 0-19, 20-59, 60+
N1 = agegroupsSV.iloc[:,0:2].sum(axis=1)[0]
N2 = agegroupsSV.iloc[:,2:6].sum(axis=1)[0]
N3 = agegroupsSV.iloc[:,6:10].sum(axis=1)[0]
N = agegroupsSV.sum(axis=1)[0] 

In [None]:

D = 10# infections lasts four days
gamma = 1.0 / D
delta = 1.0 / 5.0  # incubation period of five days
beta = 0.85 #from fhm



def Model_Q(days, xi, g, time_to_leave_quarantine):
    pi=xi/g
    lamdal = 1/time_to_leave_quarantine
    
    N = agegroupsSV.sum(axis=1)[0]

    
    y0 = N-1.0, 1.0, 0.0, 0.0, 0.0
    t = np.linspace(0, days-1, days)


    ret = odeint(deriv_seirdq, y0, t, args=(beta, gamma, N, delta, rho, pi, lamdal))
    S, E, I, R, Q = ret.T

    
   
    return t, S, E, I, R, None, None, None, Q #the None here is nothing important just to much the parameters of the plotter


In [None]:
plot_it(*Model_Q(days=360, xi=0.5, g=15, time_to_leave_quarantine=30))
#random model with initial parameters


# SEIR AGE STRUCTURED

In [None]:
def deriv_seir_age_2(y,t,beta1_1,beta1_2,beta1_3,beta2_2,beta2_3,beta3_3,gamma,N1,N2,N3,delta,rho):
    S1, E1, I1, R1,  S2, E2, I2, R2, S3, E3, I3, R3= y

    def force_of_infection(t, beta1,beta2,beta3):
        return (beta1*I1) + (beta2*I2) + (beta3*I3)


    

    dS1dt = -force_of_infection(t, beta1_1, beta1_2, beta1_3)/N * S1
    dS2dt = -force_of_infection(t, beta1_2, beta2_2, beta2_3)/N * S2
    dS3dt = -force_of_infection(t, beta1_3, beta2_3, beta3_3)/N * S3

    dE1dt = force_of_infection(t, beta1_1, beta1_2, beta1_3)/N * S1 - delta * E1
    dE2dt = force_of_infection(t, beta1_2, beta2_2, beta2_3)/N * S2 - delta * E2
    dE3dt = force_of_infection(t, beta1_3, beta2_3, beta3_3)/N * S3 - delta * E3

    dI1dt = delta * E1 - gamma * I1
    dI2dt = delta * E2 - gamma * I2 
    dI3dt = delta * E3 - gamma * I3 

    dR1dt = gamma * I1
    dR2dt = gamma * I2
    dR3dt = gamma * I3

    
    
    return dS1dt, dE1dt, dI1dt, dR1dt, dS2dt, dE2dt, dI2dt, dR2dt, dS3dt, dE3dt, dI3dt, dR3dt
    
    
    

In [None]:

D = 10.0 # infections lasts four days
gamma = 1.0 / D
delta = 1.0 / 5.0  # incubation period of five days


alpha_opt = 0.01


def Model_seir_age_2(days, beta1_1,beta1_2, beta1_3, beta2_2, beta2_3,beta3_3):


    N1 = agegroupsSV.iloc[:,0:2].sum(axis=1)[0]
    N2 = agegroupsSV.iloc[:,2:6].sum(axis=1)[0]
    N3 = agegroupsSV.iloc[:,6:10].sum(axis=1)[0]
    N = agegroupsSV.sum(axis=1)[0]

    
     
    y0 = N1-0.0, 0.0, 0.0, 0.0,  N2-1.0, 1.0, 0.0, 0.0,  N3-0.0, 0.0, 0.0, 0.0

    t = np.linspace(0, days-1, days)

    #alpha_by_agegroup = {"0-19": 0.01, "20-59": 0.09, "60-100": 0.4}
    #proportion_of_agegroup = {"0-19": N1/N, "20-59": N2/N, "60-100": N3/N }
    #s = 0.01
    #alpha_opt = sum(alpha_by_agegroup[i] * proportion_of_agegroup[i] for i in list(alpha_by_agegroup.keys()))


    ret = odeint(deriv_seir_age_2, y0, t, args=(beta1_1,beta1_2,beta1_3,beta2_2,beta2_3,beta3_3,gamma,N1,N2,N3,delta,rho))
    S1, E1, I1, R1, S2, E2, I2, R2, S3, E3, I3, R3  = ret.T

    


    S, E, I, R = S1+S2+S3, E1+E2+E3, I1+I2+I3, R1+R2+R3

    return t, S, E, I, R


plot_it(*Model_seir_age_2(days=360, beta1_1 = 0.84, beta1_2 = 0.57, beta1_3 = 0.2, beta2_2 = 0.638, beta2_3 = 0.459, beta3_3 = 0.5734 ))



# INCLUDE CRITICAL CASES AND FATALITIES




In [None]:
plt.gcf().subplots_adjust(bottom=0.15)

def plotter(t, S, E, I, C, R, D, R_0, B, S_1=None, S_2=None, x_ticks=None):
    print("percentage of Infected", (int(np.amax(I))/N)*100, ",", int(np.amax(I)) )
    print("percentage of Deceased", (int(np.amax(D))/N)*100, ",", int(np.amax(D)) )
    print("percentage of Critical", (int(np.amax(C))/N)*100, ",", int(np.amax(C)) )
    if S_1 is not None and S_2 is not None:
      print(f"percentage going to ICU: {S_1*100}; percentage dying in ICU: {S_2 * 100}")


    f, ax = plt.subplots(1,1,figsize=(10,4))
    if x_ticks is None:
        ax.plot(t, S, 'b', alpha=0.7, linewidth=2, label='Susceptible')
        ax.plot(t, E, 'y', alpha=0.7, linewidth=2, label='Exposed')
        ax.plot(t, I, 'r', alpha=0.7, linewidth=2, label='Infected')
        ax.plot(t, C, 'r--', alpha=0.7, linewidth=2, label='Critical')
        ax.plot(t, R, 'g', alpha=0.7, linewidth=2, label='Recovered')
        ax.plot(t, D, 'k', alpha=0.7, linewidth=2, label='Dead')
        ax.set_yscale('linear')
    else:
        ax.plot(x_ticks, S, 'b', alpha=0.7, linewidth=2, label='Susceptible')
        ax.plot(x_ticks, E, 'y', alpha=0.7, linewidth=2, label='Exposed')
        ax.plot(x_ticks, I, 'r', alpha=0.7, linewidth=2, label='Infected')
        ax.plot(x_ticks, C, 'r--', alpha=0.7, linewidth=2, label='Critical')
        ax.plot(x_ticks, R, 'g', alpha=0.7, linewidth=2, label='Recovered')
        ax.plot(x_ticks, D, 'k', alpha=0.7, linewidth=2, label='Dead')
        ax.set_yscale('linear')

        ax.xaxis.set_major_locator(mdates.YearLocator())
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax.xaxis.set_minor_locator(mdates.MonthLocator())
        f.autofmt_xdate()


    ax.title.set_text('extended SEIR-Model')

    ax.grid(b=True, which='major', c='w', lw=2, ls='-')
    legend = ax.legend()
    legend.get_frame().set_alpha(0.5)
    for spine in ('top', 'right', 'bottom', 'left'):
        ax.spines[spine].set_visible(False)

    month_starts = [1,32,61,92,122,153,183,214,245,275,306,336]
    month_names = ['Jan','Feb','Mar','Apr','May','Jun',
               'Jul','Aug','Sep','Oct','Nov','Dec']
    ax.set_xticks(month_starts)
    ax.set_xticklabels(month_names)

    plt.show();
    
    f = plt.figure(figsize=(10,4))
    # sp1
    ax1 = f.add_subplot(131)
    if x_ticks is None:
        ax1.plot(t, R_0, 'b--', alpha=0.7, linewidth=2, label='R_0')
    else:
        ax1.plot(x_ticks, R_0, 'b--', alpha=0.7, linewidth=2, label='R_0')
        ax1.xaxis.set_major_locator(mdates.YearLocator())
        ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax1.xaxis.set_minor_locator(mdates.MonthLocator())
        f.autofmt_xdate()

 
    ax1.title.set_text('R_0 over time')
    ax1.grid(b=True, which='major', c='w', lw=2, ls='-')
    legend = ax1.legend()
    legend.get_frame().set_alpha(0.5)
    for spine in ('top', 'right', 'bottom', 'left'):
        ax.spines[spine].set_visible(False)
    
    # sp2
    ax2 = f.add_subplot(132)
    total_CFR = [0] + [100 * D[i] / sum(sigma*E[:i]) if sum(sigma*E[:i])>0 else 0 for i in range(1, len(t))]
    daily_CFR = [0] + [100 * ((D[i]-D[i-1]) / ((R[i]-R[i-1]) + (D[i]-D[i-1]))) if max((R[i]-R[i-1]), (D[i]-D[i-1]))>10 else 0 for i in range(1, len(t))]
    if x_ticks is None:
        ax2.plot(t, total_CFR, 'r--', alpha=0.7, linewidth=2, label='total')
        ax2.plot(t, daily_CFR, 'b--', alpha=0.7, linewidth=2, label='daily')
    else:
        ax2.plot(x_ticks, total_CFR, 'r--', alpha=0.7, linewidth=2, label='total')
        ax2.plot(x_ticks, daily_CFR, 'b--', alpha=0.7, linewidth=2, label='daily')
        ax2.xaxis.set_major_locator(mdates.YearLocator())
        ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax2.xaxis.set_minor_locator(mdates.MonthLocator())
        f.autofmt_xdate()

    ax2.title.set_text('Fatality Rate (%)')
    ax2.grid(b=True, which='major', c='w', lw=2, ls='-')
    legend = ax2.legend()
    legend.get_frame().set_alpha(0.5)
    for spine in ('top', 'right', 'bottom', 'left'):
        ax.spines[spine].set_visible(False)

    # sp3
    ax3 = f.add_subplot(133)
    newDs = [0] + [D[i]-D[i-1] for i in range(1, len(t))]
    if x_ticks is None:
        ax3.plot(t, newDs, 'r--', alpha=0.7, linewidth=2, label='total')
        ax3.plot(t, [max(0, C[i]-B(i)) for i in range(len(t))], 'b--', alpha=0.7, linewidth=2, label="over capacity")
    else:
        ax3.plot(x_ticks, newDs, 'r--', alpha=0.7, linewidth=2, label='total')
        ax3.plot(x_ticks, [max(0, C[i]-B(i)) for i in range(len(t))], 'b--', alpha=0.7, linewidth=2, label="over capacity")
        ax3.xaxis.set_major_locator(mdates.YearLocator())
        ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
        ax3.xaxis.set_minor_locator(mdates.MonthLocator())
        f.autofmt_xdate()

    ax3.title.set_text('Deaths per day')
    ax3.yaxis.set_tick_params(length=0)
    ax3.xaxis.set_tick_params(length=0)
    ax3.grid(b=True, which='major', c='w', lw=2, ls='-')
    legend = ax3.legend()
    legend.get_frame().set_alpha(0.5)
    for spine in ('top', 'right', 'bottom', 'left'):
        ax.spines[spine].set_visible(False)
    
    month_starts = [1,32,61,92,122,153,183,214,245,275,306,336]
    month_names = ['Jan','Feb','Mar','Apr','May','Jun',
               'Jul','Aug','Sep','Oct','Nov','Dec']
    ax.set_xticks(month_starts)
    ax.set_xticklabels(month_names)

    plt.show();plt.gcf().subplots_adjust(bottom=0.15)


In [None]:
def deriv(y, t, beta, gamma, sigma, N, p_I_to_C, p_C_to_D, Beds):
    S, E, I, C, R, D = y

    dSdt = -beta(t) * I * S / N
    dEdt = beta(t) * I * S / N - sigma * E
    dIdt = sigma * E - 1/12.0 * p_I_to_C * I - gamma * (1 - p_I_to_C) * I
    dCdt = 1/12.0 * p_I_to_C * I - 1/7.5 * p_C_to_D * min(Beds(t), C) - max(0, C-Beds(t)) - (1 - p_C_to_D) * 1/6.5 * min(Beds(t), C)
    dRdt = gamma * (1 - p_I_to_C) * I + (1 - p_C_to_D) * 1/6.5 * min(Beds(t), C)
    dDdt = 1/7.5 * p_C_to_D * min(Beds(t), C) + max(0, C-Beds(t))
    return dSdt, dEdt, dIdt, dCdt, dRdt, dDdt

In [None]:
#model from: https://towardsdatascience.com/infectious-disease-modelling-beyond-the-basic-sir-model-216369c584c4
D = 10.0 # infections lasts four days
gamma = 1.0 / D
sigma = 1.0/3.0

def logistic_R_0(t, R_0_start, k, x0, R_0_end):
    return (R_0_start-R_0_end) / (1 + np.exp(-k*(-t+x0))) + R_0_end

def Model_extended(days, agegroups, beds_per_100k, R_0_start, k, x0, R_0_end, prob_I_to_C, prob_C_to_D, s):

    def beta(t):
        return logistic_R_0(t, R_0_start, k, x0, R_0_end) * gamma

    N = sum(agegroups)
    
    def Beds(t):
        beds_0 = beds_per_100k / 100_000 * N
        return beds_0 + s*beds_0*t  # 0.003

    y0 = N-1.0, 1.0, 0.0, 0.0, 0.0, 0.0
    t = np.linspace(0, days-1, days)
    ret = odeint(deriv, y0, t, args=(beta, gamma, sigma, N, prob_I_to_C, prob_C_to_D, Beds))
    S, E, I, C, R, D = ret.T
    R_0_over_time = [beta(i)/gamma for i in range(len(t))]

    return t, S, E, I, C, R, D, R_0_over_time, Beds, prob_I_to_C, prob_C_to_D

In [None]:
plotter(*Model_extended(days=360, agegroups=[N1,N2,N3], 
               beds_per_100k=5.8, R_0_start=5.7, k=0.01, x0=50, R_0_end=0.8, 
               prob_I_to_C=0.05, prob_C_to_D=0.6, s=0.003))

In [None]:
# parameters
data = covid[covid["Location"] == "Sweden"]["Deaths"].values[::-1]
agegroups = [1194000., 1127000., 1277000., 1320000., 1264000., 1296000.,
       1094000.,  994000.,  432000.,   98000.]
beds_per_100k = 5.8
outbreak_shift = 0
params_init_min_max = {"R_0_start": (3.9, 3.9, 5.7), "k": (0.09, 0.01, 0.2), "x0": (90, 60, 120), "R_0_end": (0.9, 0.9, 3.5),
                       "prob_I_to_C": (0.05, 0.01, 0.1), "prob_C_to_D": (0.5, 0.07, 0.8),
                       "s": (0.003, 0.001, 0.06)}  

In [None]:
days = outbreak_shift + len(data)
if outbreak_shift >= 0:
    y_data = np.concatenate((np.zeros(outbreak_shift), data))
else:
    y_data = y_data[-outbreak_shift:]

x_data = np.linspace(0, days - 1, days, dtype=int)  # x_data is just [0, 1, ..., max_days] array

def fitter(x, R_0_start, k, x0, R_0_end, prob_I_to_C, prob_C_to_D, s):
    ret = Model_extended(days, agegroups, beds_per_100k, R_0_start, k, x0, R_0_end, prob_I_to_C, prob_C_to_D, s)
    return ret[6][x]

In [None]:
mod = lmfit.Model(fitter)

for kwarg, (init, mini, maxi) in params_init_min_max.items():
    mod.set_param_hint(str(kwarg), value=init, min=mini, max=maxi, vary=True)

params = mod.make_params()
fit_method = "dual_annealing"

In [None]:
result = mod.fit(y_data, params, method=fit_method, x=x_data)

In [None]:
result.plot_fit(datafmt="-");


In [None]:
result.best_values

In [None]:
full_days = 360
first_date = np.datetime64(covid.Date.min()) - np.timedelta64(outbreak_shift,'D')
x_ticks = pd.date_range(start=first_date, periods=full_days, freq="D")
print("Prediction for Sweden")
plotter(*Model_extended(full_days, agegroups, beds_per_100k, **result.best_values));
#if it's not visible you can either zoom in or go to the plot and plot y axis in log scale 