In [1]:
# Standard library imports
import os
import datetime as dt

# Data handling and numerical computations
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
from scipy import interpolate

# PyTorch related imports
import torch
import torch.nn as nn
from torch.distributions import Categorical, Normal, kl_divergence
from torch.profiler import profile, record_function, ProfilerActivity
from torchdiffeq import odeint

# Visualization library
import matplotlib.pyplot as plt

# Utilities and custom modules
from itertools import chain
import lib.utils as utils
import lib.models as models
import lib.train_functions as train_functions
import lib.encoders as encoders
from lib.HHS_data import *
import tqdm

# Setting the number of threads for PyTorch and specifying the device
torch.set_num_threads(1)

# root = 'checkpoints/HHS_SIR_Big/'   
# enc.load_state_dict(torch.load(root+'enc_' + '.pth'))
# ode.load_state_dict(torch.load(root+'sir_' + '.pth'))
# dec.load_state_dict(torch.load(root+'dec_' + '.pth'))

In [2]:
def eval(x_in, y_in, t, n_samples = 128, dtype = torch.float32):
    batch_size = x_in.shape[0]
    eps = torch.randn(n_samples, batch_size, n_regions, latent_dim-1, dtype=dtype, device=device)
    ode.clear_tracking()
    mean, std = enc(x_in)
    z = reparam(eps, std, mean, n_samples, batch_size)
    latent = odeint(ode, z, t, method='rk4', options=dict(step_size = 1.0))
    y_pred = dec(latent[..., :3]).reshape((t.shape[0], n_samples, batch_size, n_regions)).permute(2,1,0,3)

    nll = train_functions.nll_loss(y_pred, y_in).detach().cpu().numpy()
    return nll

In [3]:
# Encoder_BiDirectionalGRU(n_regions, 
#                          n_qs=n_qs,
#                          latent_dim = latent_dim-1,    
#                          q_sizes=q_sizes, 
#                          ili_sizes=ili_sizes, 
#                          ff_sizes = ff_sizes, 
#                          SIR_scaler = SIR_scaler, 
#                          device=device, 
#                          dtype=torch.float32)

# enc = Encoder_MISO_GRU(n_regions = n_regions,
#                        n_qs=n_qs,
#                        latent_dim=latent_dim-1, 
#                        q_sizes=q_sizes, 
#                        ili_sizes=ili_sizes, 
#                        ff_sizes = ff_sizes, 
#                        SIR_scaler = SIR_scaler,
#                        device=device, 
#                        dtype=torch.float32)

# enc = Encoder_Back_GRU(n_regions=n_regions, 
#                        input_size=n_qs+1, 
#                        latent_dim = latent_dim-1, 
#                        q_sizes=q_sizes, 
#                        ili_sizes=ili_sizes, 
#                        ff_sizes = ff_sizes, 
#                        SIR_scaler = SIR_scaler, 
#                        device=device, 
#                        dtype=torch.float32)

**Get prior working better - choose sigma1 and sigma2:**
- [0.1, 0.01]
- [0.05, 0.005]
  
**Encoder Architecture:**
- GRU MISO Bidirectional
- GRU SISO Backwards
- GRU Bidirectional
  
**Prior Sigma1/2:**
- [0.1, 0.01]
- [0.05, 0.005]
  
**ODE:**
- HHS SIR
- HHS SIR Fa
  
**Decoder:**
- Single Layer FF
  
**Epochs:**
- 1000
  
**Latent Dim:**
- 6

In [4]:
# Variable
n_qs = 5
window = 42
latent_dim = 6
batch_size = 32
means=[0.8, 0.55]
stds = [0.2, 0.2]
ff_sizes = [64,32]
ili_sizes = [32, 16]
SIR_scaler = [0.1, 0.05, 1.0]
q_sizes=[128, 64]

encoder_model = encoders.Encoder_BiDirectionalGRU
# encoder_model = encoders.Encoder_MISO_GRU
# encoder_model = encoders.Encoder_Back_GRU

lag = 14
n_regions = 10
season = 2016
lr = 1e-3
n_samples = 128
epochs = 1000

root = 'checkpoints/HHS_SIR_Big_new/'      
device = 'cpu'
dtype=torch.float32

tmax = 8

if encoder_model == encoders.Encoder_Back_GRU:   
    gamma = 28
    t = torch.linspace(1,gamma+window, gamma+window, device=device)/7
    
