In [3]:
%reset -f
from jupyterthemes import get_themes
import jupyterthemes as jt
from jupyterthemes.stylefx import set_nb_theme
set_nb_theme('onedork')

In [68]:
import torch
torch.cuda.empty_cache()

In [69]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:60% !important; }</style>"))

In [70]:
import os
import sys
# file_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# print(file_dir)
# sys.path.append(file_dir)

import torch
import numpy as np
import torch.nn as nn
import argparse
import configparser
from datetime import datetime
from model.AGCRN import AGCRN_UQ as Network
from model.BasicTrainer import Trainer
from lib.TrainInits import init_seed
from lib.dataloader import get_dataloader
from lib.TrainInits import print_model_parameters
from tqdm import tqdm
from copy import deepcopy
import torch.nn.functional as F
import torchcontrib
from model.train_methods import swa_train_combined, swa_train, train_cali, train_cali_mc, train_fge
from model.test_methods import regular_test, heter_test, mc_test, combined_test, ensemble_test,quantile_test

In [71]:
#*************************************************************************#
Mode = 'train'
DEBUG = 'True'
DATASET = 'PEMS04'      #PEMS04/8/3/7
DEVICE = 'cuda:0'
MODEL = 'AGCRN'
MODEL_NAME = "combined"#"combined" #"combined"#"basic/dropout/heter/combined_swa"
P1= 0.1 #04/03/07: 0.1; 08: 0.05

#get configuration
config_file = 'model/{}_{}.conf'.format(DATASET, MODEL)
#print('Read configuration file: %s' % (config_file))
config = configparser.ConfigParser()
config.read(config_file)

#config["data"]

from lib.utils import enable_dropout,save_model_,load_model_
from lib.metrics import All_Metrics


In [72]:
from lib.metrics import MAE_torch
def masked_mae_loss(scaler, mask_value):
    def loss(preds, labels):
        if scaler:
            preds = scaler.inverse_transform(preds)
            labels = scaler.inverse_transform(labels)
        mae = MAE_torch(pred=preds, true=labels, mask_value=mask_value)
        return mae
    return loss

#parser
args = argparse.ArgumentParser(description='arguments')
args.add_argument('--dataset', default=DATASET, type=str)
args.add_argument('--mode', default=Mode, type=str)
args.add_argument('--device', default=DEVICE, type=str, help='indices of GPUs')
args.add_argument('--debug', default=DEBUG, type=eval)
args.add_argument('--model', default=MODEL, type=str)
args.add_argument('--cuda', default=True, type=bool)
#data
args.add_argument('--val_ratio', default=config['data']['val_ratio'], type=float)
args.add_argument('--test_ratio', default=config['data']['test_ratio'], type=float)
#args.add_argument('--val_ratio', default=0.1, type=float)
#args.add_argument('--test_ratio', default=0.85, type=float)

args.add_argument('--lag', default=config['data']['lag'], type=int)
args.add_argument('--horizon', default=config['data']['horizon'], type=int)
args.add_argument('--num_nodes', default=config['data']['num_nodes'], type=int)
args.add_argument('--tod', default=config['data']['tod'], type=eval)
args.add_argument('--normalizer', default=config['data']['normalizer'], type=str)
args.add_argument('--column_wise', default=config['data']['column_wise'], type=eval)
args.add_argument('--default_graph', default=config['data']['default_graph'], type=eval)
#model
args.add_argument('--input_dim', default=config['model']['input_dim'], type=int)
args.add_argument('--output_dim', default=config['model']['output_dim'], type=int)
args.add_argument('--embed_dim', default=config['model']['embed_dim'], type=int)
args.add_argument('--rnn_units', default=config['model']['rnn_units'], type=int)
args.add_argument('--num_layers', default=config['model']['num_layers'], type=int)
args.add_argument('--cheb_k', default=config['model']['cheb_order'], type=int)
#train
args.add_argument('--loss_func', default=config['train']['loss_func'], type=str)
#args.add_argument('--loss_func', default='mse', type=str)
args.add_argument('--seed', default=config['train']['seed'], type=int)
args.add_argument('--batch_size', default=config['train']['batch_size'], type=int)
args.add_argument('--epochs', default=config['train']['epochs'], type=int)
#args.add_argument('--epochs', default=500, type=int)
args.add_argument('--lr_init', default=config['train']['lr_init'], type=float)
#args.add_argument('--lr_init', default=1e-2, type=float)
args.add_argument('--lr_decay', default=config['train']['lr_decay'], type=eval)
args.add_argument('--lr_decay_rate', default=config['train']['lr_decay_rate'], type=float)
args.add_argument('--lr_decay_step', default=config['train']['lr_decay_step'], type=str)
args.add_argument('--early_stop', default=config['train']['early_stop'], type=eval)
args.add_argument('--early_stop_patience', default=config['train']['early_stop_patience'], type=int)
args.add_argument('--grad_norm', default=config['train']['grad_norm'], type=eval)
args.add_argument('--max_grad_norm', default=config['train']['max_grad_norm'], type=int)
args.add_argument('--teacher_forcing', default=False, type=bool)
args.add_argument('--tf_decay_steps', default=2000, type=int, help='teacher forcing decay steps')
args.add_argument('--real_value', default=config['train']['real_value'], type=eval, help = 'use real value for loss calculation')
#test
args.add_argument('--mae_thresh', default=config['test']['mae_thresh'], type=eval)
args.add_argument('--mape_thresh', default=config['test']['mape_thresh'], type=float)
#log
args.add_argument('--log_dir', default='./', type=str)
args.add_argument('--log_step', default=config['log']['log_step'], type=int)
args.add_argument('--plot', default=config['log']['plot'], type=eval)
args.add_argument('--model_name', default=MODEL_NAME, type=str)
args.add_argument('--p1', default=P1, type=float)


