### Training RNN using World's data

This is based on `COVID-19 growth prediction using multivariate
long short term memory` by `Novanto Yudistira`

https://arxiv.org/pdf/2005.04809.pdf

https://github.com/VICS-CORE/lstmcorona/blob/master/lstm.py

- We've aligned all countries' inputs rather than taking an absolute timeline. We start when number of confirmed cases in the country has crossed 100.
- We've normalised data by dividing by a population factor. That way the network can learn what factor of population will be affected.
- Rather than using the entire timeline as an input as suggested by NYudistira, we're training a fixed window (e.g. 20 days) so that the model learns to predict the future by looking at present data. The problem with fixed window approach is that some countries have peaked, while others have not. Also few countries start early, and some start late.

#### Ideas
- One idea is to train a network to predict SIR buckets
- Another is to train only with most populous countries

In [None]:
import pandas as pd
import numpy as np
import requests as rq
import datetime as dt
import torch

tnn = torch.nn
top = torch.optim
from torch.utils import data as tdt

In [None]:
CUDA="cuda:0"
CPU="cpu"
device = torch.device(CUDA if torch.cuda.is_available() else CPU)
print(device)

### Read OWID data

In [None]:
!head -n1 csv/owid-covid-data.csv

In [None]:
cols = ['location', 'date', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths', 'population']
dates = ['date']
df = pd.read_csv("csv/owid-covid-data.csv", 
                 usecols=cols,
                 parse_dates=dates)
df.sample()

### Prepare dataset

In [None]:
IP_SEQ_LEN = 20
OP_SEQ_LEN = 10
VAL_RATIO = 0.3

ip_trn = []
op_trn = []

countries = df['location'].unique()
c = 0
for country in countries:
    if country in ['World', 'International']: # Countries to be skipped
        continue
    country_df = df.loc[df.location == country]
    tot_cases_gt_100 = (country_df['total_cases'] >= 100)
    country_df = country_df.loc[tot_cases_gt_100]
    
    if len(country_df) >= IP_SEQ_LEN + OP_SEQ_LEN:
        c += 1
        pop = country_df['population'].iloc[0]
        print(c, country, len(country_df), pop)
        daily_cases = np.array(country_df['new_cases'].rolling(7, center=True, min_periods=1).mean() * 1000 / pop, dtype=np.float32)

        if country in ['India']: # Countries to be tested. Not included in training data.
            continue

        for i in range(len(country_df) - IP_SEQ_LEN - OP_SEQ_LEN + 1):
            ip_trn.append(daily_cases[i : i+IP_SEQ_LEN])
            op_trn.append(daily_cases[i+IP_SEQ_LEN : i+IP_SEQ_LEN+OP_SEQ_LEN])

ip_trn = torch.from_numpy(np.array(ip_trn, dtype=np.float32))
op_trn = torch.from_numpy(np.array(op_trn, dtype=np.float32))
dataset = tdt.TensorDataset(ip_trn, op_trn)

val_len = int(VAL_RATIO * len(dataset))
trn_len = len(dataset) - val_len
trn_set, val_set = tdt.random_split(dataset, (trn_len, val_len))
print("Training data:", trn_len, "Validation data:", val_len)

trn_loader = tdt.DataLoader(trn_set, shuffle=True, batch_size=1)
val_loader = tdt.DataLoader(val_set, shuffle=True, batch_size=1)

### LSTM

In [None]:
class YudistirNet(tnn.Module):
    def __init__(self, ip_seq_len=1, op_seq_len=1, hidden_size=1, num_layers=1):
        super(YudistirNet, self).__init__()
        
        self.ip_seq_len = ip_seq_len
        self.op_seq_len = op_seq_len
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = tnn.LSTM(input_size=1, hidden_size=self.hidden_size, num_layers=self.num_layers)
        self.linear = tnn.Linear(self.hidden_size * self.ip_seq_len, self.op_seq_len)
        self.sigmoid = tnn.Sigmoid()
    
    def forward(self, ip):
        lstm_out, _ = self.lstm(ip)
        linear_out = self.linear(lstm_out.view(self.hidden_size * self.ip_seq_len))
        sigmoid_out = self.sigmoid(linear_out.view(self.op_seq_len))
        return sigmoid_out
    
    def predict(self, ip):
        with torch.no_grad():
            preds = self.forward(ip)
        return preds

### Checkpoint

In [None]:
def save_checkpoint(epoch, model, optimizer, trn_losses, val_losses):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'trn_losses': trn_losses,
        'val_losses': val_losses
    }, "latest.pt")
    print("Checkpoint saved")
    
def load_checkpoint():
    cp = torch.load("latest.pt")
    print("Checkpoint loaded")
    return cp['epoch'], cp['model_state_dict'], cp['optimizer_state_dict'], cp['trn_losses'], cp['val_losses']

### Train

In [None]:
HIDDEN_SIZE = 1
NUM_LAYERS = 1
LEARNING_RATE = 0.01
NUM_EPOCHS = 31

model = YudistirNet(ip_seq_len=IP_SEQ_LEN, op_seq_len=OP_SEQ_LEN, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS)
model = model.to(device)

loss_fn = tnn.MSELoss()
optimizer = top.Adam(model.parameters(), lr=LEARNING_RATE)

trn_loss_vals = []
val_loss_vals = []
e = 0