else:
    gamma = 63
    t = torch.linspace(1,gamma, gamma, device=device)/7
eval_pts = [0,6,13,20,27,34,40,47,54][:tmax]

In [5]:
ili = load_ili('hhs')
ili = intepolate_ili(ili)

hhs_dict = {}
qs_dict = {}

ignore = ['AZ', 'ND', 'AL', 'RI', 'VI', 'PR']
for i in range(1,1+n_regions):
    hhs_dict[i] = get_hhs_query_data(i, ignore=ignore, smooth_after = True)
    qs_dict[i] = choose_qs(hhs_dict, ili, i, season, n_qs)

    hhs_dict[i] = hhs_dict[i].loc[:, list(qs_dict[i])]
    hhs_dict[i] = hhs_dict[i].div(hhs_dict[i].max())
    
ili = ili.loc[hhs_dict[i].index[0] : hhs_dict[i].index[-1]]
ili = ili.div(ili.max())



In [6]:
run_backward = False
if encoder_model == encoders.Encoder_Back_GRU:
    run_backward = True

inputs = []
outputs = []
for batch in range(ili.shape[0] - (window+gamma)):
    batch_inputs = []
    for i in range(1,11):
        batch_inputs.append(hhs_dict[i].iloc[batch:batch+window])
    
    t_ili = ili.iloc[batch:batch+window].copy()
    t_ili.iloc[-lag:, :] = -1
    batch_inputs.append(t_ili)
    batch_inputs = np.concatenate(batch_inputs, -1)

    if run_backward:
        gamma = 28
        batch_outputs = ili.iloc[batch:batch+window-lag+gamma].values
        t = torch.linspace(1, batch_outputs.shape[0], batch_outputs.shape[0])/7
    else:
        gamma = 56
        batch_outputs = ili.iloc[batch+window-lag:batch+window-lag+gamma].values
        t = torch.linspace(1, batch_outputs.shape[0], batch_outputs.shape[0])/7
        
    inputs.append(batch_inputs)
    outputs.append(batch_outputs)
inputs = torch.tensor(np.asarray(inputs), dtype=torch.float32)
outputs = torch.tensor(np.asarray(outputs), dtype=torch.float32)

In [7]:
enc = encoder_model(n_regions, 
             n_qs=n_qs,
             latent_dim = latent_dim-1,    
             q_sizes=q_sizes, 
             ili_sizes=ili_sizes, 
             ff_sizes = ff_sizes, 
             SIR_scaler = SIR_scaler, 
             device=device, 
             dtype=torch.float32)

ode = models.Fp(n_regions, latent_dim, nhidden=64)
dec = models.Decoder(n_regions, 3, 1, device=device)

enc.to(device)
ode.to(device)
dec.to(device)

num = np.sum([np.prod(_.shape) for _ in list(enc.parameters())])
print('encoder parameters:', num)

num = np.sum([np.prod(_.shape) for _ in list(ode.parameters())])
print('ode parameters:', num)

num = np.sum([np.prod(_.shape) for _ in list(dec.parameters())])
print('decoder parameters:', num)

batch_size = 32
new_inputs = torch.tensor(np.asarray(inputs), dtype=torch.float32).to(device)
new_outputs = torch.tensor(np.asarray(outputs), dtype=torch.float32).to(device)

train_size = len(new_inputs) - 365
x_tr, y_tr = new_inputs[:train_size], new_outputs[:train_size]
x_test, y_test = new_inputs[train_size:], new_outputs[train_size:]

# batch it all 
x_train = []
y_train = []
for b in range(int(np.ceil(x_tr.shape[0]/batch_size))):
    x_train.append(torch.tensor(x_tr[b*batch_size:(b+1)*batch_size], dtype=torch.float32))
    y_train.append(torch.tensor(y_tr[b*batch_size:(b+1)*batch_size], dtype=torch.float32))

encoder parameters: 283172
ode parameters: 9364
decoder parameters: 310


  x_train.append(torch.tensor(x_tr[b*batch_size:(b+1)*batch_size], dtype=torch.float32))
  y_train.append(torch.tensor(y_tr[b*batch_size:(b+1)*batch_size], dtype=torch.float32))


