In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import datetime
# Import Statsmodels
from statsmodels.tsa.api import VAR
from statsmodels.tsa.stattools import adfuller
from statsmodels.tools.eval_measures import rmse, aic

In [2]:
def parser(x):
    return datetime.datetime.strptime(x+'20', '%m/%d/%Y')

In [3]:
def avg_relative_error(y, p):
    error = 0
    er = []
    for i in range(len(y)):
        if y[i] != 0 and p[i] != 0:
            e = (((p[i] - y[i]) / max(y[i], p[i])))
            er.append(e)
            error += e
    error = error / len(y)
    return abs(error)

In [5]:
countries = ['US','Australia','Austria', 'Brazil', 'China', 'Czechia', 'Denmark', 'France', 'Germany', 'India', 'Italy', 'Korea_South', 'South_Africa', 'Spain', 'UK']
results = []
for country in countries:
    result = {'country':'', 'conf': 0, 'dead': 0, 'recov': 0}
    result['country'] = country
    file = '../Data_Preprocessing/Clean/' + country + '.csv'
    df = pd.read_csv(file, index_col = ['Date'], parse_dates = ['Date'], date_parser = parser, squeeze = True)
    df = df.drop(['Country/Region', 'Online_School', 'Social_Distancing', 'Travel_Ban', 'Army_Deployed'], axis = 1)
    train = df[:int(0.8*(len(df)))]
    valid = df[int(0.8*(len(df))):]
    model = VAR(endog=df)
    model_fit = model.fit()
    prediction = model_fit.forecast(model_fit.y, steps=len(valid))
    preds = pd.DataFrame(prediction, columns=['Confirmed_Cases', 'Recoveries', 'Deaths'], index=df.index[-len(valid):])
    result['conf'] = avg_relative_error(preds['Confirmed_Cases'].values, valid['Confirmed_Cases'].values)
    result['recov'] = avg_relative_error(preds['Recoveries'].values, valid['Recoveries'].values)
    result['dead'] = avg_relative_error(preds['Deaths'].values, valid['Deaths'].values)
    print(result)
    results.append(result)
results



{'country': 'US', 'conf': 0.10601946277437058, 'dead': 0.2552877196366616, 'recov': 0.1157967955989258}
{'country': 'Australia', 'conf': 0.5245396345074986, 'dead': 0.5069902199211461, 'recov': 0.2700111632589966}
{'country': 'Austria', 'conf': 0.4057279131458995, 'dead': 0.07340449503050571, 'recov': 0.060408610662115766}
{'country': 'Brazil', 'conf': 0.06046913587033022, 'dead': 0.017081875368857483, 'recov': 0.09259531130340133}
{'country': 'China', 'conf': 0.931525804988715, 'dead': 0.8551939068966494, 'recov': 0.5764020383330741}




{'country': 'Czechia', 'conf': 1.2582013969283503, 'dead': 0.30977477581145685, 'recov': 0.30139343338514163}
{'country': 'Denmark', 'conf': 0.23364711621912035, 'dead': 0.24242042784659087, 'recov': 0.03829340447496179}
{'country': 'France', 'conf': 0.21803245412126263, 'dead': 0.269923992570582, 'recov': 0.32637829498311494}
{'country': 'Germany', 'conf': 0.4092431237787822, 'dead': 0.5354949539962921, 'recov': 0.6050866703344022}




{'country': 'India', 'conf': 0.5402730896908006, 'dead': 0.39695892776735703, 'recov': 0.1284376089145622}
{'country': 'Italy', 'conf': 0.08506323662217798, 'dead': 0.19285530192910705, 'recov': 0.1897474524492156}
{'country': 'Korea_South', 'conf': 0.09964780591780928, 'dead': 0.44268625189069266, 'recov': 0.4472155197115772}
{'country': 'South_Africa', 'conf': 0.453168142883598, 'dead': 0.5031490075259724, 'recov': 0.19532228332383336}
{'country': 'Spain', 'conf': 0.5222697718272664, 'dead': 0.5070944635739084, 'recov': 0.363028848047085}
{'country': 'UK', 'conf': 0.5454838202642582, 'dead': 0.6038254573942733, 'recov': 0.5553429359645715}




[{'country': 'US',
  'conf': 0.10601946277437058,
  'dead': 0.2552877196366616,
  'recov': 0.1157967955989258},
 {'country': 'Australia',
  'conf': 0.5245396345074986,
  'dead': 0.5069902199211461,
  'recov': 0.2700111632589966},
 {'country': 'Austria',
  'conf': 0.4057279131458995,
  'dead': 0.07340449503050571,
  'recov': 0.060408610662115766},
 {'country': 'Brazil',
  'conf': 0.06046913587033022,
  'dead': 0.017081875368857483,
  'recov': 0.09259531130340133},
 {'country': 'China',
  'conf': 0.931525804988715,
  'dead': 0.8551939068966494,
  'recov': 0.5764020383330741},
 {'country': 'Czechia',
  'conf': 1.2582013969283503,
  'dead': 0.30977477581145685,
  'recov': 0.30139343338514163},
 {'country': 'Denmark',
  'conf': 0.23364711621912035,
  'dead': 0.24242042784659087,
  'recov': 0.03829340447496179},
 {'country': 'France',
  'conf': 0.21803245412126263,
  'dead': 0.269923992570582,
  'recov': 0.32637829498311494},
 {'country': 'Germany',
  'conf': 0.4092431237787822,
  'dead': 0.