# Importing Libraries

In [827]:
import sys
import pandas as pd
import numpy as np
import math
import plotly.graph_objects as go
import plotly
from pmdarima import auto_arima
from sklearn.model_selection import train_test_split
from statsmodels.tsa.statespace.sarimax import SARIMAX

# Ignoring warnings
import warnings
warnings.filterwarnings("ignore")

# To import the main.py file
sys.path.append('../')
from python_files import main

In [828]:
%matplotlib inline
plotly.offline.init_notebook_mode(connected=True)

# Data Preprocessing Functions

In [829]:
def get_data():
    '''
Retrieve shaped data from the John Hopkins University COVID-19 database
    '''
    confirmed_global, deaths_global, recovered_global, country_cases = main.collect_data()
    
    recovered = recovered_global.groupby("country").sum().T
    deaths = deaths_global.groupby("country").sum().T
    confirmed = confirmed_global.groupby("country").sum().T
    
    deaths.index = pd.to_datetime(deaths.index, infer_datetime_format = True)
    recovered.index = pd.to_datetime(recovered.index, infer_datetime_format = True)
    confirmed.index = pd.to_datetime(confirmed.index, infer_datetime_format = True)
    
    return deaths, recovered, confirmed

In [874]:
def create_data_frame(df_name="confirmed",country="India"):
    '''
Create shaped DataFrames
df_name: Choice of prediction: "recovered", "confirmed", "deaths"
country: Name of the country you wish to predict for
The return value is a dataframe with columns 'Date','Total' containing Dates and Values respectively,
with a DatetimeIndex
    '''
    deaths, recovered, confirmed = get_data()
    if df_name == 'deaths':
        data = pd.DataFrame(index = deaths.index, data = deaths[country].values, columns = ["Total"])

    elif df_name == 'recovered':
        data = pd.DataFrame(index = recovered.index, data = recovered[country].values, columns = ["Total"])

    elif df_name == 'confirmed':
        data = pd.DataFrame(index = confirmed.index, data = confirmed[country].values, columns = ["Total"])

    data = data[(data != 0).all(1)]
    data['Date'] = data.index
    cols = [data.columns[-1]] + [col for col in data if col != data.columns[-1]]
    data = data[cols]   

    return data

# Graphing Functions

In [831]:
def plot_forecast(data,forecast):
        ''' 
Plot the forecasted values against the data at hand
data: Dataframe with columns 'Date','Total' containing Dates and Values of the data at hand respectively
forecast: Series with a Pandas DatetimeIndex of the dates on which prediction occurs, 
          containing the predicted values
The return value is a plotly.graph_objs._figure.Figure which showcases predictions 
    '''
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=data["Date"], y=data["Total"],   
                            mode='lines',
                            name='Up till now '))
        
    fig.add_trace(go.Scatter(x=forecast.index, y=forecast.values,   
                            mode='lines',
                            name='Prediction*'))

    fig.update_layout(title={
                'text': "Forecasted results",
                'y':0.9,
                'x':0.5,
                'xanchor': 'center',
                'yanchor': 'top'},
                        template = "plotly_dark",
                        xaxis_title="Date",
                        yaxis_title="Cases",
                        legend_title="Legend ",
                        font=dict(
                                family="Arial",
                                size=15,
                                color="white"
                                )
                        )
    return fig

# Functions to Train and Test the Model

In [901]:
def find_params(train_set):
    '''
train_set: Training Dataset
Takes in Training dataset and creates a SARIMAX model with self determined best parameters
Return value is a pmdarima.arima.arima.ARIMA model
    '''
    stepwise_model = auto_arima(train_set, method='nm', start_p = 0, start_q = 0,
                               max_p = 2, max_q = 2, m = 7,
                               start_P = 0, max_P=1, start_Q=0, max_Q=1, seasonal = True,
                               d = None, D = 1, n_jobs=-1, trace = True,
                               error_action = 'ignore',  
                               suppress_warnings = True, 
                               stepwise = True)
    return stepwise_model

In [846]:
def Predict(stepwise_model,train,test):
    '''
stepwise_model: SARIMA model that has been fit with the training dataset
train: Training Dataset
test: Test Dataset
Returns a DataFrame with predicted values that can be compared against the test values
    '''
    stepwise_model.fit(train)
    
    pred = stepwise_model.predict(n_periods=len(test))
    
    pred = pd.DataFrame(pred,index = test.index,columns=['Prediction'])
   
    return pred

# Error Function

