# Importing Libraries

In [91]:
import sys
import pandas as pd
import numpy as np
import plotly
import math
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo
from pmdarima import auto_arima
from sklearn.model_selection import train_test_split
from statsmodels.tsa.statespace.sarimax import SARIMAX

# To import the main.py file
sys.path.append('../')
from python_files import main

In [92]:
# Setting up plotly to work offline and in jupyter notebooks
pyo.init_notebook_mode(connected = True)
%matplotlib inline

# Data Preprocessing Functions

In [93]:
def get_data():
    confirmed_global, deaths_global, recovered_global, country_cases = main.collect_data()
    recovered = recovered_global.groupby("country").sum().T
    deaths = deaths_global.groupby("country").sum().T
    confirmed = confirmed_global.groupby("country").sum().T
    deaths.index = pd.to_datetime(deaths.index, infer_datetime_format = True)
    recovered.index = pd.to_datetime(recovered.index, infer_datetime_format = True)
    confirmed.index = pd.to_datetime(confirmed.index, infer_datetime_format = True)
    return deaths, recovered, confirmed

In [94]:
def create_data_frame(dataframe,country):

    if dataframe == 'deaths':
        data = pd.DataFrame(index = deaths.index, data = deaths[country].values, columns = ["Total"])

    elif dataframe == 'recovered':
        data = pd.DataFrame(index = recovered.index, data = recovered[country].values, columns = ["Total"])

    elif dataframe == 'confirmed':
        data = pd.DataFrame(index = confirmed.index, data = confirmed[country].values, columns = ["Total"])

    data = data[(data != 0).all(1)]
    
    data['Date'] = data.index
    cols = [data.columns[-1]] + [col for col in data if col != data.columns[-1]]
    data = data[cols]   

    return data

# Graphing Functions

In [95]:
def plot_forecast(data,forecast):
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=data["Date"], y=data["Total"],   
                            mode='lines',
                            name='Up till now '))
        
    fig.add_trace(go.Scatter(x=forecast.index, y=forecast.values,   
                            mode='lines',
                            name='prediction'))

    fig.update_layout(title={
                'text': "Forecasted results",
                'y':0.9,
                'x':0.5,
                'xanchor': 'center',
                'yanchor': 'top'},
                        template = "plotly_dark",
                        xaxis_title="Date",
                        yaxis_title="Cases",
                        legend_title="Legend ",
                        font=dict(
                                family="Arial",
                                size=15,
                                color="white"
                                )
                        )
    return fig

# Functions to Train the Model

In [96]:
def find_params(train_set):
    stepwise_model = auto_arima(train_set,method='nm',start_p = 0, start_q = 0,
                               max_p = 2, max_q = 2, m = 7,
                               start_P = 0,max_P=0,start_Q=1,max_Q=1,seasonal = False,
                               d = None, D = 1, n_jobs=-1,trace = True,
                               error_action = 'ignore',  
                               suppress_warnings = True, 
                               stepwise = True)
    return stepwise_model

In [97]:
def Predict(stepwise_model,train,test):
    
    stepwise_model.fit(train)
    
    pred = stepwise_model.predict(n_periods=len(test))
    
    pred = pd.DataFrame(pred,index = test.index,columns=['Prediction'])
   
    return pred

In [98]:
def pred_to_int(pred):
    L=[]
    for i in  (pred['Prediction'].to_frame())['Prediction']:
        if i%1>=0.5:
            L.append((math.ceil(i)))
        else:
            L.append((math.floor(i)))
    pred['Prediction']=L
    return pred

# Error Functions

In [99]:
def mape(y_true, y_pred): 
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100

# Functions for Forecasting

In [100]:
def Future(df,order,seasonal_order,train,test,data):
    
    model = SARIMAX(df['Total'],  
                        order = order,  
                        seasonal_order = seasonal_order) 
    result = model.fit() 
  
    forecast = result.predict(start = len(df),  
                          end = (len(df)-1) + 14).rename('Forecast') 
    
 
    error_check = result.predict(start = len(train), end = len(train) - 1 +len(test))
    error = mape(error_check,test)
    
 
    graph = plot_forecast(data,forecast)

    return forecast,graph,error

# Calling Function

In [101]:
def arima_predict(df,country):
    deaths, recovered, confirmed = get_data()
    data = create_data_frame(df,country)
    train = data["Total"][:len(data)//2]
    test = data["Total"][len(data)//2:]
    model = find_params(train)
    order=model.get_params()['order']
    seasonal_order=model.get_params()['seasonal_order']
    pred=Predict(model,train,test)
    forecast,graph,error = Future(data,order, seasonal_order, train, test,data)
    return forecast,graph,error



# Example

In [102]:
forecast,graph,error = arima_predict("deaths","Russia")

Performing stepwise search to minimize aic
Fit ARIMA(0,2,0)x(0,0,0,0) [intercept=True]; AIC=786.618, BIC=791.504, Time=0.032 seconds
Fit ARIMA(1,2,0)x(0,0,0,0) [intercept=True]; AIC=781.760, BIC=789.088, Time=0.057 seconds
Fit ARIMA(0,2,1)x(0,0,0,0) [intercept=True]; AIC=765.079, BIC=772.407, Time=0.059 seconds
Fit ARIMA(0,2,0)x(0,0,0,0) [intercept=False]; AIC=784.880, BIC=787.323, Time=0.012 seconds
Fit ARIMA(1,2,1)x(0,0,0,0) [intercept=True]; AIC=785.516, BIC=795.287, Time=0.059 seconds
Fit ARIMA(0,2,2)x(0,0,0,0) [intercept=True]; AIC=765.837, BIC=775.608, Time=0.078 seconds
Fit ARIMA(1,2,2)x(0,0,0,0) [intercept=True]; AIC=789.224, BIC=801.437, Time=0.065 seconds
Total fit time: 0.368 seconds

No frequency information was provided, so inferred frequency D will be used.


No frequency information was provided, so inferred frequency D will be used.



In [103]:
print("PREDICTIONS : \n")
for i in range(len(forecast)):
    print(forecast.index[i] , "---->", int(forecast.values[i]))
  

PREDICTIONS : 

2020-09-10 00:00:00 ----> 18186
2020-09-11 00:00:00 ----> 18292
2020-09-12 00:00:00 ----> 18398
2020-09-13 00:00:00 ----> 18504
2020-09-14 00:00:00 ----> 18610
2020-09-15 00:00:00 ----> 18716
2020-09-16 00:00:00 ----> 18822
2020-09-17 00:00:00 ----> 18928
2020-09-18 00:00:00 ----> 19034
2020-09-19 00:00:00 ----> 19140
2020-09-20 00:00:00 ----> 19246
2020-09-21 00:00:00 ----> 19352
2020-09-22 00:00:00 ----> 19458
2020-09-23 00:00:00 ----> 19564


In [104]:
graph.show()

In [105]:
print("MAPE :",error)

MAPE : 0.23621619105897457
