In [1]:
from math import inf
from jupyter_dash import JupyterDash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from datetime import date, timedelta
import numpy as np

from itertools import chain

from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import PolynomialFeatures

import statsmodels.api as sm

from tqdm.notebook import tqdm

In [2]:
def extract_data(from_df, level, data_type):
    temp_df = from_df.copy()
    groupby = ['state', 'county'] if level == "county" else ['state']
    opposite = "_deaths" if data_type == "cases" else "_cases"
    type_data = temp_df[[i for i in temp_df.columns if opposite not in i]]
    type_data.columns = [i.replace("_{0}".format(data_type), "") for i in type_data.columns]
    type_data = pd.concat([type_data.iloc[:, :6], type_data.iloc[:, 5:].diff(axis=1).iloc[:, 1:]], axis=1)
    # Replace all negative values with 0
    temp = type_data.iloc[:,5:]
    temp[temp < 0] = 0.0000001
    type_data = pd.concat([type_data.iloc[:, :5], temp], axis=1)
    type_data.drop(['countyFIPS', 'StateFIPS'], axis=1, inplace=True)
    if level == "nation":
        type_data = type_data.iloc[:, 2:].sum().reset_index()
        type_data['population'] = type_data[0][0]
        type_data.rename(columns={"index": "date", 0: data_type}, inplace=True)
        type_data.drop(0, inplace=True)
    else:
        if level == 'state':
            type_data.drop(['County Name'], axis=1, inplace=True)
            type_data = type_data.groupby('State').sum().reset_index()
            type_data.set_index(['State', 'population'], inplace=True)
        else:
            type_data.drop(type_data[type_data.population == 0].index, inplace=True)
            type_data.set_index(['State', 'County Name', 'population'], inplace=True)
        type_data = type_data.stack().reset_index()
        type_data.columns = groupby + ['population', 'date', data_type]
    return type_data

In [3]:
def remove_before_first(data, data_type=None, inplace=False):
    temp = data if inplace else data.copy()
    if not data_type: data_type = temp.columns[-1]
    idx = temp[temp[data_type] != 0].index
    temp.drop(list(range(temp.index[0], idx[0])) if not idx.empty else [], inplace=True)
    if not inplace: return temp

In [4]:
def choose_best_degree(data, fit_for):
    temp = data.copy()
    rmse = inf
    req = 0
    for i in range(10):
        polynomial_features = PolynomialFeatures(degree=i)
        y = temp[fit_for]
        x = np.arange(1, len(y)+1)
        xp = polynomial_features.fit_transform(x.reshape(len(x), 1))
        pm = sm.OLS(y.values.reshape(len(y), 1), xp).fit()
        temp['poly_pred'] = np.ceil(pm.predict(xp))
        temp_rmse = np.sqrt(mean_squared_error(temp[fit_for], temp.poly_pred))
        if temp_rmse < rmse:
            rmse = temp_rmse
            req = i
    return req

In [5]:
def poly_fit(data, fit_for, inplace=False, degree=None, prediction_date=None):
    temp = data if inplace else data.copy()
    if not degree: degree = choose_best_degree(temp, fit_for)
    polynomial_features = PolynomialFeatures(degree=degree)
    y = temp[fit_for]
    x = np.arange(1, len(y)+1)
    xp = polynomial_features.fit_transform(x.reshape(len(x), 1))
    poly_model = sm.OLS(y.values.reshape(len(y), 1), xp).fit()
    temp['poly_pred'] = np.ceil(poly_model.predict(xp))
    if prediction_date:
        end_date = np.datetime64(data['date'].values[-1])
        pred_date = np.datetime64(prediction_date) 
        end_days = data.shape[0]
        days = (pred_date - end_date).astype('int')
        x = np.arange(end_days, end_days + days + 1)
        dates = np.arange(end_date, pred_date).astype('str')
        xp = polynomial_features.fit_transform(x.reshape(len(x), 1))
        preds = np.ceil(poly_model.predict(xp))
        preds = np.where(preds<0, 0, preds)
        return (temp, dates, preds) if not inplace else (dates, preds)
    if not inplace: return temp

In [6]:
def linear_fit(data, fit_for, inplace=False, prediction_date=None):
    temp = data if inplace else data.copy()
    y = temp[fit_for]
    x = np.arange(1, len(y)+1)
    lr_model = sm.OLS(y, x).fit()
    temp['lr_pred'] = np.ceil(lr_model.predict(x))
    if prediction_date:
        end_date = np.datetime64(data['date'].values[-1])
        pred_date = np.datetime64(prediction_date) 
        end_days = data.shape[0]
        days = (pred_date - end_date).astype('int')
        x = np.arange(end_days, end_days + days + 1)
        dates = np.arange(end_date, pred_date).astype('str')
        preds = np.ceil(lr_model.predict(x))
        preds = np.where(preds<0, 0, preds)
        return (temp, dates, preds) if not inplace else (dates, preds)
    if not inplace: return temp