In [8]:
optimizer = torch.optim.Adam(enc.parameters(), lr=lr)
for epoch in range(3):
    kls = 0
    pbar = tqdm.tqdm(x_train)
    num = 0
    for x_tr in pbar:
        optimizer.zero_grad()
        
        mean, std = enc(x_tr)
        prior = encoders.make_prior(mean, latent_dim=latent_dim, device=device)
        kl = kl_divergence(Normal(mean, std), prior).mean(0).sum()
        if torch.isnan(kl):
            break
        kl.backward()
        optimizer.step()
        kls += kl.cpu().detach().numpy()
        num += 1
        pbar.set_postfix({'Epoch':epoch, 'KL_z':kls/num})

100%|██████████| 176/176 [00:10<00:00, 16.85it/s, Epoch=0, KL_z=9.73]


In [None]:
optimizer = torch.optim.Adam(chain(enc.parameters(), ode.parameters(), dec.parameters()), lr=lr)
_history = train_functions.history()

for epoch in range(epochs):
    pbar = tqdm.tqdm(zip(x_train, y_train))
    for x_tr, y_tr in pbar:
        batch_size = x_tr.shape[0]
        eps = torch.randn(n_samples, batch_size, n_regions, latent_dim-1, dtype=dtype, device=device)
        ode.clear_tracking()
        optimizer.zero_grad()
        
        mean, std = enc(x_tr)
        z = encoders.reparam(eps, std, mean, n_samples, batch_size)
        latent = odeint(ode, z, t, method='rk4', options=dict(step_size = 1.0))
        y_pred = dec(latent[..., :3]).reshape((-1, n_samples, batch_size, n_regions)).permute(2,1,0,3)

        # nll = train_functions.nll_loss(y_pred, y_tr[:, eval_pts, :])
        nll = train_functions.nll_loss(y_pred, y_tr)
        kl_p = train_functions.get_kl_params(1, ode.posterior(), means=means, stds = stds,limit = 1e6, device=device)
        kl_z = kl_divergence(encoders.make_prior(mean, latent_dim=latent_dim, device=device), Normal(mean, std)).sum(-1).mean() / len(x_train)
        reg_loss = train_functions.latent_init_loss(latent[..., :3])

        loss = nll+kl_p+kl_z+reg_loss
        loss.backward()
        optimizer.step()
        _history.batch([loss.cpu(), nll.cpu(), kl_z.cpu(),kl_p.cpu(),reg_loss.cpu(), optimizer.param_groups[-1]['lr']], ['loss', 'nll', 'kl_latent', 'kl_params', 'reg_loss', 'lr'])
        pbar.set_postfix(_history.epoch())
    _history.reset()
        
    utils.update_learning_rate(optimizer, 0.999, lr/10)



176it [03:09,  1.07s/it, loss=786, nll=5.86, kl_latent=4.83, kl_params=5.93, reg_loss=770, lr=0.001]         
176it [03:15,  1.11s/it, loss=1.99, nll=0.54, kl_latent=0.0255, kl_params=1.3, reg_loss=0.12, lr=0.000999]    
176it [03:33,  1.21s/it, loss=0.826, nll=0.0129, kl_latent=0.0242, kl_params=0.742, reg_loss=0.0464, lr=0.000998]  
176it [03:47,  1.29s/it, loss=0.446, nll=-.24, kl_latent=0.0184, kl_params=0.571, reg_loss=0.0974, lr=0.000997] 
176it [03:35,  1.22s/it, loss=0.133, nll=-.411, kl_latent=0.0173, kl_params=0.507, reg_loss=0.0202, lr=0.000996]  
176it [03:35,  1.23s/it, loss=-.00745, nll=-.499, kl_latent=0.0164, kl_params=0.461, reg_loss=0.0134, lr=0.000995] 
176it [03:57,  1.35s/it, loss=0.0261, nll=-.461, kl_latent=0.0151, kl_params=0.412, reg_loss=0.0597, lr=0.000994] 
176it [04:06,  1.40s/it, loss=-.181, nll=-.559, kl_latent=0.0159, kl_params=0.351, reg_loss=0.0114, lr=0.000993] 
176it [04:19,  1.47s/it, loss=-.268, nll=-.587, kl_latent=0.0142, kl_params=0.289, reg_los

In [None]:
eval(x_test, y_test, t, n_samples = 128, dtype = torch.float32)

In [None]:
# enc.load_state_dict(torch.load(root+'enc_' + '.pth'))
# ode.load_state_dict(torch.load(root+'sir_' + '.pth'))
# dec.load_state_dict(torch.load(root+'dec_' + '.pth'))