### TODO 
1. Test for all countries and add country specific attributes. 
2. Ablation study with **pd.fillna()** and **pd.notna()**. 
3. Ablation study with and w/o growth_rate. 

### Common libraries 

In [1]:
import torch.nn.functional as F
import torch 
import torch.optim as optim
import sys
import torch.nn as nn
import pandas as pd 
import os
import numpy as np
from dateutil.parser import parse 

def dateConvertor(date):
    dt = parse(date)
    date = dt.strftime('%Y-%m-%d')
    return date

### Reading data (for all country)
TODO: 
1. Data cleaning has to be done efficiently. In following code, simply removing Jan, Feb and Dec data and linearly interpolating all NAN. 
2. If we don't have history, then following program will exit (sys.exit call). Replace this with try and catch.

In [2]:
country_codes = ['ABW','AFG','AGO','ALB','AND','ARE','ARG','AUS','AUT','AZE','BDI','BEL','BEN','BFA','BGD','BGR','BHR','BHS','BIH','BLR','BLZ','BMU','BOL','BRA','BRB','BRN','BTN','BWA','CAF','CAN','CHE','CHL','CHN','CIV','CMR','COD','COG','COL','COM','CPV','CRI','CUB','CYP','CZE','DEU','DJI','DMA','DNK','DOM','DZA','ECU','EGY','ERI','ESP','EST','ETH','FIN','FJI','FRA','FRO','GAB','GBR','GEO','GHA','GIN','GMB','GRC','GRL','GTM','GUM','GUY','HKG','HND','HRV','HTI','HUN','IDN','IND','IRL','IRN','IRQ','ISL','ISR','ITA','JAM','JOR','JPN','KAZ','KEN','KGZ','KHM','KOR','KWT','LAO','LBN','LBR','LBY','LKA','LSO','LTU','LUX','LVA','MAC','MAR','MCO','MDA','MDG','MEX','MLI','MMR','MNG','MOZ','MRT','MUS','MWI','MYS','NAM','NER','NGA','NIC','NLD','NOR','NPL','NZL','OMN','PAK','PAN','PER','PHL','PNG','POL','PRI','PRT','PRY','PSE','QAT','RKS','ROU','RUS','RWA','SAU','SDN','SEN','SGP','SLB','SLE','SLV','SMR','SOM','SRB','SSD','SUR','SVK','SVN','SWE','SWZ','SYC','SYR','TCD','TGO','THA','TJK','TKM','TLS','TTO','TUN','TUR','TWN','TZA','UGA','UKR','URY','USA','UZB','VEN','VIR','VNM','VUT','YEM','ZAF','ZMB','ZWE']
filenames = ["c1_school_closing.csv", "c2_workplace_closing.csv", "c3_cancel_public_events.csv", "c4_restrictions_on_gatherings.csv", "c5_close_public_transport.csv", "c6_stay_at_home_requirements.csv", "c7_movementrestrictions.csv", "c8_internationaltravel.csv", "confirmed_cases.csv"]

country_code2id = {}
for i in range(len(country_codes)):
    country_code2id[country_codes[i]] = i 

# date extraction
npi_date = pd.DataFrame({})
npi_df = pd.read_csv(os.path.join('timeseries', filenames[0])).T
npi_date['Date'] = npi_df.index.values[3:]
npi_date['Date'] = npi_date['Date'].apply(dateConvertor)

# extract data 
dataframes = {} 
countries_to_extract = ['ITA','IND','USA'] # countries code for which you want data. 
index = [country_code2id[code] for code in countries_to_extract]
for file in filenames:
    npi_df = pd.read_csv(os.path.join('timeseries', file)).T[3:]
    npi_df['Date'] = npi_date['Date'].values
    npi_df.set_index('Date', drop=True, inplace=True)
    npi_df = npi_df[index] # selecting countries 
    npi_df = npi_df[64:335] # removing Jan, Feb and Dec data
    for col in npi_df:
        npi_df[col] = pd.to_numeric(npi_df[col], errors='coerce') # converting object to numeric 
    npi_df.interpolate(method='linear', inplace=True) # interpolate missing values 
    dataframes[file[:-4]] = npi_df
    
    # computing growth rate
    if(file[:-4]=='confirmed_cases'):
        npi_df = pd.read_csv(os.path.join('timeseries', file)).T[3:]
        npi_df['Date'] = npi_date['Date'].values
        npi_df.set_index('Date', drop=True, inplace=True)
        npi_df = npi_df[index]
        npi_df = npi_df[64:335]
        for col in npi_df:
            npi_df[col] = pd.to_numeric(npi_df[col], errors='coerce')
        npi_df = 100*npi_df.diff()/npi_df
        npi_df.interpolate(method='linear', inplace=True) # interpolate missing values     
        dataframes['growth_rate'] = npi_df.rolling(3).mean()

def readData(attributes, history, date):
    index = dataframes['c1_school_closing'].index.get_loc(date)
    if(history>index):
        print('Not sufficient history')
        sys.exit()
    data = []
    for att in attributes:
        temp = dataframes[att].iloc[index-history:index].values
        if(len(data)==0):
            data = np.asarray(temp)
        else:
            data = np.dstack((data, temp))
    x = torch.from_numpy(data).to(dtype=torch.double)
    y = torch.from_numpy(dataframes['growth_rate'].iloc[index].values).to(dtype=torch.double)
    return x,y

