In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sagemaker
import boto3
import json

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

import mxnet as mx
from mxnet import gluon
from gluonts.dataset.common import ListDataset
from gluonts.dataset.loader import (
    TrainDataLoader, ValidationDataLoader, InferenceDataLoader
)
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.model.deepar import DeepAREstimator
from gluonts.support.util import get_hybrid_forward_input_names

In [3]:
from gluonts.model.san import *

In [4]:
estimator = SelfAttentionEstimator(
    freq='h',
    prediction_length=24,
    context_length=168,
    model_dim=64,
    ffn_dim_multiplier=2,
    num_heads=4,
    num_layers=3,
    num_outputs=3,
    cardinalities=[370],
    kernel_sizes=[5,9],
    distance_encoding='dot',
    use_feat_dynamic_cat=False,
    use_feat_dynamic_real=False,
    use_feat_static_cat=True,
    use_feat_static_real=False,
)

In [5]:
from gluonts.dataset.repository.datasets import get_dataset

In [6]:
tds = get_dataset('electricity')

In [7]:
loader = TrainDataLoader(
    dataset=tds.train,
    transform=estimator.create_transformation(),
    batch_size=3,
    ctx=mx.cpu(),
    num_batches_per_epoch=100,
)

In [8]:
out = next(iter(loader))

In [9]:
out.keys()

dict_keys(['start', 'feat_static_cat', 'item_id', 'source', 'feat_static_real', 'past_observed_values', 'future_observed_values', 'past_feat_dynamic_real', 'future_feat_dynamic_real', 'past_target', 'future_target', 'past_is_pad', 'forecast_start', 'past_feat_dynamic_cat', 'future_feat_dynamic_cat'])

In [10]:
network = estimator.create_training_network()
network.initialize(ctx=mx.cpu())

In [11]:
input_names = get_hybrid_forward_input_names(network)
inputs = [out[k] for k in input_names]

In [17]:
network(*inputs)


[952.7119]
<NDArray 1 @cpu(0)>

In [18]:
predictor = estimator.create_predictor(
    estimator.create_transformation(),
    network,
)

In [23]:
pnet = predictor.prediction_net

In [21]:
tload = InferenceDataLoader(
    dataset=tds.test,
    transform=estimator.create_transformation(),
    batch_size=3,
    ctx=mx.cpu(),
)

In [22]:
out = next(iter(loader))

In [24]:
input_names = get_hybrid_forward_input_names(pnet)
inputs = [out[k] for k in input_names]

In [29]:
pr = pnet(*inputs)