In [7]:
def linear_normalize(data, data_type, norm_val=100000, inplace=False):
    temp = data if inplace else data.copy()
    temp['linear_norm'] = temp[data_type]/temp.population.max()*norm_val
    if not inplace: return data

In [8]:
def log_normalize(data, data_type, inplace=False):
    np.seterr(divide='ignore', invalid='ignore')
    temp = data if inplace else data.copy()
    temp['log_norm'] = np.log10(temp[data_type])
    if not inplace: return temp

In [9]:
def rolling_avg(data, data_type, window_size, inplace=False):
    temp = data if inplace else data.copy()
    temp['roll_avg'] = np.ceil(temp[data_type].rolling(window=window_size, min_periods=1).mean()).astype('int')
    if not inplace: return temp

In [10]:
def remove_negatives(data, inplace=False):
    temp = data.copy() if not inplace else data
    numeric = temp._get_numeric_data()
    numeric[numeric < 0] = 0
    non_numeric = temp[[i for i in temp.columns if i not in numeric.columns]]
    orig = pd.concat([numeric, non_numeric], axis=1)
    orig = orig[temp.columns]
    if not inplace: return orig

In [11]:
super_dataset = pd.read_csv("../../../data/stage_IV/superDataset.csv")
state_names = pd.read_csv("../../../data/stage_I/name-abbr.csv", header=None, names=['label', 'value'])

# Create dashboard

In [12]:
app = JupyterDash(__name__, external_stylesheets=['https://codepen.io/chriddyp/pen/bWLwgP.css'])

# Create dashboard layout

In [13]:
app.layout = html.Div([
    #This Div is for the title and the background.
    html.Div(
        children = [
            html.Div(
                children = [
                    html.H1("COVID - 19 Data Visualization", className="header-title")
                ], className = "header",)
        ]),
    #This Div is for the Menu
    html.Div(
        children=[
            html.Div(
                children=[
                    html.Div(
                        children=[
                            html.Div(children = "State", className = "menu-title"),
                            dcc.Dropdown(
                                id='state_dropdown',
                                placeholder='Select a state',
                                options=[{'label': 'All States', 'value': 'all'}] + state_names.to_dict('records'),
                                value=['all'],
                                multi=True,
                            ),
                        ], className = "give-space",),
                    html.Div(
                        children =[
                            html.Div(
                                children =[
                                    html.Div(children = "Graph Type", className = "menu-title"),
                                    dcc.RadioItems(
                                        id='graph_type_picker',
                                        options=[
                                                {'label' : 'Graph', 'value' : 'graph'},
                                                {'label' : 'Map', 'value': 'map'}
                                            ],
                                        value='graph'
                                    ),
                                ], className="flex-child",),
                            html.Div(
                                children=[
                                    html.Div(children = "Data Type", className = "menu-title"),
                                    dcc.RadioItems(
                                        id='data_type_picker',
                                        options=[
                                            {'label' : 'Cases', 'value' : 'cases'},
                                            {'label' : 'Deaths', 'value' : 'deaths'}
                                        ],
                                        value='cases'
                                    ),
                                ], className="flex-child",),
                        ],className="flex-container",),
                    ]),
            
            html.Div(
                children=[
                    html.Div(
                        children=[
                            html.Div(children = "County", className = "menu-title"),
                            dcc.Dropdown(
                                id='county_dropdown',
                                placeholder='Select a county',
                            ),
                ], className = "give-space",),
                    html.Div(
                        children=[
                            html.Div(children = "Prediction Model", className = "menu-title"),
                            dcc.Dropdown(
                                id='pred_dropdown',
                                placeholder = 'Select the Prediction Model',
                                options=[
                                    {'label' : 'Linear Model', 'value' : 'linear'},
                                    {'label' : 'Non-Linear Model', 'value' : 'poly'}
                                ],
                            ),
                        ],className="dropdown",),
                        
                    ]),
            
            html.Div(
                children=[
                    html.Div(
                        children=[
                            html.Div(children = "Date Range", className = "menu-title"),
                            dcc.DatePickerSingle(
                                id='pre_date_picker',
                                min_date_allowed=[i for i in super_dataset.columns if "_cases" in i][0].replace("_cases", ""),
                                max_date_allowed=date.today(),
                                date = [i for i in super_dataset.columns if "_cases" in i][0].replace("_cases", ""),
                            ),
                            dcc.DatePickerSingle(
                                id='post_date_picker',
                                min_date_allowed=[i for i in super_dataset.columns if "_cases" in i][0].replace("_cases", ""),
                                max_date_allowed=date.today(),
                                date=date.today()
                            ),
                        ], className = "give-space",),
                    html.Div(
                    children=[
                        html.Div(children = "Prediction Date", className = "menu-title"),
                        dcc.DatePickerSingle(
                            id='prediction_date_picker',
                            date=[i for i in super_dataset.columns if "_cases" in i][-1].replace("_cases", ""),
                        ),
                    ]),
                ]),
            
            html.Div(
                children=[
                    html.Div(
                        children=[
                            html.Div(children = "Normalization Type", className = "menu-title"),
                            dcc.Dropdown(
                                 id='norm_dropdown',
                                 placeholder = 'Select the normalization',
                                 options=[
                                     {'label' : 'Linear Normalization', 'value' : 'norm'},
                                     {'label' : 'Log Normalization', 'value' : 'log_norm'}
                                 ],
                                 value='norm'
                            ),
                        ], className = "give-space",),   
                    html.Div(
                        children=[
                            html.Div(children = "Moving Average", className = "menu-title"),
                            dcc.Checklist(
                                id = 'roll_avg_checklist',
                                options = [
                                    {'label' : ' 7-day moving average', 'value': 'true'},
                                ],
                            ),
                        ]),      
                ]),
            ], className = "menu"),
    #This Div is for the main Graph.
    html.Div(
        children =[
            html.Div(
                children = [
                    html.P(
                        id='para_text',
                        children='Dash converts Python classes into HTML'
                    ),
                    dcc.Graph(
                        id = 'viz_graph',
                        figure = {

                        },
                    ),
        ], className = "card",),    
    ], className = "wrapper",),
])

