In [1]:
import pandas as pd
import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
import random
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import *

from models import * 
import torch.optim as optim

from sklearn import preprocessing

from Experiments import *

seed = 0 

def set_seeds(seed): 
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 

set_seeds(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

%load_ext autoreload
%autoreload 2

cuda


In [2]:
# dataset = 'traffic'
# data_file = "traffic_data/data.csv"
# hier_file = "traffic_data/agg_mat.csv"
# data = pd.read_csv(data_file, index_col=0)
# agg_mat_df = pd.read_csv(hier_file, index_col=0)
# base_agg_mat = agg_mat_df.values

dataset = "tourism"
data_file = "tourism_data/data.csv"
hier_file = "tourism_data/agg_mat.csv"

# dataset = 'labor'
# data_file = "labor/data.csv"
# hier_file = "labor/agg_mat.csv"

data = pd.read_csv(data_file, index_col=0)
agg_mat_df = pd.read_csv(hier_file, index_col=0)
base_agg_mat = agg_mat_df.values
full_agg = format_aggregation_matrix(base_agg_mat).float().to(device)

maximum = np.max(data.values)
data = (data / maximum).values

batch_size = data.shape[0]
context_window = 10
train_split = 0.8
val_split = 0.1

num_runs = 10
n_epochs = 1000

X_train, y_train, X_val, y_val, X_test, y_test = utils.get_data(data, train_split, val_split, context_window)

In [7]:
params = {'n_series':data.shape[1], 'hidden_dim':128, 'latent_dim':128, 'lr':1e-4, 'n_epochs':2000, 'batch_size':batch_size, 'train_split':train_split, 'val_split':val_split, 'context_window':context_window, 'aggregate':False,
          'coherency_loss':False, 'profhit':False, 'coherency_weight':1e-2, 'project':False, 'jsd':True}

jsd_model = JSDDistribution(base_agg_mat, params)
X_train, y_train, X_val, y_val, X_test, y_test = jsd_model.make_data(data)

jsd_model.run(data)

  train_means = torch.tensor(torch.mean(data, dim=0), device=device).float()
  train_std = torch.tensor(torch.std(data, dim=0), device=device).float()
  0%|          | 3/2000 [00:00<01:43, 19.21it/s]

0.2707916498184204 0.9932851791381836


  3%|▎         | 55/2000 [00:01<00:50, 38.69it/s]

0.21142251789569855 0.9421935081481934


  5%|▌         | 108/2000 [00:02<00:46, 40.73it/s]

0.08519061654806137 1.5271872282028198


  8%|▊         | 158/2000 [00:04<00:44, 41.11it/s]

0.017257066443562508 1.4534869194030762


 10%|█         | 208/2000 [00:05<00:44, 40.39it/s]

0.005298672243952751 0.9349079728126526


 13%|█▎        | 255/2000 [00:06<00:44, 39.07it/s]

0.0027191038243472576 0.7321904897689819


 15%|█▌        | 307/2000 [00:07<00:41, 40.51it/s]

0.0018448668997734785 0.6431750059127808


 18%|█▊        | 356/2000 [00:09<00:40, 40.67it/s]

0.0014692958211526275 0.5969110131263733


 20%|██        | 406/2000 [00:10<00:40, 39.58it/s]

0.00125330057926476 0.5707597136497498


 23%|██▎       | 456/2000 [00:11<00:37, 41.16it/s]

0.0011328947730362415 0.5541549324989319


 25%|██▌       | 506/2000 [00:12<00:36, 40.84it/s]

0.0010528761195018888 0.5431044697761536


 28%|██▊       | 556/2000 [00:13<00:35, 40.29it/s]

0.0010048161493614316 0.5351189970970154


 30%|███       | 608/2000 [00:15<00:32, 42.45it/s]

0.0009671401348896325 0.5292444229125977


 33%|███▎      | 654/2000 [00:16<00:32, 40.93it/s]

0.0009409341146238148 0.5246919989585876


 35%|███▌      | 709/2000 [00:17<00:29, 44.19it/s]

0.0009230907307937741 0.5213062167167664


 38%|███▊      | 754/2000 [00:18<00:29, 42.20it/s]

0.000908001558855176 0.5185909271240234


 40%|████      | 804/2000 [00:19<00:29, 40.94it/s]

0.0008977164980024099 0.5163733959197998


 43%|████▎     | 854/2000 [00:21<00:27, 41.56it/s]

0.0008880009409040213 0.5146058797836304


 45%|████▌     | 904/2000 [00:22<00:26, 40.73it/s]

0.0008797526243142784 0.5131310224533081


 48%|████▊     | 954/2000 [00:23<00:26, 39.67it/s]

0.0008761381614021957 0.5119962692260742


 50%|█████     | 1004/2000 [00:24<00:24, 40.77it/s]

0.0008701412589289248 0.5110464096069336


 53%|█████▎    | 1054/2000 [00:25<00:22, 41.43it/s]

0.000865944370161742 0.5101805329322815


 55%|█████▌    | 1104/2000 [00:27<00:21, 40.80it/s]

0.0011386815458536148 4.673553943634033


 58%|█████▊    | 1154/2000 [00:28<00:20, 41.15it/s]

0.0009722504182718694 4.344911575317383


 60%|██████    | 1204/2000 [00:29<00:19, 41.19it/s]

0.0009153939317911863 4.304514408111572


 63%|██████▎   | 1254/2000 [00:30<00:18, 41.32it/s]

0.0008972459472715855 4.004683971405029


 65%|██████▌   | 1304/2000 [00:32<00:16, 41.35it/s]

0.0008709283429197967 4.2006683349609375


 68%|██████▊   | 1357/2000 [00:33<00:16, 38.44it/s]

0.0008661731262691319 3.774613857269287


 70%|███████   | 1406/2000 [00:34<00:14, 39.98it/s]

0.0008494569920003414 3.8908350467681885


 73%|███████▎  | 1455/2000 [00:35<00:13, 41.14it/s]

0.0008444800623692572 3.764763593673706


 75%|███████▌  | 1505/2000 [00:37<00:11, 42.25it/s]

0.0008479708922095597 3.565192699432373


 78%|███████▊  | 1555/2000 [00:38<00:10, 41.82it/s]

0.000820866203866899 3.53470516204834


 80%|████████  | 1605/2000 [00:39<00:09, 41.21it/s]

0.0008164542960003018 3.6460089683532715


 83%|████████▎ | 1655/2000 [00:40<00:08, 41.00it/s]

0.0008194731781259179 3.6781110763549805


 85%|████████▌ | 1705/2000 [00:41<00:07, 42.14it/s]

0.0008166456245817244 3.4572510719299316


 88%|████████▊ | 1755/2000 [00:43<00:05, 41.21it/s]

0.0008063060813583434 3.5228707790374756


 90%|█████████ | 1805/2000 [00:44<00:04, 41.14it/s]

0.00080348108895123 3.564983606338501


 93%|█████████▎| 1856/2000 [00:45<00:03, 42.43it/s]

0.000818111002445221 3.60292649269104


 95%|█████████▌| 1906/2000 [00:46<00:02, 40.91it/s]

0.0008064495632424951 3.6282424926757812


 98%|█████████▊| 1956/2000 [00:48<00:01, 41.11it/s]

0.0007984137046150863 3.5246424674987793


100%|██████████| 2000/2000 [00:49<00:00, 40.73it/s]


(    Coherency      CRPS       MSE     WMAPE
 0    0.000035  0.002809  0.000018  1.039404
 1    0.000398  0.005256  0.000093  1.000268
 2    0.000760  0.007651  0.000154  0.998851
 3    0.000769  0.011896  0.000131  1.004116
 4    0.000708  0.010199  0.000264  1.004737
 5    0.000479  0.003722  0.000036  1.002776
 6    0.001591  0.003053  0.000024  1.059236
 7    0.000745  0.013336  0.000583  1.000900
 8    0.001448  0.024774  0.001277  0.998938
 9    0.006171  0.035740  0.001972  0.879192
 10   0.001318  0.013962  0.000308  0.999604
 11   0.001454  0.044169  0.003081  1.001823
 12   0.001772  0.025457  0.001298  0.990160
 13   0.001393  0.016017  0.000326  1.003667
 14   0.003909  0.009968  0.000216  0.996070
 15   0.001501  0.014325  0.000254  0.993051
 16   0.009427  0.077930  0.008941  0.925043
 17   0.056579  0.111755  0.017399  0.679360
 18   0.020015  0.120644  0.020402  0.883934
 19   0.002439  0.108557  0.013487  0.999580
 20   0.011555  0.560906  0.331362  1.000486,
 ([1,
   

In [10]:
out = jsd_model.network(X_val.to(device).float())
var = out[2].exp()
print(var.mean(axis=0))

tensor([0.0153, 0.0468, 0.0090, 0.0240, 0.0086, 0.0125, 0.0032, 0.0227, 0.0054,
        0.0083, 0.0361, 0.0067, 0.0040, 0.0088, 0.0183, 0.0089, 0.0062, 0.0060,
        0.0066, 0.0112, 0.0102, 0.0045, 0.0237, 0.0080, 0.0052, 0.0086, 0.0036,
        0.0113, 0.0052, 0.0130, 0.0038, 0.0041, 0.0040, 0.0149, 0.0198, 0.0055,
        0.0044, 0.0058, 0.0096, 0.0375, 0.0044, 0.0045, 0.0047, 0.0065, 0.0052,
        0.0045, 0.0054, 0.0045, 0.0075, 0.0172, 0.0153, 0.0076, 0.0089, 0.0081,
        0.0069, 0.0076, 0.0085, 0.0047, 0.0057, 0.0035, 0.0059, 0.0051, 0.0037,
        0.0035, 0.0127, 0.0051, 0.0033, 0.0031, 0.0045, 0.0104, 0.0177, 0.0091,
        0.0086, 0.0058, 0.0049, 0.0063, 0.0074, 0.0047, 0.0052, 0.0047, 0.0125,
        0.0203, 0.0081, 0.0039, 0.0057, 0.0058, 0.0054, 0.0046, 0.0067, 0.0063,
        0.0042, 0.0096, 0.0042, 0.0037, 0.0091, 0.0297, 0.0051, 0.0057, 0.0134,
        0.0040, 0.0030, 0.0066, 0.0041, 0.0075, 0.0142, 0.0084, 0.0175, 0.0067,
        0.0152, 0.0258, 0.0075, 0.0450, 