In [1]:

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os


In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from sklearn.metrics import r2_score, median_absolute_error, mean_absolute_error
from sklearn.metrics import median_absolute_error, mean_squared_error, mean_squared_log_error

from scipy.optimize import minimize
import statsmodels.tsa.api as smt
import statsmodels.api as sm

from tqdm import tqdm_notebook

from itertools import product

def mean_absolute_percentage_error(y_true, y_pred):
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100

import warnings
warnings.filterwarnings('ignore')

In [3]:
data=pd.read_csv("state_wise_daily.csv")

In [4]:
confirm_data=data[data["Status"]=="Confirmed"]

In [5]:
# Train many SARIMA models to find the best set of parameters
def optimize_SARIMA(series,parameters_list, d, D, s):
    """
        Return dataframe with parameters and corresponding AIC
        
        parameters_list - list with (p, q, P, Q) tuples
        d - integration order
        D - seasonal integration order
        s - length of season
    """
    
    results = []
    best_aic = float('inf')
    
    for param in tqdm_notebook(parameters_list):
        try: model = sm.tsa.statespace.SARIMAX(series, order=(param[0], d, param[1]),
                                               seasonal_order=(param[2], D, param[3], s)).fit(disp=-1)
        except:
            continue
            
        aic = model.aic
        
        #Save best model, AIC and parameters
        if aic < best_aic:
            best_model = model
            best_aic = aic
            best_param = param
        results.append([param, model.aic])
        
    result_table = pd.DataFrame(results)
    result_table.columns = ['parameters', 'aic']
    #Sort in ascending order, lower AIC is better
    result_table = result_table.sort_values(by='aic', ascending=True).reset_index(drop=True)
    
    return result_table


In [33]:
def predict_arima(series):
    
   #Set initial values and some bounds
    ps = range(0, 3)
    d = 1
    qs = range(0, 3)
    Ps = range(0, 3)
    D = 1
    Qs = range(0, 3)
    s = 7

    #Create a list with all possible combinations of parameters
    parameters = product(ps, qs, Ps, Qs)
    parameters_list = list(parameters)
    result_table = optimize_SARIMA(series,parameters_list, d, D, s)
    
    p, q, P, Q = result_table.parameters[0]
    best_model = sm.tsa.statespace.SARIMAX(series, order=(p, d, q),
                                       seasonal_order=(P, D, Q, s)).fit(disp=-1)
    predictions=pd.DataFrame({'Date': pd.date_range(start='20-07-20', periods=31, freq='D', closed='right')})
    predictions["Cases"]=best_model.predict(start=series.shape[0]-10, end=series.shape[0] + 21).reset_index(drop=True)
    return predictions

In [34]:
series

0        0
1        0
2        0
3        0
4        0
5        0
6        0
7        0
8        0
9        0
10       0
11       0
12       0
13       0
14       0
15       0
16       0
17       0
18       0
19       0
20       0
21     307
22     217
23     279
24     439
25     304
26     328
27    1043
28     319
29     341
30     370
31     448
32    -138
33   -1295
34     183
35     126
36     352
37    -370
38    -100
39    -184
40    -554
41    -281
42       0
43       0
44       0
45       0
46       0
47       0
48       0
49       0
50       0
Name: UN, dtype: int64

In [35]:
colmns=confirm_data.columns
colmns=colmns[3:41]

In [36]:
confirm_data.shape

(51, 41)

In [37]:
confirm_data=confirm_data[confirm_data["Date"]>="2020-06-01"]

In [38]:
confirm_data.shape

(51, 41)

In [39]:
final_predict=pd.DataFrame(columns=["State","Date",'Cases'])

for i in colmns:
    series=confirm_data[i]
    series=series.reset_index(drop=True)
    predict1= predict_arima(series)
    predict1["State"]=i
    cols=["State","Date",'Cases']
    predict1=predict1[cols]
    final_predict=final_predict.append(predict1)


HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=81.0), HTML(value='')))




In [40]:
final_predict

Unnamed: 0,State,Date,Cases
0,AN,2020-07-21,8.528091
1,AN,2020-07-22,11.886967
2,AN,2020-07-23,5.200128
3,AN,2020-07-24,19.801460
4,AN,2020-07-25,3.065276
...,...,...,...
25,UN,2020-08-15,29.082602
26,UN,2020-08-16,-85.057850
27,UN,2020-08-17,-295.380244
28,UN,2020-08-18,50.183144


In [41]:
final_predict.to_csv("pred_30_july_arima.csv")