<a href="https://colab.research.google.com/github/geande/covid-19-predictor/blob/main/LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas
import jax.numpy as np
from jax import random, vmap, grad, jit
from jax.experimental import optimizers
from jax.ops import index_update, index

import itertools
from functools import partial
from tqdm import trange
import numpy.random as npr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata
from scipy.integrate import odeint

In [2]:
# Create lags function, nothing changes here
def create_lags(data, L):
    N = data.shape[0] - L
    D = data.shape[1]
    X = np.zeros((L, N, D))
    Y = np.zeros((N, D))
    for i in range(0,N):
        X = index_update(X, index[:,i,:], data[i:(i+L), :])
        Y = index_update(Y, index[i,:], data[i+L, :])
    return X, Y

# Define our logistic sigmoid which we will use later on
def sigmoid(x):
  return 1 / (1 + np.exp(-x))

In [3]:
class RNN():
  def __init__(self, dataset, num_lags, hidden_dim, rng_key = random.PRNGKey(0)):
    # Normalize across data-points dimension
    self.mean, self.std = dataset.mean(0), dataset.std(0)
    dataset = (dataset - self.mean)/self.std

    # Create the lagged normalized training data
    # X: L x N x D
    # Y: N x D
    self.X, self.Y = create_lags(dataset, num_lags)
    self.X_dim = self.X.shape[-1]
    self.Y_dim = self.Y.shape[-1]
    self.hidden_dim = hidden_dim
    self.num_lags = num_lags

    # Initialization and evaluation functions
    self.net_init, self.net_apply = self.init_RNN()
    
    # Initialize parameters, not committing to a batch shape
    self.net_params = self.net_init(rng_key)
                
    # Use optimizers to set optimizer initialization and update functions
    self.opt_init, \
    self.opt_update, \
    self.get_params = optimizers.adam(1e-1)
    self.opt_state = self.opt_init(self.net_params)

    # Logger to monitor the loss function
    self.loss_log = []
    self.itercount = itertools.count()

  def init_RNN(self):
    # Define init function
    def _init(rng_key):
        # Define methods for initializing the weights
        def glorot_normal(rng_key, size):
          in_dim = size[0]
          out_dim = size[1]
          glorot_stddev = 1. / np.sqrt((in_dim + out_dim) / 2.)
          return glorot_stddev*random.normal(rng_key, (in_dim, out_dim))

        # Inputs
        Uo = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))
        Us = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))
        Ui = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))
        Uf = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))

        # Biases all initialized to 0
        bo = np.zeros(self.hidden_dim)
        bs = np.zeros(self.hidden_dim)
        bi = np.zeros(self.hidden_dim)
        bf = np.zeros(self.hidden_dim)

        # Transition dynamics
        Wo = np.eye(self.hidden_dim)
        Ws = np.eye(self.hidden_dim)
        Wi = np.eye(self.hidden_dim)
        Wf = np.eye(self.hidden_dim)

        # Outputs
        V = glorot_normal(rng_key, (self.hidden_dim, self.Y_dim))
        c = np.zeros(self.Y_dim)

        return (Uo, Us, Ui, Uf, bo, bs, bi, bf, Wo, Ws, Wi, Wf, V, c)
    # Define apply function
    def _apply(params, input):
        Uo, Us, Ui, Uf, bo, bs, bi, bf, Wo, Ws, Wi, Wf, V, c = params
        H = np.zeros((input.shape[1], self.hidden_dim))
        s_t = np.zeros((input.shape[1], self.hidden_dim))
        #s_t = np.zeros((self.hidden_dim, input.shape[1]))
        for i in range(self.num_lags):
          s_t_tilde = np.tanh(np.matmul(H, Ws) + np.matmul(input[i,:,:], Us) + bs)
          f_t = sigmoid(np.matmul(H, Wf) + np.matmul(input[i,:,:], Uf) + bf)
          i_t = sigmoid(np.matmul(H, Wi) + np.matmul(input[i,:,:], Ui) + bi)
          s_t = (f_t * s_t) + (i_t * s_t_tilde)
          o_t = sigmoid(np.matmul(H, Wo) + np.matmul(input[i,:,:], Uo) + bo)
          H = o_t * np.tanh(s_t)       
        H = np.matmul(H, V) + c
        return H
    return _init, _apply

  def loss(self, params, batch):
    X, y = batch
    y_pred = self.net_apply(params, X)
    loss = np.mean((y - y_pred)**2)
    return loss

  # Define a compiled update step
  @partial(jit, static_argnums=(0,))
  def step(self, i, opt_state, batch):
      params = self.get_params(opt_state)
      g = grad(self.loss)(params, batch)
      return self.opt_update(i, g, opt_state)

  def data_stream(self, n, num_batches, batch_size):
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(n)
      for i in range(num_batches):
        batch_idx = perm[i*batch_size:(i+1)*batch_size]
        yield self.X[:, batch_idx, :], self.Y[batch_idx, :]

  def train(self, num_epochs = 100, batch_size = 64):   
    n = self.X.shape[1]
    num_complete_batches, leftover = divmod(n, batch_size)
    num_batches = num_complete_batches + bool(leftover) 
    batches = self.data_stream(n, num_batches, batch_size)
    pbar = trange(num_epochs)
    for epoch in pbar:
      for _ in range(num_batches):
        batch = next(batches)
        self.opt_state = self.step(next(self.itercount), self.opt_state, batch)
      self.net_params = self.get_params(self.opt_state)
      loss_value = self.loss(self.net_params, batch)
      self.loss_log.append(loss_value)
      pbar.set_postfix({'Loss': loss_value})

  @partial(jit, static_argnums=(0,))
  def predict(self, params, inputs):
    Y_pred = self.net_apply(params, inputs)
    return Y_pred



