In [1]:
import numpy as np
import pytorch_lightning as pl
import torch as t
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import seed_everything

from nixtla.data.datasets.m4 import M4, M4Info
from nixtla.data.tsdataset import TimeSeriesDataset
from nixtla.data.tsloader import TimeSeriesLoader
from nixtla.models.esrnn.esrnn import ESRNN

## Import Yearly data

In [2]:
group = M4Info['Yearly']
Y_df, *_ = M4.load(directory='data', group=group.name)

In [3]:
train_ts_dataset = TimeSeriesDataset(Y_df=Y_df, ds_in_test=group.horizon,
                                     mode='full',
                                     window_sampling_limit=25, # To limit backprop time
                                     input_size=4,
                                     output_size=group.horizon,
                                     idx_to_sample_freq=1,
                                     len_sample_chunks=group.horizon * 3,
                                     complete_inputs=True, 
                                     skip_nonsamplable=True)

In [5]:
test_ts_dataset = TimeSeriesDataset(Y_df=Y_df, ds_in_test=0,
                                    mode='full',
                                    window_sampling_limit=500_000, # To limit backprop time
                                    input_size=4,
                                    output_size=group.horizon,
                                    idx_to_sample_freq=1,
                                    complete_inputs=True,
                                    complete_outputs=True,
                                    len_sample_chunks=group.horizon * 3,
                                    skip_nonsamplable=False,
                                    last_samplable_window=True)

In [94]:
train_ts_loader = TimeSeriesLoader(dataset=train_ts_dataset,
                                   batch_size=32,
                                   eq_batch_size=True,
                                   num_workers=8,
                                   shuffle=True)

In [95]:
test_ts_loader = TimeSeriesLoader(dataset=test_ts_dataset,
                                  batch_size=1_024,
                                  eq_batch_size=False,
                                  num_workers=8,
                                  shuffle=False)

In [96]:
model = ESRNN(n_series=group.n_ts,
              n_x=0, n_s=0,
              idx_to_sample_freq=1,
              input_size=4,
              output_size=group.horizon,
              learning_rate=1e-4,
              lr_scheduler_step_size=10,
              lr_decay=0.1,
              per_series_lr_multip=0.8,
              gradient_eps=1e-8,
              gradient_clipping_threshold=50,
              rnn_weight_decay=0,
              level_variability_penalty=100,
              testing_percentile=50,
              training_percentile=50,
              cell_type='LSTM',
              state_hsize=50,
              dilations=[[1, 2], [2, 6]],
              add_nl_layer=False,
              loss='SMYL',
              val_loss='SMAPE',
              seasonality=[])

In [97]:
seed_everything(117982, workers=True)

Global seed set to 117982


117982

In [105]:
trainer = pl.Trainer(max_epochs=10, progress_bar_refresh_rate=50, deterministic=True)
trainer.fit(model, train_ts_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type   | Params
---------------------------------
0 | esrnn | _ESRNN | 95.7 K
---------------------------------
95.7 K    Trainable params
0         Non-trainable params
95.7 K    Total params
0.383     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 117982


Training: 0it [00:00, ?it/s]

In [106]:
outputs = trainer.predict(model, test_ts_loader)

Predicting: 719it [00:00, ?it/s]

In [107]:
_, y_hat = zip(*outputs)

In [108]:
y_hat = t.cat(y_hat)[:, -1].numpy()

In [109]:
y = Y_df.groupby('unique_id').tail(group.horizon)['y'].values.reshape(-1, group.horizon)

In [110]:
from nixtla.losses.numpy import smape

In [112]:
smape(y, y_hat)

13.414962723225404