resume = False
if resume:
    e, model_dict, optimizer_dict, trn_loss_vals, val_loss_vals = load_checkpoint()
    e+=1
    model.load_state_dict(model_dict)
    optimizer.load_state_dict(optimizer_dict)

# TRAIN
print("BEGIN: [", dt.datetime.now(), "]")
while e < NUM_EPOCHS:
    model.train()
    trn_losses = []
    for data in trn_loader:
        ip, op = data
        ip = ip.to(device)
        op = op.to(device)
        optimizer.zero_grad() # set grads to 0
        preds = model(ip.view(IP_SEQ_LEN, 1, 1)) # predict
        loss = loss_fn(preds, op.view(OP_SEQ_LEN)) # calc loss
        loss.backward() # calc and assign grads
        optimizer.step() # update weights
        trn_losses.append(loss) # logging
    avg_trn_loss = torch.stack(trn_losses).mean().item() * 10000
    trn_loss_vals.append(avg_trn_loss)
    
    model.eval()
    with torch.no_grad():
        val_losses = []
        for data in val_loader:
            ip, op = data
            ip = ip.to(device)
            op = op.to(device)
            preds = model(ip.view(IP_SEQ_LEN, 1, 1))
            loss = loss_fn(preds, op.view(OP_SEQ_LEN))
            val_losses.append(loss)
        avg_val_loss = torch.stack(val_losses).mean().item() * 10000
        val_loss_vals.append(avg_val_loss)
    
    if e%10==0:
        print("[", dt.datetime.now(), "] epoch:", f"{e:3}", "avg_val_loss:", f"{avg_val_loss: .5f}", "avg_trn_loss:", f"{avg_trn_loss: .5f}")
        save_checkpoint(e, model, optimizer, trn_loss_vals, val_loss_vals)
    e+=1

print("END: [", dt.datetime.now(), "]")

df_trn_loss = pd.DataFrame({
    'trn_loss': trn_loss_vals,
    'val_loss': val_loss_vals
})
_ = df_trn_loss.plot(
    y=['trn_loss', 'val_loss'],
    title=['Training loss per epoch', 'Validation loss per epoch'],
    subplots=True,
    figsize=(5,8),
    sharex=False,
    logy=True
)

### Load saved model

In [None]:
HIDDEN_SIZE = 1
NUM_LAYERS = 1
_, md, _, _, _ = load_checkpoint()
model = YudistirNet(ip_seq_len=IP_SEQ_LEN, op_seq_len=OP_SEQ_LEN, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS)
model = model.to(device)
model.load_state_dict(md)

### Evalute fit

In [None]:
c = "Brazil"
pop_fct = df.loc[df.location==c, 'population'].iloc[0] / 1000

all_preds = []
pred_vals = []
out_vals = []

test_data = np.array(df.loc[(df.location==c) & (df.total_cases>=100), 'new_cases'].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)

for i in range(len(test_data) - IP_SEQ_LEN - OP_SEQ_LEN + 1):
    ip = torch.tensor(test_data[i : i+IP_SEQ_LEN])
    op = torch.tensor(test_data[i+IP_SEQ_LEN : i+IP_SEQ_LEN+OP_SEQ_LEN])
    ip = ip.to(device)
    op = op.to(device)

    pred = model.predict(ip.view(IP_SEQ_LEN, 1, 1))    
#     if i==0: # prepend first input
#         pred_vals.extend(ip.view(IP_SEQ_LEN).numpy() * pop_fct)
#         out_vals.extend(ip.view(IP_SEQ_LEN).numpy() * pop_fct)        
    all_preds.append(pred.view(OP_SEQ_LEN).cpu().numpy() * pop_fct)
    pred_vals.append(pred.view(OP_SEQ_LEN).cpu().numpy()[0] * pop_fct)
    out_vals.append(op.view(OP_SEQ_LEN).cpu().numpy()[0] * pop_fct)

# last N-1 values
out_vals.extend(op.view(OP_SEQ_LEN).cpu().numpy()[1:] * pop_fct)
pred_vals.extend(([np.NaN] * OP_SEQ_LEN)[1:]) # pad with NaN

cmp_df = pd.DataFrame({
    'actual': out_vals,
    'predicted0': pred_vals
})
ax = cmp_df.plot(
    figsize=(20,8),
    lw=3,
    title=c
)

# plot predictions
i=0
for pred in all_preds:
    cmp_df['predicted_cases'] = np.NaN
    cmp_df.loc[i:i+OP_SEQ_LEN-1, 'predicted_cases'] = pred
    cmp_df.plot(y='predicted_cases', ax=ax, legend=False)
    i+=1

### Test (predict)

In [None]:
c = "Brazil"
pop_fct = df.loc[df.location==c, 'population'].iloc[0] / 1000
test_data = np.array(df.loc[(df.location==c) & (df.total_cases>=100), 'new_cases'].rolling(7, center=True, min_periods=1).mean() / pop_fct, dtype=np.float32)

ip = torch.tensor(
    test_data[-IP_SEQ_LEN:],
    dtype=torch.float32
)
ip = ip.to(device)
pred = model.predict(ip.view(IP_SEQ_LEN, 1, 1))
orig_df = pd.DataFrame({
    'actual': test_data * pop_fct
})
fut_df = pd.DataFrame({
    'predicted': pred.cpu().numpy() * pop_fct
})
orig_df = orig_df.append(fut_df, ignore_index=True, sort=False)
_ = orig_df.plot(title=c)