In [14]:
# app.layout = html.Div([
#     html.H1("COVID - 19 Data Visualization", style={'text-align': 'center'}),
#     dcc.Dropdown(
#         id='state_dropdown',
#         placeholder='Select a state',
#         options=[{'label': 'All States', 'value': 'all'}] + state_names.to_dict('records'),
#         value=['all'],
#         multi=True,
#     ),
#     dcc.Dropdown(
#         id='county_dropdown',
#         placeholder='Select a county',
#     ),
#     dcc.DatePickerSingle(
#         id='pre_date_picker',
#         min_date_allowed=[i for i in super_dataset.columns if "_cases" in i][0].replace("_cases", ""),
#         max_date_allowed=date.today(),
#         date = [i for i in super_dataset.columns if "_cases" in i][0].replace("_cases", ""),
#     ),
#     dcc.DatePickerSingle(
#         id='post_date_picker',
#         min_date_allowed=[i for i in super_dataset.columns if "_cases" in i][0].replace("_cases", ""),
#         max_date_allowed=date.today(),
#         date=date.today()
#     ),
#     dcc.RadioItems(
#         id='graph_type_picker',
#         options=[
#                 {'label' : 'Plot', 'value' : 'Plot'},
#                 {'label' : 'Map', 'value': 'Map'}
#             ],
#         value='Plot'
#     ),
#     dcc.Dropdown(
#         id='norm_dropdown',
#         placeholder = 'Select the normalization',
#         options=[
#             {'label' : 'Linear Normalization', 'value' : 'norm'},
#             {'label' : 'Log Normalization', 'value' : 'log_norm'}
#         ],
#         value='norm'
#     ),
#     dcc.RadioItems(
#         id='data_type_picker',
#         options=[
#             {'label' : 'Cases', 'value' : 'cases'},
#             {'label' : 'Deaths', 'value' : 'deaths'}
#         ],
#         value='cases'
#     ),
#     dcc.Dropdown(
#         id='pred_dropdown',
#         placeholder = 'Select the Prediction Model',
#         options=[
#             {'label' : 'Linear Model', 'value' : 'linear'},
#             {'label' : 'Non-Linear Model', 'value' : 'poly'}
#         ],
#     ),
#     dcc.DatePickerSingle(
#         id='prediction_date_picker',
#         date=[i for i in super_dataset.columns if "_cases" in i][-1].replace("_cases", ""),
#     ),
#     dcc.Checklist(
#         id = 'roll_avg_checklist',
#         options = [
#             {'label' : ' 7-day moving average', 'value': 'true'},
#         ],
#     ),
#     html.P(
#         id='para_text',
#         children='Dash converts Python classes into HTML'
#     ),
#     dcc.Graph(
#         id = 'viz_graph',
#         figure = {},
#     ),
# ])

# Callback for state and county dropdown

In [15]:
@app.callback(
    [Output(component_id='county_dropdown', component_property='options'),
     Output(component_id='county_dropdown', component_property='disabled')],
    [Input(component_id='state_dropdown', component_property='value')]
)
def update_county(value):
    county_disabled = True
    ret = []
    if len(value) == 1:
        county_disabled = False
        if value != ['all']:
            placeholder = "Select a county"
            ret = [{'label': 'All Counties', 'value': 'all'}] + [{'label': i, 'value': i} for i in super_dataset[(super_dataset.State == value[0])&(super_dataset.countyFIPS != 0)]['County Name']]
        else:
            county_disabled = True
    return ret, county_disabled