args = args.parse_args([])
init_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.set_device(int(args.device[5]))
else:
    args.device = 'cpu'

#init model
model = Network(args).to(args.device)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
    else:
        nn.init.uniform_(p)
print_model_parameters(model, only_num=False)

#load dataset
train_loader, val_loader, test_loader, scaler = get_dataloader(args,
                                                               normalizer=args.normalizer,
                                                               tod=args.tod, dow=False,
                                                               weather=False, single=False)
#init loss function, optimizer
if args.loss_func == 'mask_mae':
    loss = masked_mae_loss(scaler, mask_value=0.0)
elif args.loss_func == 'mae':
    loss = torch.nn.L1Loss().to(args.device)
elif args.loss_func == 'mse':
    loss = torch.nn.MSELoss().to(args.device)
else:
    raise ValueError

# optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8,
#                              weight_decay=1e-6, amsgrad=False)

#basic
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8,
                             weight_decay=0, amsgrad=False)
#learning rate decay
lr_scheduler = None
if args.lr_decay:
    print('Applying learning rate decay.')
    lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(','))]
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                        milestones=lr_decay_steps,
                                                        gamma=args.lr_decay_rate)
#start training
trainer = Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
                  args, lr_scheduler=lr_scheduler)


*****************Model Parameter*****************
T torch.Size([1]) True
node_embeddings torch.Size([307, 10]) True
encoder.dcrnn_cells.0.gate.weights_pool torch.Size([10, 2, 65, 128]) True
encoder.dcrnn_cells.0.gate.bias_pool torch.Size([10, 128]) True
encoder.dcrnn_cells.0.update.weights_pool torch.Size([10, 2, 65, 64]) True
encoder.dcrnn_cells.0.update.bias_pool torch.Size([10, 64]) True
encoder.dcrnn_cells.1.gate.weights_pool torch.Size([10, 2, 128, 128]) True
encoder.dcrnn_cells.1.gate.bias_pool torch.Size([10, 128]) True
encoder.dcrnn_cells.1.update.weights_pool torch.Size([10, 2, 128, 64]) True
encoder.dcrnn_cells.1.update.bias_pool torch.Size([10, 64]) True
get_mu.0.weight torch.Size([32, 1, 1, 1]) True
get_mu.0.bias torch.Size([32]) True
get_mu.3.weight torch.Size([12, 32, 1, 64]) True
get_mu.3.bias torch.Size([12]) True
get_log_var.0.weight torch.Size([32, 1, 1, 1]) True
get_log_var.0.bias torch.Size([32]) True
get_log_var.3.weight torch.Size([12, 32, 1, 64]) True
get_log_var

# Pre-train model

In [73]:
#trainer.train()

# AWA re-train model

In [None]:
#trainer.model = swa_train_combined(trainer,epoch_swa=20)

# Save and load  model

In [1]:
#save_model_(model,args.model_name,args.dataset,args.horizon)
#trainer.model = load_model_(model,args.model_name,args.dataset,args.horizon)

# MHCC: online calibration

In [93]:

"""
Online inference as validation.
"""

