In [None]:
import fbprophet as pr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import folium
from sklearn.metrics import mean_squared_error
from statsmodels.tsa.statespace.sarimax import SARIMAX


In [None]:
def get_covid_data():
    
    #get the latest data from OxCGRT
    DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'
    full_df = pd.read_csv(DATA_URL,
                    parse_dates=['Date'],
                    encoding="ISO-8859-1",
                    dtype={"RegionName": str},
                    error_bad_lines=False)

    #add new cases and new deaths columns

    full_df['NewCases'] = np.nan
    full_df['NewDeaths'] = np.nan
    
    for state in full_df[(full_df['Jurisdiction'] == 'STATE_TOTAL')]['RegionName'].unique():
        state_inds = (full_df['Jurisdiction'] == 'STATE_TOTAL') & (full_df['RegionName'] == state)
        full_df.loc[state_inds, 'NewCases'] = full_df.loc[state_inds, 'ConfirmedCases'].diff().fillna(0)
        full_df.loc[state_inds, 'NewDeaths'] = full_df.loc[state_inds, 'ConfirmedDeaths'].diff().fillna(0)

    for country in full_df[(full_df['Jurisdiction'] == 'NAT_TOTAL')]['CountryName'].unique():
        nat_inds = (full_df['Jurisdiction'] == 'NAT_TOTAL') & (full_df['CountryName'] == country)
        full_df.loc[nat_inds, 'NewCases'] = full_df.loc[nat_inds, 'ConfirmedCases'].diff().fillna(0)
        full_df.loc[nat_inds, 'NewDeaths'] = full_df.loc[nat_inds, 'ConfirmedDeaths'].diff().fillna(0)

    return full_df

def split_data(full_df):

    state_df = full_df[full_df['Jurisdiction'] == 'STATE_TOTAL']
    country_df = full_df[full_df['Jurisdiction'] == 'NAT_TOTAL']
    return (state_df, country_df)

In [None]:
full_df = get_covid_data()

In [None]:
division = 'NAT_TOTAL'
region = 'United States'

df = full_df[(full_df['Jurisdiction'] == division) & (full_df['CountryName'] == region)][['Date','NewDeaths','NewCases']][:-1]
df

In [None]:
train_df = df[df['Date'] < '2020-12-01'][['Date','NewCases']]
test_df = df[df['Date'] >= '2020-12-01'][['Date','NewCases']]
train_df.columns = ['ds','y']
test_df.columns = ['ds','y']

In [None]:
def mean_percent_error(y_test, y_hat):
    error = np.abs(y_test - y_hat)
    percent_error = error/y_test
    mean_percent_error = percent_error.sum() / len(y_test)
    return mean_percent_error

In [None]:
m = pr.Prophet(seasonality_mode = 'multiplicative')
m.add_country_holidays(country_name='US')
# m.add_regressor('NewCases')
m.fit(train_df)
future = m.make_future_dataframe(periods=len(test_df))
forecast = m.predict(future)


In [None]:
model2 = SARIMAX(train_df[train_df['ds'] > '2020-03-31'].set_index('ds'), order =(1,1,4), seasonal_order = (4,0,0,7),
                                       freq = 'D')
fit_model2 = model2.fit(maxiter = 200, disp = False)
yhat = fit_model2.forecast(len(test_df))


In [None]:
prophet_mpse = mean_percent_error(test_df['y'].values, forecast['yhat'][-len(test_df):].values)
print(f'Prophet MPSE = {prophet_mpse}')
sarmimax_mpse = mean_percent_error(test_df['y'].values, yhat)
print(f'SARIMAX MPSE = {sarmimax_mpse}')

In [None]:
fig, ax = plt.subplots(figsize = (20,10))
m.plot(forecast, ax = ax,)
ax.scatter(test_df['ds'],test_df['y'], marker = '.', c = 'red')

In [None]:
import plotly
import plotly.express as px

In [None]:
from fbprophet.plot import plot_plotly, plot_components_plotly

plot_plotly(m, forecast, changepoints = True)
px.scatter(x=test_df['ds'], y=test_df['y'])

In [None]:
plot_components_plotly(m, forecast)