training_attributes = ['c1_school_closing', 'c2_workplace_closing', 'c3_cancel_public_events', 'c5_close_public_transport']
x,y = readData(attributes=training_attributes, history=30, date='2020-07-15')
print('days | country | attributes = ', x.shape, y.shape)
print('country | days | attributes = ', x.permute(1,0,2).shape)

days | country | attributes =  torch.Size([30, 3, 4]) torch.Size([3])
country | days | attributes =  torch.Size([3, 30, 4])


### Reading data (for single country)

In [None]:
filenames = ["c1_school_closing.csv", "c2_workplace_closing.csv", "c3_cancel_public_events.csv",
            "c4_restrictions_on_gatherings.csv", "c5_close_public_transport.csv", "c6_stay_at_home_requirements.csv", "c7_movementrestrictions.csv",
            "c8_internationaltravel.csv", "confirmed_cases.csv"]
# filenames = np.core.defchararray.add('timeseries/', np.asarray(filenames))
npi_data = pd.DataFrame({})

# date extraction 
file = filenames[0]
npi_df = pd.read_csv(os.path.join('timeseries', file))
npi_df = npi_df[npi_df['country_name']=='India'].iloc[:,3:].T
npi_data['Date'] = npi_df[77].index.values
npi_data['Index'] = npi_data['Date'].index.values/100.0

# other attributes extraction 
for file in filenames:
    npi_df = pd.read_csv(os.path.join('timeseries', file))
    npi_df = npi_df[npi_df['country_name']=='India'].iloc[:,3:].T
    npi_data[file[:-4]] = npi_df[77].values

# compute growth rate 
npi_data['growth_rate'] = npi_data['confirmed_cases'].diff()
npi_data['growth_rate'] = 100*npi_data['growth_rate']/npi_data['confirmed_cases'] # (0,1) or (0,100)

# smoothing growth_rate
npi_data['growth_rate'] = npi_data['growth_rate'].rolling(3).mean()

# cleaning df
npi_data = npi_data[64:]
# for col in npi_data.columns:
#     npi_data= npi_data[npi_data[col].notna()]

# interpolating instead of skipping 
npi_data.interpolate(method='linear', inplace=True)

### Baseline linear model 

In [None]:
class LinearModel(torch.nn.Module):
    def __init__(self, in_=300, out_=1):
        super(LinearModel, self).__init__()
        self.linear1 = torch.nn.Linear(in_, 64)
        self.linear2 = torch.nn.Linear(64, 8)
        self.linear3 = torch.nn.Linear(8, 1)
        self.relu = torch.nn.ReLU()
        
    def forward(self, x):
        x = x.reshape(-1)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

### Training baseline model without positional info

1. Add country specific info here while creating dataset for model. 

In [None]:
history = 30

y_total = npi_data['growth_rate'].values
x_total = npi_data.drop(columns=['Date', 'confirmed_cases']).to_numpy(dtype=np.float)

x_total = torch.from_numpy(x_total).to(dtype=torch.double)
y_total = torch.from_numpy(y_total).to(dtype=torch.double)

model = LinearModel(x_total.shape[1]*history).to(dtype=torch.double)
model.apply(init_weights)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
mse_loss = torch.nn.MSELoss()

split_index = 150

x_train = torch.clone(x_total[0:split_index,:])
y_train = torch.clone(y_total[0:split_index])
x_test = torch.clone(x_total[split_index-history:,:])
y_test = torch.clone(y_total[split_index-history:])

print('attributes | ', x_train.shape[1])
print('training | ', len(x_train))
print('test | ', len(x_test)-history)
print('-'*20)

model.eval()

validation_loss = 0.0
x_temp = x_test[0:history,:]
y_pred = model(x_temp)
loss = mse_loss(y_pred, y_test[1])
validation_loss += loss.item()
for i in range(history+1,len(x_test)-1):
    x_temp = x_test[i-history:i,:] # [30,9] --> [1,30*9]
    x_temp[-1,-1] = y_pred.item()
    y_pred = model(x_temp)
    loss = mse_loss(y_pred, y_test[i+1])
    validation_loss += loss.item()
print('validation loss before training %0.2f' %(validation_loss))

model.train()
for epoch in range(200):
    training_loss = 0.0
    x_train = torch.clone(x_total[0:split_index,:])
    y_train = torch.clone(y_total[0:split_index])
    for i in range(history,len(x_train)-1):
        optimizer.zero_grad() # make gradients zero 
        x_temp = x_train[i-history:i,:]
        y_pred = model(x_temp)
        loss = mse_loss(y_pred, y_train[i+1])
        loss.backward() # computing gradients 
        optimizer.step() # updating weights 
        training_loss += loss.item() 
    if((epoch+1)%5 == 0):
        model.eval()
        validation_loss = 0.0
        x_test = torch.clone(x_total[split_index-history:,:])
        y_test = torch.clone(y_total[split_index-history:])
        
        x_temp = x_test[0:history,:]
        y_pred = model(x_temp)
        loss = mse_loss(y_pred, y_test[1])
        validation_loss += loss.item()
        for i in range(history+1,len(x_test)-1):
            x_temp = x_test[i-history:i,:]
            x_temp[-1,-1] = y_pred.item()
            y_pred = model(x_temp)
            loss = mse_loss(y_pred, y_test[i+1])
            validation_loss += loss.item()
        print('epoch %d | training loss %0.2f | validation loss %0.2f'%(epoch, training_loss, validation_loss))
        model.train()        