def combined_conf_val(model,num_samples,args, data_loader, scaler, q,logger=None, path=None):
    model.eval()
    enable_dropout(model)
    nll_fun = nn.GaussianNLLLoss()
    y_true = []
    with torch.no_grad():
        for batch_idx, (_, target) in enumerate(data_loader):
            label = target[..., :args.output_dim]
            y_true.append(label)
    y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3)
    
    mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda()
    mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda()
    
    with torch.no_grad():
        for i in tqdm(range(num_samples)):
            mu_pred = []
            log_var_pred = []
            for batch_idx, (data, _) in enumerate(data_loader):
                data = data[..., :args.input_dim]
                mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)
                #print(mu.size())
                mu_pred.append(mu.squeeze(3))
                log_var_pred.append(log_var.squeeze(3))
        
            if args.real_value:
                mu_pred = torch.cat(mu_pred, dim=0)
            else:
                mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0))     
            log_var_pred = torch.cat(log_var_pred, dim=0)    

            #print(mc_mus.size(),mu_pred.size())    
            mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0)))   
            mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0))) 
       
    y_pred = torch.mean(mc_mus, axis=0)
    #total_var = (torch.var(mc_mus, axis=0)+torch.exp(torch.mean(mc_log_vars, axis=0)))#/temperature   
    total_var = torch.exp(torch.mean(mc_log_vars, axis=0))
    total_std = total_var**0.5 
    
    mpiw = 2*torch.mean(torch.mul(total_std,q))    
    nll = nll_fun(y_pred.ravel(), y_true.ravel(), total_var.ravel())
    lower_bound = y_pred-torch.mul(total_std,q)
    upper_bound = y_pred+torch.mul(total_std,q)  

    #in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ))
    #print(torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0))
    #picp = in_num/(y_true.size(0)*y_true.size(1)*y_true.size(2))
    in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0)
    #picp = in_num/(y_true.size(0)*y_true.size(1)*y_true.size(2))
    in_num = torch.sum(in_num,dim=1)
    picp = in_num/(y_true.size(0)*y_true.size(2))#.shape
    return y_true, y_pred, total_std, picp.detach().cpu().numpy()
    

In [94]:
y_true_val, y_pred_val, std_val, p = combined_conf_val(model,10,args,val_loader, scaler, q=1.96) 
#y_true_val, mc_mus_val, mc_log_vars_val = combined_conf_val(model,10,args,test_loader, scaler, q=1.96) #

100%|██████████| 10/10 [00:55<00:00,  5.59s/it]


In [96]:
scores = abs(y_pred_val-y_true_val)/std_val#.shape
n = y_true_val.shape[0]
def quantile_lwci(scores,n,alpha):
    q = torch.empty(y_true_val.size(1)).cuda()
    for i in range(y_true_val.shape[1]): 
      qq=np.quantile(scores[:,i,:].detach().cpu().numpy(),min((n+1.0)*(1-alpha)/n,1))
      q[i]=qq
    return q
def quantile_mhcc(scores,n,alpha_new):
    q = torch.empty(y_true_val.size(1)).cuda()
    for i in range(y_true_val.shape[1]): 
      qq=np.quantile(scores[:,i,:].detach().cpu().numpy(),min((n+1.0)*(1-alpha_new[i])/n,1))
      q[i]=qq
    return q

In [97]:
def combined_conf_test(model,num_samples,args, data_loader, scaler, q,logger=None, path=None):
    model.eval()
    enable_dropout(model)
    nll_fun = nn.GaussianNLLLoss()
    y_true = []
    with torch.no_grad():
        for batch_idx, (_, target) in enumerate(data_loader):
            label = target[..., :args.output_dim]
            y_true.append(label)
    y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3)
    
    mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda()
    mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda()
    
    with torch.no_grad():
        for i in tqdm(range(num_samples)):
            mu_pred = []
            log_var_pred = []
            for batch_idx, (data, _) in enumerate(data_loader):
                data = data[..., :args.input_dim]
                mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)
                #print(mu.size())
                mu_pred.append(mu.squeeze(3))
                log_var_pred.append(log_var.squeeze(3))
        
            if args.real_value:
                mu_pred = torch.cat(mu_pred, dim=0)
            else:
                mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0))     
            log_var_pred = torch.cat(log_var_pred, dim=0)    

            #print(mc_mus.size(),mu_pred.size())    
            mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0)))   
            mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0))) 
       
    y_pred = torch.mean(mc_mus, axis=0)
    total_var = (torch.var(mc_mus, axis=0)+torch.exp(torch.mean(mc_log_vars, axis=0)))#/temperature   
    total_std = total_var**0.5 
    
    mpiw = 2*torch.mean(torch.mul(total_std,q))    
    nll = nll_fun(y_pred.ravel(), y_true.ravel(), total_var.ravel())
    lower_bound = y_pred-torch.mul(total_std,q)
    upper_bound = y_pred+torch.mul(total_std,q)  
    
    in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0)
    in_num = torch.sum(in_num,dim=1)
    picp = in_num/(y_true.size(0)*y_true.size(2))#.shape
    print(picp*100, torch.mean(picp).item()*100, mpiw.item())
    #return y_true, y_pred, total_std