In [None]:
# This will be the training data for the model using Analytical Data from the ODE as validation (beta assigned using MLP)

rng_key = random.PRNGKey(0)
noise = 0.0

dataset = data_Analytical.T
dataset = dataset + dataset.std(0)*noise*random.normal(rng_key, dataset.shape)

# Use 2/3 of all data as training Data
train_sizeL = int(len(dataset) * (4.0/5.0))
train_data = dataset[0:train_sizeL,:]

Determining the Optimal LSTM hyperparameters

In [None]:
def optimization(model2, lags, hidden, epochs, batchsize):

  model2.train(num_epochs = epochs, batch_size = batchsize) # this is our model prediction with 100 epochs total

  opt_params2 = model2.net_params
  # One-step ahead prediction (normalized)
  N, D = dataset.shape
  pred2 = np.zeros((N-lags, D))
  X_tmp =  model2.X[:,0:1,:]

  for i in trange(N-lags):
    pred2 = index_update(pred2, index[i:i+1], model2.net_apply(opt_params2, X_tmp))
    X_tmp = index_update(X_tmp, index[:-1,:,:], X_tmp[1:,:,:])
    X_tmp = index_update(X_tmp, index[-1,:,:], pred2[i])
  # De-normalize predictions
  pred2 = pred2*model2.std + model2.mean
  error2 = np.linalg.norm(dataset[lags:] - pred2, 2)/np.linalg.norm(dataset[lags:], 2)
  prediction_error = np.linalg.norm(dataset[])
  return pred2, error2

In [None]:
epoch_tests = [10, 15, 25, 25, 25, 50, 50] # wide range of epoch testing (10-200)
hidden_tests = [1, 2, 3, 4, 5, 10, 15, 20] # wide range of hidden dimensions (1-20)
lag_tests = [1, 2, 3, 4, 5, 10, 15, 20] # wide range of lags (1-20)
batch_sizes = [16, 32, 64] # wide range of batchsizes (16-64)

hyper_coords = []
preds = []
errors = []
losses = []

counter = 0

for hidden in hidden_tests:
  for lag in lag_tests:
    for batchsize in batch_sizes:
      model2 = RNN(train_data, lag, hidden, rng_key)
      total_epochs = 0
      for epochs in epoch_tests:
        total_epochs += epochs
        pred, error = optimization(model2, lag, hidden, epochs, batchsize)
        hyper_coord = (lag, hidden, total_epochs, batchsize)

        x = model2.loss_log
        hyper_coords.append(hyper_coord)
        preds.append(pred)
        errors.append(error)
        losses.append(x)