In [847]:
def mape(y_true, y_pred): 
    '''Return the Mean Absolute Percentage Error between any 2 Series'''
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100

# Function for Forecasting

In [848]:
def Future(order,seasonal_order,train,test,data):
    '''
order: (p,d,q) order of the SARIMA model
seasonal_order: (P,D,Q) the seasonal order of the SARIMA model
train: Training Dataset
test: Test Dataset
data: Original Dataset
    '''
    model = SARIMAX(data['Total'],  
                        order = order,  
                        seasonal_order = seasonal_order) 
    result = model.fit() 
  
    forecast = result.predict(start = len(data),  
                          end = (len(data)-1) + 14).rename('Forecast') 
    
 
    error_check = result.predict(start = len(train), end = len(train) - 1 +len(test))
    error = mape(error_check,test)
    error = error
    graph = plot_forecast(data,forecast)

    return forecast,graph,error

# Calling Function

In [896]:
def arima_predict(df_name,country):
    '''
Predict for any Country
df_name: Choice of prediction: "recovered", "confirmed", "deaths"
country: Name of the country you wish to predict for
The return values are: 
1)Pandas Series with Forecasted values
2)A plotly.graph_objs._figure.Figure which showcases predictions 
3)Mean Absolute Error Percentage of the prediction 
    '''
    data = create_data_frame(df_name,country)
    train = data["Total"][:len(data)*4//5]
    test = data["Total"][len(data)*4//5:]
    
    model = find_params(train)
    pred = Predict(model,train,test)
    mape_error = mape(test, pred["Prediction"])
    order=model.get_params()['order']
    seasonal_order=model.get_params()['seasonal_order']
    
    forecast,graph,error = Future(order, seasonal_order, train, test,data)
    return forecast,graph,(error + np.std([error, mape_error]))

# Example

In [902]:
forecast,graph,error = arima_predict("deaths","Australia")

Performing stepwise search to minimize aic
Fit ARIMA(0,2,0)x(0,1,0,7) [intercept=True]; AIC=745.303, BIC=751.377, Time=0.046 seconds
Fit ARIMA(1,2,0)x(1,1,0,7) [intercept=True]; AIC=690.601, BIC=702.749, Time=0.494 seconds
Fit ARIMA(0,2,1)x(0,1,1,7) [intercept=True]; AIC=622.182, BIC=634.330, Time=0.445 seconds
Near non-invertible roots for order (0, 2, 1)(0, 1, 1, 7); setting score to inf (at least one inverse root too close to the border of the unit circle: 0.993)
Fit ARIMA(0,2,0)x(0,1,0,7) [intercept=False]; AIC=743.347, BIC=746.384, Time=0.054 seconds
Fit ARIMA(1,2,0)x(0,1,0,7) [intercept=True]; AIC=723.576, BIC=732.687, Time=0.119 seconds
Fit ARIMA(1,2,0)x(1,1,1,7) [intercept=True]; AIC=638.885, BIC=654.069, Time=0.540 seconds
Fit ARIMA(1,2,0)x(0,1,1,7) [intercept=True]; AIC=637.899, BIC=650.047, Time=0.440 seconds
Near non-invertible roots for order (1, 2, 0)(0, 1, 1, 7); setting score to inf (at least one inverse root too close to the border of the unit circle: 0.992)
Fit ARIMA(

In [903]:
print("PREDICTIONS FOR THE NEXT 14 DAYS : \n")
for i in range(len(forecast)):
    print(forecast.index[i] , "---->", int(forecast.values[i]))

PREDICTIONS FOR THE NEXT 14 DAYS : 

2020-09-21 00:00:00 ----> 855
2020-09-22 00:00:00 ----> 861
2020-09-23 00:00:00 ----> 867
2020-09-24 00:00:00 ----> 874
2020-09-25 00:00:00 ----> 879
2020-09-26 00:00:00 ----> 884
2020-09-27 00:00:00 ----> 891
2020-09-28 00:00:00 ----> 896
2020-09-29 00:00:00 ----> 902
2020-09-30 00:00:00 ----> 908
2020-10-01 00:00:00 ----> 915
2020-10-02 00:00:00 ----> 920
2020-10-03 00:00:00 ----> 926
2020-10-04 00:00:00 ----> 933


In [904]:
graph.show()

In [905]:
print("ALLOW AN ERROR OF UPTO (Based on the data at hand) :",error,"%")

ALLOW AN ERROR OF UPTO (Based on the data at hand) : 12.129355914731594 %
