In [1]:
import pandas as pd
import numpy as np
import os
import pickle

from datetime import datetime

from sklearn.metrics import make_scorer
from scipy import optimize
from scipy import integrate

import matplotlib as mpl
import matplotlib.pyplot as plt

import seaborn as sns
from sklearn.base import BaseEstimator,RegressorMixin
from sklearn.metrics import mean_squared_error, mean_absolute_error

from sklearn.model_selection import GridSearchCV

sns.set(style="darkgrid")

mpl.rcParams['figure.figsize'] = (16, 9)
pd.set_option('display.max_rows', 500)

%matplotlib inline

df_analyse=pd.read_csv('COVID-19/COVID_small_flat_table.csv',sep=';')

In [2]:
class SIR_modelling(object):
    def __init__(self,country_data,country_population):
        self.country_data = country_data

        # initalization for SIR_t_model
        self.N0=country_population #population size
        self.I0=country_data[0] #Initial infected
        self.S0=self.N0-self.I0 #Susceptible people
        self.R0=0 #Recovered

    def SIR_model_t(self,SIR,t,beta,gamma):
        ''' Simple SIR model
            S: susceptible population
            t: time step, mandatory for integral.odeint
            I: infected people
            R: recovered people
            beta: 
            
            overall condition is that the sum of changes (differnces) sum up to 0
            dS+dI+dR=0
            S+I+R= N (constant size of population)
        
        '''
        
        S,I,R=SIR
        dS_dt=-beta*S*I/self.N0           
        dI_dt=beta*S*I/self.N0-gamma*I
        dR_dt=gamma*I
        return dS_dt,dI_dt,dR_dt

    def fit_odeint(self,x, beta, gamma):
        return integrate.odeint(self.SIR_model_t, (self.S0, self.I0, self.R0), x, args=(beta, gamma))[:,1] # we only would like to get dI

    def interval_curve_fit(self,interval):
        t=np.arange(len(self.country_data))

        if(interval==-1):
            interval = len(self.country_data)

        interval_fitted = np.array([])
        r0 = []
        SIR=np.array([self.S0,self.I0,self.R0])

        for i in range(self.country_data.size):
            interval_data = self.country_data[i*interval:(i*interval)+interval] #List comprehension 
            interval_t = t[i*interval:(i*interval)+interval]

            if(interval_data.size != interval_t.size or interval_data.size==0):
                break

            #Re-initialize SIR for this interval
            self.I0=interval_data[0] #Initial infected
            self.S0=self.N0-self.I0 #Susceptible people
            self.R0=SIR[2] #Recovered

            #print("\n\nFitting curve from day - ",interval_t[0]," to ",interval_t[-1]," with SIR - ",SIR)
            popt, pcov = optimize.curve_fit(self.fit_odeint,interval_t,interval_data,maxfev=1000)
            perr = np.sqrt(np.diag(pcov))
            r0.append(popt[0]/popt[1])

            temp_fit=self.fit_odeint(interval_t,*popt)
            interval_fitted = np.hstack((interval_fitted, temp_fit))

        return interval_fitted, sum(r0)/len(r0)

In [3]:
import dash
import plotly
dash.__version__
import dash_core_components as dcc
import dash_html_components as html
import plotly.graph_objects as go

fig = go.Figure()
app = dash.Dash()
#Configure dashboard
app.layout = html.Div([
    
    dcc.Markdown('''
    ## Multi-Select Country for visualization
    '''),
    dcc.Dropdown(
        id='country_drop_down',
        options=[ {'label': each,'value':each} for each in df_analyse.columns.values[1:]],
        value=['India'], # which are pre-selected
        multi=True
    ),
    dcc.Markdown('''
    ## Select type of SIR model curve fit
    '''),
    dcc.Dropdown(
        id='curve_fit_drop_down',
        options=[ {'label': each,'value':each} for each in ['General Curve Fit','Interval Curve Fit']],
        value=['General Curve Fit'],
        multi=True
    ),
    dcc.Markdown('''
    ### Select interval (days) size
    '''),
    dcc.Input(
            id="interval_input", type="number", value=5,
            min=5, max=100, step=5,
        ),
    dcc.Graph(figure=fig, id='main_window_slope')
])


In [4]:
from dash.dependencies import Input, Output

@app.callback(
    Output('main_window_slope', 'figure'),
    [Input('country_drop_down', 'value'),
    Input('curve_fit_drop_down','value'),
    Input('interval_input','value')])
def update_figure(country_list,curve_fit_type,interval_input):
    
    traces = [] 
    
    for country in country_list:
        # Trace actual infection
        traces.append(dict(x=df_analyse.date,
                                y=df_analyse[country][10:],
                                mode='markers+lines',
                                opacity=0.9,
                                line_width=4,
                                marker_size=8, 
                                name=country+"_actual"
                        )
                )

        if curve_fit_type and 'General Curve Fit' in curve_fit_type:
            interval_fit,r0 = SIR_modelling(np.array(df_analyse[country][10:]),1350000000).interval_curve_fit(-1)
            traces.append(dict(x=df_analyse.date[10:],
                                y=interval_fit[10:],
                                mode='lines',
                                line_width=2,
                                name=country+"_general_fit"
                        )
                )

        if curve_fit_type and 'Interval Curve Fit' in curve_fit_type:
            interval_fit,r0 = SIR_modelling(np.array(df_analyse[country][10:]),1350000000).interval_curve_fit(interval_input)
            traces.append(dict(x=df_analyse.date,
                                y=interval_fit,
                                mode='lines',
                                line_width=2,
                                name=country+"_interval_fit"
                        )
                )
        
    return {
            'data': traces,
            'layout': dict (
                width=1280,
                height=720,

                xaxis={'title':'Timeline',
                        'tickangle':-45,
                        'nticks':20,
                        'tickfont':dict(size=14,color="#7f7f7f"),
                        },
                yaxis={'type':"log",
                  'title':'Infected people (log-scale)'
                }
        )
    }

In [None]:
app.run_server(debug=True, use_reloader=False)

Dash is running on http://127.0.0.1:8050/

 in production, use a production WSGI server like gunicorn instead.

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: on



Excess work done on this call (perhaps wrong Dfun type). Run with full_output = 1 to get quantitative information.


overflow encountered in double_scalars


overflow encountered in double_scalars


Illegal input detected (internal error). Run with full_output = 1 to get quantitative information.


overflow encountered in double_scalars

