# Facebook Prophet Model

# Imports

In [22]:
import sys
import pandas as pd
import numpy as np
import math
import plotly.graph_objects as go
import plotly

# Ignoring warnings
import warnings
warnings.filterwarnings("ignore")

# To import the main.py file
sys.path.append('../')
from python_files import main
from fbprophet import Prophet
from fbprophet.diagnostics import cross_validation

# Data Preprocessing Functions

In [23]:
def get_data():
    confirmed_global, deaths_global, recovered_global, country_cases = main.collect_data()
    
    recovered = recovered_global.groupby("country").sum().T
    deaths = deaths_global.groupby("country").sum().T
    confirmed = confirmed_global.groupby("country").sum().T
    
    deaths.index = pd.to_datetime(deaths.index, infer_datetime_format = True)
    recovered.index = pd.to_datetime(recovered.index, infer_datetime_format = True)
    confirmed.index = pd.to_datetime(confirmed.index, infer_datetime_format = True)
    
    return deaths, recovered, confirmed

In [24]:
def create_data_frame(dataframe,country):
    deaths, recovered, confirmed = get_data()
    if dataframe == 'deaths':
        data = pd.DataFrame(index = deaths.index, data = deaths[country].values, columns = ["Total"])

    elif dataframe == 'recovered':
        data = pd.DataFrame(index = recovered.index, data = recovered[country].values, columns = ["Total"])

    elif dataframe == 'confirmed':
        data = pd.DataFrame(index = confirmed.index, data = confirmed[country].values, columns = ["Total"])

    data = data[(data != 0).all(1)]
    
    data['Date'] = data.index
    cols = [data.columns[-1]] + [col for col in data if col != data.columns[-1]]
    data = data[cols]   
    data.columns=['ds','y']
    return data

# Model Functions

In [25]:
def build_model():
    prophet= Prophet (growth="linear",
                  seasonality_mode="additive",
                  changepoint_prior_scale=30,
                  seasonality_prior_scale=35,
                  daily_seasonality=False,
                  weekly_seasonality=False,
                  yearly_seasonality=False,).add_seasonality(name="monthly",period=30.5,fourier_order=55).add_seasonality(name="weekly",period=7,fourier_order=15).add_seasonality(name="daily",period=1,fourier_order=15)
    return prophet

In [26]:
def predict_future(prophet,data):
    prophet.fit(data)
    future=prophet.make_future_dataframe(freq='D',periods=14)
    forecast=prophet.predict(future)
    ftr=pd.Series(data=forecast[len(data):]['yhat'].values,index=forecast[len(data):]['ds'])
    ftr.columns=['Total']
    ftr.index.freq='D'
    return ftr

# Graphing Functions

In [27]:
def plot_forecast(data,forecast):
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=data["ds"], y=data["y"],   
                            mode='lines',
                            name='Up till now '))
        
    fig.add_trace(go.Scatter(x=forecast.index, y=forecast.values,   
                            mode='lines',
                            name='Prediction*'))

    fig.update_layout(title={
                'text': "Forecasted results",
                'y':0.9,
                'x':0.5,
                'xanchor': 'center',
                'yanchor': 'top'},
                        template = "plotly_dark",
                        xaxis_title="Date",
                        yaxis_title="Cases",
                        legend_title="Legend ",
                        font=dict(
                                family="Arial",
                                size=15,
                                color="white"
                                )
                        )
    return fig

# Error Function

In [28]:
def mape(y_true, y_pred): 
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100

# Main Function

In [29]:
def prophet_predict(df_name,country):
    data = create_data_frame(df_name,country)
    prophet=build_model()
    ftr=predict_future(prophet,data)
    fig=plot_forecast(data,ftr)
    return fig

# Example

In [32]:
prophet_predict("confirmed","India")

# Cross Validation

In [37]:
df_cv = cross_validation(prophet,initial=f'{0.8*len(data)} days',period='7 days',horizon = '1 days')
mape=mape(df_cv.y,df_cv.yhat)
print('ALLOW UPTO AN ERROR OF',mape)

INFO:fbprophet:Making 7 forecasts with cutoffs between 2020-08-07 00:00:00 and 2020-09-18 00:00:00


ALLOW UPTO AN ERROR OF 0.6939822703744413