# Correct target signifcance level $\alpha$

In [103]:
h = np.arange(args.horizon)
alpha = 0.05
gamma = 0.03#04/07/08:0.03, 03:0
alpha_new = p-(1-alpha)+alpha
alpha_new = alpha_new+(p[0]-p[-1])*gamma*h*2 #04/07/08:0.03
q = quantile_mhcc(scores,n,alpha_new)

[0.05340891 0.05050783 0.05023377 0.05054503 0.04958587 0.04828899
 0.04636179 0.04609852 0.0462424  0.04623555 0.04712426 0.04645186]
[0.05340891 0.05092526 0.05106862 0.0517973  0.05125556 0.05037611
 0.04886633 0.04902048 0.04958179 0.04999236 0.05129849 0.05104351]


In [104]:
q = torch.tensor(q.reshape(-1,args.horizon,1)).to(args.device)
combined_conf_test(model,10,args,test_loader,scaler,q)

  q = torch.tensor(q.reshape(-1,12,1)).to(args.device)
100%|██████████| 10/10 [00:56<00:00,  5.67s/it]

tensor([95.0162, 95.1789, 95.1004, 94.9924, 94.9851, 95.0227, 95.1640, 95.0996,
        95.0323, 95.0006, 94.8244, 94.8621], device='cuda:0') 95.0232207775116 103.91312408447266





# Online MHCC: online calibration

### 1. Update nonconformity score

### 2. Update empirical prediction interval percentage coverage

### 3. Update alpha

In [None]:
online_size = y_true_val.shape[0]#1000
y_true_online = y_true_val[:online_size,...]
y_pred_online = y_pred_val[:online_size,...]
std_online = std_val[:online_size,...]
score_online = scores[:online_size,...]


gamma_online = 0.03 #04/07/08:0.03, 03:0
h = np.arange(args.horizon)
update_freq = 1000
picp_online = p

In [None]:
alpha = 0.05
scores_online = scores

def mhcc_online_(y_true_test, y_pred_test, std_test, q_online = q):
    ii = 0
    mpiw_ls = [] 
    picp_ls = [] 
    for y_t, p_t, s_t in tqdm(zip(y_true_test, y_pred_test, std_test)):    
        
        #emperical picp
        mpiw = 2*torch.mean(torch.mul(s_t,q_online))
        lower_bound = p_t-torch.mul(s_t,q_online)
        upper_bound = p_t+torch.mul(s_t,q_online)  

        in_num = torch.sum((y_t >= lower_bound)&(y_t <= upper_bound),dim=0)
        in_num = torch.sum(in_num,dim=1)
        picp = in_num/y_true_test.size(2)

        mpiw_ls.append(mpiw) 
        picp_ls.append(picp) 
        picp = picp.detach().cpu().numpy()
        picp_online = (picp_online*(online_size+ii) + picp)/(online_size+ii +1)
        
        #update alpha
        score_new = abs(p_t-y_t)/s_t
        scores_online = torch.cat([scores_online[1:,...],score_new.unsqueeze(0)],dim=0)


        if (ii +1) % update_freq ==0:
            #print(ii)
            alpha_new = picp_online-(1-alpha)+alpha
            alpha_new = alpha_new+(picp_online[0]-picp_online[-1])*gamma_online*h*2 #04/07/08:0.03
            for i in range(y_true_test.shape[1]): 
              qq = np.quantile(scores_online[:,i,:].detach().cpu().numpy(),min((online_size+1.0)*(1-alpha_new[i])/online_size,1))
              #print(qq.shape)  
              q_online[0,i,0]=torch.from_numpy(qq.reshape(-1,1))

        ii = ii+1 
    
    picp_test = torch.stack(picp_ls,dim=1) 
    mpiw_test = torch.stack(mpiw_ls,dim=0) 
    
    ### update q
    q_online = q
    mpiw = 2*torch.mean(torch.mul(std_test,q_online))#.reshape(-1,y_true_test.shape[1],1)) )   
    lower_bound = y_pred_test-torch.mul(std_test,q_online)
    upper_bound = y_pred_test+torch.mul(std_test,q_online)  

    in_num = torch.sum((y_true_test >= lower_bound)&(y_true_test <= upper_bound),dim=0)
    in_num = torch.sum(in_num,dim=1)
    picp = in_num/(y_true_test.size(0)*y_true_test.size(2)) 

    print("PICP: {:.4f}, MPIW: {:.4f}".format(picp,mpiw))
    
        

In [None]:
 mhcc_online_(y_true_test, y_pred_test, std_test, q_online = q)