@app.callback(
    Output(component_id='county_dropdown', component_property='value'),
    [Input(component_id='county_dropdown', component_property='options')]
)
def update_county_vals(value):
    return 'all'

@app.callback(
    [
        Output(component_id='pred_dropdown', component_property='disabled'),
        Output(component_id='pred_dropdown', component_property='value'),
    ],
    Input(component_id='norm_dropdown', component_property='value')
)
def update_county_vals(value):
    if value == 'log_norm':
        return True, None
    return False, None

# Callback for graphs

In [16]:
@app.callback(
    [
        Output(component_id = 'viz_graph', component_property = 'figure'),
        Output(component_id='para_text', component_property='children'),
    ],
    [
        Input(component_id='state_dropdown', component_property='value'),
        Input(component_id='county_dropdown', component_property='value'),
        Input(component_id='pre_date_picker', component_property='date'),
        Input(component_id='post_date_picker', component_property='date'),
        Input(component_id='norm_dropdown', component_property='value'),
        Input(component_id='data_type_picker', component_property='value'),
        Input(component_id='pred_dropdown', component_property='value'),
        Input(component_id='roll_avg_checklist', component_property='value'),
        Input(component_id='prediction_date_picker', component_property='date'),
    ]
)

def update_graph(states, county, start_date, end_date, norm_type, data_type, pred_type, roll_avg_state, pred_date):
    fig = go.Figure()
    for state in states:
        if state == "all":
            pd_data = extract_data(super_dataset, 'nation', data_type)
        else:
            if county != "all":
                pd_data = extract_data(super_dataset, 'county', data_type)
                pd_data = pd_data[(pd_data.state == state) & (pd_data.county == county)]
            else:
                pd_data = extract_data(super_dataset, 'state', data_type)
                pd_data = pd_data[pd_data.state == state]
        remove_before_first(pd_data, inplace=True)
        rolling_avg(pd_data, data_type, 7, inplace=True)
        linear_normalize(pd_data, data_type, inplace=True)
        log_normalize(pd_data, data_type, inplace=True)
        remove_negatives(pd_data, inplace=True)
        if np.datetime64(pred_date) <= np.datetime64(pd_data.date.values[-1]):
            pred_date = None
        if roll_avg_state == ['true']:
            if pred_type == 'linear':
                preds = linear_fit(pd_data, 'roll_avg', inplace=True, prediction_date=pred_date)
            else:
                preds = poly_fit(pd_data, 'roll_avg', inplace=True, prediction_date=pred_date)
        else:
            if pred_type == 'linear':
                preds = linear_fit(pd_data, data_type, inplace=True, prediction_date=pred_date)
            else:
                preds = poly_fit(pd_data, data_type, inplace=True, prediction_date=pred_date)
        pd_data = pd_data[(pd_data.date >= start_date) & (pd_data.date <=end_date)]
        remove_negatives(pd_data, inplace=True)
        # Get the first and last date
        pre, post = np.datetime64(pd_data.date.values[0]), np.datetime64(pd_data.date.values[-1])
        if norm_type:
            if norm_type == 'norm':
                fig.add_trace(go.Bar(x=pd_data.date, y=pd_data[data_type], name="{0}-{1}".format(data_type, state)))
            else:
                fig.add_trace(go.Scatter(x=pd_data.date, y=pd_data.log_norm, mode='lines+markers', name=data_type))
        if roll_avg_state == ['true']:
            fig.add_trace(
                go.Scatter(x=pd_data.date, y=pd_data.roll_avg, mode='lines', name='7 - day moving avg'),
            )
        if pred_type:
            if pred_type == 'linear':
                fig.add_trace(
                    go.Scatter(x=pd_data.date, y=pd_data.lr_pred, mode='lines', name='linear-prediction'),
                )
            else:
                fig.add_trace(
                    go.Scatter(x=pd_data.date, y=pd_data.poly_pred, mode='lines', name='polyomial-prediction'),
                )
        if pred_date:
            fig.add_trace(
                go.Scatter(x=preds[0], y=preds[1], mode='lines', name='prediction', line=dict(dash='dash')),
            )
        fig.update_layout(
            title="COVID 19 {0} across {1} from {2} to {3}".format(data_type, state, start_date, end_date),
            title_x=0.5,
            xaxis_title="Date",
            yaxis_title="Frequency (NEW {0} per day)".format(data_type.upper()),
            legend=dict(orientation="h", yanchor="bottom", y=1, x=0.5),
#             hovermode="x",
        )
    return fig, roll_avg_state

# Run server

In [17]:
app.run_server(mode="external", debug=True, port=8086, use_reloader=False)

Dash app running on http://127.0.0.1:8086/
