In [1]:
import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import SIR_utils as sir

In [2]:
mandate = pd.read_csv('stay_at_home_and_masks.csv')
def get_date(dataframe, state):
    state = dataframe.loc[dataframe['state'] == state]
    date = np.array(state['mask_date'])[0].split('-')[::-1]
    return date[0] + '-' + date[1] + '-' + date[2]

print(mandate)

   state            Unnamed: 1  Population   mask_date  mask_in_effect  \
0     AL               Alabama     4934190  16-07-2021           False   
1     AK                Alaska      724357         NaN           False   
2     AZ               Arizona     7520100         NaN           False   
3     AR              Arkansas     3033950  20-07-2020           False   
4     CA            California    39613500  18-06-2020            True   
5     CO              Colorado     5893630  17-07-2021            True   
6     CT           Connecticut     3552820  14-08-2020            True   
7     DE              Delaware      990334  05-01-2021            True   
8     DC  District of Columbia      714153  22-07-2020            True   
9     FL               Florida    21944600         NaN           False   
10    GA               Georgia    10830000         NaN           False   
11    HI                Hawaii     1406430  17-04-2020            True   
12    ID                 Idaho     186

In [3]:
state = "new-jersey"
state_abbr = "NJ"

all_states_history =  pd.read_csv("all-states-history.csv")
def stateData(dataframe, state_abbr):
    # Create a copy
    dataframe2 = dataframe.copy()
    # set the index to be this and don't drop
    dataframe2.set_index(keys=['state'], drop=False,inplace=True)
    return dataframe2.loc[dataframe2.state==state_abbr]


time_dataframe = stateData(all_states_history, state_abbr) 
print(time_dataframe)

dates = time_dataframe["date"][::-1]
removed = time_dataframe["positiveCasesViral"][::-1]
indices = np.array(1 - np.isnan(removed)).astype(np.bool)

if all(indices == False):
    removed = time_dataframe["positive"][::-1]
    indices = np.array(1 - np.isnan(removed)).astype(np.bool)

dates = np.array(dates)[indices]
rem = np.array(removed[indices])

#print(dates)
#print(rem)

             date state    death  deathConfirmed  deathIncrease  \
state                                                             
NJ     2021-03-07    NJ  23574.0         21177.0             17   
NJ     2021-03-06    NJ  23557.0         21160.0             36   
NJ     2021-03-05    NJ  23521.0         21124.0             30   
NJ     2021-03-04    NJ  23491.0         21094.0             42   
NJ     2021-03-03    NJ  23449.0         21052.0            128   
...           ...   ...      ...             ...            ...   
NJ     2020-02-14    NJ      0.0             NaN              0   
NJ     2020-02-13    NJ      0.0             NaN              0   
NJ     2020-02-12    NJ      0.0             NaN              0   
NJ     2020-02-11    NJ      0.0             NaN              0   
NJ     2020-02-10    NJ      0.0             NaN              0   

       deathProbable  hospitalized  hospitalizedCumulative  \
state                                                        
NJ  

In [12]:
window = 14
infectious, infectious_rate = sir.compute_number_infectious(rem, window)

rem = rem[:len(infectious)]

#population = 8e7
population = np.array(mandate[mandate['state'] == state_abbr]['Population'])
sus = population - infectious - rem
print(sus)

298
[8874465.14285714 8874452.71428571 8874421.5        8874377.71428571
 8874302.92857143 8874242.64285714 8874184.21428571 8873999.14285714
 8873836.71428571 8873653.         8873472.85714286 8873166.42857143
 8872936.64285714 8872532.         8872003.14285714 8871558.21428571
 8870850.35714286 8870063.78571429 8868944.14285714 8867934.92857143
 8867044.07142857 8864490.35714286 8862426.21428571 8860036.92857143
 8857672.28571429 8854459.14285714 8852243.42857143 8848781.07142857
 8845378.07142857 8841155.57142857 8837017.71428571 8833601.
 8830021.71428571 8826677.5        8823623.85714286 8819995.71428571
 8816531.42857143 8812985.28571429 8809299.42857143 8806607.14285714
 8802479.42857143 8800259.         8796107.64285714 8793001.35714286
 8790078.78571429 8786258.78571429 8782895.21428571 8779404.
 8776081.78571429 8772127.71428571 8769948.42857143 8766742.57142857
 8763375.28571429 8761277.92857143 8758743.5        8756449.14285714
 8754150.         8751707.5        8749276.428

In [29]:
times = np.arange(len(rem))
sir_fitting = sir.moving_averages_fits(times, sus, infectious, rem, window=window, 
                                     a_guess=0.01, b_guess=0.01)

In [30]:
print(sir_fitting)

[[5.17138174e-08 2.28569531e-05]
 [8.52956691e-09 4.43951957e-05]
 [4.33176677e-09 4.45547675e-05]
 [2.84706462e-09 3.94483883e-05]
 [1.78459632e-09 1.24506073e-04]
 [1.47098808e-09 1.04603826e-04]
 [1.22750098e-09 9.56700647e-05]
 [1.03941573e-09 9.75306559e-05]
 [8.90273804e-10 1.04985828e-04]
 [7.61887204e-10 1.33050134e-04]
 [6.68316512e-10 1.58205283e-04]
 [5.80179063e-10 1.98060942e-04]
 [5.10771193e-10 2.33755689e-04]
 [4.47064160e-10 2.92140340e-04]
 [3.91093541e-10 3.90531879e-04]
 [3.39142975e-10 5.35736673e-04]
 [2.99568055e-10 6.70752173e-04]
 [2.73556207e-10 8.15909892e-04]
 [2.40835075e-10 1.13026495e-03]
 [2.17084609e-10 1.49016842e-03]
 [1.95404416e-10 1.77448154e-03]
 [1.83333741e-10 2.02870224e-03]
 [1.72861761e-10 2.25084975e-03]]