In [None]:
def dataExtraction(ind):

  pred2 = preds[ind]
  lags = hyper_coords[ind][0]
  epochs = hyper_coords[ind][2]
  loss_log = losses[ind]

  S_learned = []
  I_learned = []
  R_learned = []
  D_learned = []
  C_learned = []

  for i in range(0, pred2.shape[0]):
    S_learned.append(pred2[i][0])
    I_learned.append(pred2[i][1])
    R_learned.append(pred2[i][2])
    D_learned.append(pred2[i][3])
    C_learned.append(pred2[i][4])

  plt.figure(figsize=(25,8), facecolor="w")

  Sl = np.asarray(S_learned, dtype=np.float32)
  Il = np.asarray(I_learned, dtype=np.float32)
  Rl = np.asarray(R_learned, dtype=np.float32)
  Dl = np.asarray(D_learned, dtype=np.float32)
  Cl = np.asarray(C_learned, dtype=np.float32)

  plt.subplot(1,3,1)
  plt.plot(Il, 'r-.', linewidth = 3, label = "Currently Infected (I) - Prediction")
  plt.plot(Rl, 'r-*', linewidth = 3, label = "Total Recovered (R) - Prediction")
  plt.plot(Dl, 'r-+', linewidth = 3, label = "Total Deceased (D) - Prediction")
  plt.plot(Cl, 'r-', linewidth = 3, label = "Cumulative Caseload (C) - Prediction")

  plt.plot(Ie[lags:], 'b-.', linewidth = 2, label = "Currently Infected (I) - Exact")
  plt.plot(Re[lags:], 'b-*', linewidth = 2, label = "Total Recovered (R) - Exact")
  plt.plot(De[lags:], 'b-+', linewidth = 2, label = "Total Deceased (D) - Exact")
  plt.plot(Ce[lags:], 'b-', linewidth = 2, label = "Cumulative Caseload (C) - Exact")

  plt.plot(CurrentlyInfected_data[train_size + lags:], 'r--', label = 'Currently Infected (I) - Data')
  plt.plot(Recovered_data[train_size + lags:], 'g--', label = 'Recovered (R) - Data')
  plt.plot(Deceased_data[train_size + lags:], 'k--', label = 'Deceased (D) - Data')
  plt.plot(CumulativeCaseload_data[train_size + lags:], 'p--', label = 'Total Caseload (C) - Data')

  plt.title("COVID-19 Dynamics in NJ\n lags, epochs, hidden_dim, batchsize - {hypers}".format(hypers = hyper_coords[ind]))
  plt.xlabel("Time (days)")
  plt.ylabel("Number of Individuals")
  plt.legend()
  plt.axvline(train_sizeL - lags)
  
  plt.subplot(1,3,2)
  plt.plot(Sl, 'r-.', linewidth = 3, label = "Susceptible (S)) - Prediction")
  plt.plot(Se[lags:], 'b-.', linewidth = 2, label = "Susceptible (S) - Exact")
  plt.plot(Susceptible_data[train_size + lags:], 'k-.', linewidth = 1, label = "Susceptible (S) - Data")

  plt.title("COVID-19 Susceptible Population in NJ\n lags, epochs, hidden_dim, batchsize - {hypers}".format(hypers = hyper_coords[ind]))
  plt.xlabel("Time (days)")
  plt.ylabel("Number of Individuals")
  plt.legend()
  plt.axvline(train_sizeL - lags)

  plt.subplot(1,3,3)
  plt.plot(loss_log)
  plt.yscale('log')
  plt.title("Loss over Time\n lags, epochs, hidden_dim, batchsize - {hypers}".format(hypers = hyper_coords[ind]))
  plt.xlabel("Epochs")
  plt.ylabel("Loss")

In [None]:
er = np.asarray(errors)
ordered_er = np.sort(er)

for i in range(0, int(0.1*len(errors))):
  val = ordered_er[i]
  ind = errors.index(val)
  print(i)
  dataExtraction(ind)