In [6]:
import torch
import torch.nn as nn
from ode_nn import Seq2Seq, Auto_FC, Transformer, Latent_ODE, Transformer_EncoderOnly, GAT, GCN
from ode_nn import Dataset, train_epoch, eval_epoch, get_lr, Dataset_graph, train_epoch_graph, eval_epoch_graph
import dgl
import numpy as np
import time
from torch.utils import data
import matplotlib.pyplot as plt
import os
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data Loader

In [2]:
input_length = 14
mid = 14
output_length = 7
batch_size = 128

# Directories to the samples of subsequences
train_direc = '/global/cscratch1/sd/rwang2/ODEs/Data/covid_seq/train/sample_'
test_direc = '/global/cscratch1/sd/rwang2/ODEs/Data/covid_seq/test/sample_'

train_indices = list(range(4000))
valid_indices = list(range(4000, 5250))
test_indices = list(range(150))

train_set = Dataset(train_indices, input_length, mid, output_length, train_direc, entire_target = True)
valid_set = Dataset(valid_indices, input_length, mid, output_length, train_direc, entire_target = True)
test_set = Dataset(test_indices, input_length, mid, output_length, test_direc, entire_target = True)

train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True)
valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = False)
test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False)

## Train Deep Sequence Models

In [None]:
name = "..."
model = Auto_FC(input_length = input_length, input_dim = 3, output_dim = 3, hidden_dim = 16, quantile = True).to(device)
# model = Seq2Seq(input_dim = 3, output_dim = 3, hidden_dim = 128, num_layers = 1, quantile = True).to(device)
# model = Transformer(input_dim = 3, output_dim = 3, nhead = 4, d_model = 32, num_layers = 3, dim_feedforward = 64, quantile = True).to(device)
# model = Latent_ODE(latent_dim = 64, obs_dim = 3, nhidden = 128, rhidden = 128, quantile = True, aug = False).to(device).to(device)

####################################
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.95)
loss_fun = nn.MSELoss()
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
train_rmse = []
valid_rmse = []
test_rmse = []
min_rmse = 1

for i in range(1, 200):
    start = time.time()
    scheduler.step()
    model.train()
    train_rmse.append(train_epoch(model, train_loader, optimizer, loss_fun)[-1])#, feed_tgt = True
    model.eval()
    preds, trues, rmse = eval_epoch(model, valid_loader, loss_fun, concat_input = True)
    valid_rmse.append(rmse)
    if valid_rmse[-1] < min_rmse:
        min_rmse = valid_rmse[-1] 
        best_model = model 
        torch.save(best_model, name + ".pth")
    end = time.time()
    if (len(train_rmse) > 30 and np.mean(valid_rmse[-5:]) >= np.mean(valid_rmse[-10:-5])):
            break
    print("Epoch " + str(i) + ": ", "train rmse:", train_rmse[-1], "valid rmse:",valid_rmse[-1], 
              "time:",round((end-start)/60,3), "Learning rate:", format(get_lr(optimizer), "5.2e"))

preds, trues, rmses = eval_epoch(best_model, test_loader, loss_fun, concat_input = False)

torch.save({"preds": preds[:,-7:],
            "trues": trues,
            "model": best_model},
             name + ".pt")

## Train Graphic Models

In [4]:
input_length = 14
mid = 14
output_length = 7
batch_size = 4
####################################
train_direc = '.../graph_train/sample_'
test_direc = '.../graph_test/sample_'

train_indices = list(range(80))
valid_indices = list(range(80, 100))
test_indices = list(range(3))
####################################
train_set = Dataset_graph(train_indices, input_length, mid, output_length, train_direc)
valid_set = Dataset_graph(valid_indices, input_length, mid, output_length, train_direc)
test_set = Dataset_graph(test_indices, input_length, mid, output_length, test_direc)

train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True)
valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = False)
test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False)

################################
# U.S. states 1-0 adjacency matrix
graph = torch.load("/global/cscratch1/sd/rwang2/ODEs/Main/ode_nn/mobility/us_graph.pt")[:50,:50]
G = dgl.DGLGraph().to(device)
G.add_nodes(50)
for i in range(50):
    for j in range(50):
        if graph[i,j] == 1:
            G.add_edge(i,j)

In [None]:
name = "covid_gcn"
model = GCN(in_dim = 42, out_dim = 3, hidden_dim = 16, num_layer = 3).to(device)
#model = GAT(in_dim = 42, out_dim = 3, hidden_dim = 32, num_heads = 4, num_layer = 6).to(device)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.9)
loss_fun = nn.MSELoss()
sum(p.numel() for p in model.parameters() if p.requires_grad) 
train_rmse = []
valid_rmse = []
test_rmse = []
min_rmse = 1
for i in range(1, 200):
    start = time.time()
    scheduler.step()
    model.train()
    train_rmse.append(train_epoch_graph(model, train_loader, optimizer, loss_fun, G)[-1])
    model.eval()
    preds, trues, rmse = eval_epoch_graph(model, valid_loader, loss_fun, G)
    valid_rmse.append(rmse)
    if valid_rmse[-1] < min_rmse:
        min_rmse = valid_rmse[-1] 
        best_model = model 
        torch.save(best_model, name + ".pth")
    end = time.time()
    if (len(train_rmse) > 30 and np.mean(valid_rmse[-5:]) >= np.mean(valid_rmse[-10:-5])):
            break
    print("Epoch " + str(i) + ": ", "train rmse:", train_rmse[-1], "valid rmse:",valid_rmse[-1], 
              "time:",round((end-start)/60,3), "Learning rate:", format(get_lr(optimizer), "5.2e"))
    
preds, trues, rmse = eval_epoch_graph(best_model, test_loader, loss_fun, G)
torch.save({"preds": preds,
            "trues": trues,
            "rmse": np.sqrt(np.mean((preds - trues[:,:,-7:])**2))}, 
            name + ".pt")