Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MQ-CNN: Bound context_length by the max_ts_len - prediction_length #1037

Merged
merged 8 commits into from
Sep 17, 2020
4 changes: 4 additions & 0 deletions src/gluonts/dataset/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class DatasetStatistics(NamedTuple):
mean_abs_target: float
mean_target: float
mean_target_length: float
max_target_length: int
min_target: float
feat_static_real: List[Set[float]]
feat_static_cat: List[Set[int]]
Expand Down Expand Up @@ -173,6 +174,7 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics:
scale_histogram = ScaleHistogram()

with tqdm(enumerate(ts_dataset, start=1), total=len(ts_dataset)) as it:
max_target_length = 0
for num_time_series, ts in it:

# TARGET
Expand All @@ -190,6 +192,7 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics:
)

num_time_observations += num_observations
max_target_length = max(num_observations, max_target_length)
min_target = float(min(min_target, observed_target.min()))
max_target = float(max(max_target, observed_target.max()))
num_missing_values += int(np.isnan(target).sum())
Expand Down Expand Up @@ -401,6 +404,7 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics:
mean_abs_target=mean_abs_target,
mean_target=mean_target,
mean_target_length=mean_target_length,
max_target_length=max_target_length,
min_target=min_target,
num_missing_values=num_missing_values,
feat_static_real=observed_feat_static_real
Expand Down
23 changes: 17 additions & 6 deletions src/gluonts/model/seq2seq/_forking_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
scaling_decoder_dynamic_feature: bool = False,
dtype: DType = np.float32,
num_forking: Optional[int] = None,
max_ts_len: Optional[int] = None,
) -> None:
super().__init__(trainer=trainer)

Expand Down Expand Up @@ -187,8 +188,18 @@ def __init__(
if context_length is not None
else 4 * self.prediction_length
)
if max_ts_len is not None:
max_pad_len = max(max_ts_len - self.prediction_length, 0)
# Don't allow context_length to be longer than the max pad length
self.context_length = (
min(max_pad_len, self.context_length)
if max_pad_len > 0
else self.context_length
)
self.num_forking = (
num_forking if num_forking is not None else self.context_length
min(num_forking, self.context_length)
if num_forking is not None
else self.context_length
)
self.use_past_feat_dynamic_real = use_past_feat_dynamic_real
self.use_feat_dynamic_real = use_feat_dynamic_real
Expand Down Expand Up @@ -252,7 +263,7 @@ def create_transformation(self) -> Transformation:
time_features=time_features_from_frequency_str(self.freq),
pred_length=self.prediction_length,
dtype=self.dtype,
),
)
)
dynamic_feat_fields.append(FieldName.FEAT_TIME)

Expand All @@ -263,7 +274,7 @@ def create_transformation(self) -> Transformation:
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
dtype=self.dtype,
),
)
)
dynamic_feat_fields.append(FieldName.FEAT_AGE)

Expand All @@ -290,7 +301,7 @@ def create_transformation(self) -> Transformation:
pred_length=self.prediction_length,
const=0.0, # For consistency in case with no dynamic features
dtype=self.dtype,
),
)
)
dynamic_feat_fields.append(FieldName.FEAT_CONST)

Expand All @@ -315,7 +326,7 @@ def create_transformation(self) -> Transformation:
SetField(
output_field=FieldName.FEAT_STATIC_CAT,
value=np.array([0], dtype=np.int32),
),
)
)

# --- SAMPLE AND CUT THE TIME-SERIES ---
Expand Down Expand Up @@ -361,7 +372,7 @@ def create_transformation(self) -> Transformation:
else []
),
prediction_time_decoder_exclude=[FieldName.OBSERVED_VALUES],
),
)
)

# past_feat_dynamic features generated above in ForkingSequenceSplitter from those under feat_dynamic - we need
Expand Down
64 changes: 34 additions & 30 deletions src/gluonts/model/seq2seq/_forking_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Third-party imports
# Standard library imports
from typing import List, Optional, Tuple

# Third-party imports
import mxnet as mx
import numpy as np
from mxnet import gluon
import numpy as np

# First-party imports
from gluonts.core.component import DType, validated
Expand Down Expand Up @@ -73,14 +73,14 @@ def __init__(
enc2dec: Seq2SeqEnc2Dec,
decoder: Seq2SeqDecoder,
context_length: int,
num_forking: int,
cardinality: List[int],
embedding_dimension: List[int],
distr_output: Optional[DistributionOutput] = None,
quantile_output: Optional[QuantileOutput] = None,
scaling: bool = False,
scaling_decoder_dynamic_feature: bool = False,
dtype: DType = np.float32,
num_forking: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -92,14 +92,13 @@ def __init__(
self.decoder = decoder
self.distr_output = distr_output
self.quantile_output = quantile_output
self.context_length = context_length
self.num_forking = num_forking
self.cardinality = cardinality
self.embedding_dimension = embedding_dimension
self.scaling = scaling
self.scaling_decoder_dynamic_feature = scaling_decoder_dynamic_feature
self.scaling_decoder_dynamic_feature_axis = 1
self.dtype = dtype
self.num_forking = (
num_forking if num_forking is not None else context_length
)

if self.scaling:
self.scaler = MeanScaler(keepdims=True)
Expand Down Expand Up @@ -146,37 +145,30 @@ def get_decoder_network_output(
past_target: Tensor
shape (batch_size, encoder_length, 1)
past_feat_dynamic
shape (batch_size, encoder_length, num_past_feature_dynamic)
shape (batch_size, encoder_length, num_past_feat_dynamic)
future_feat_dynamic
shape (batch_size, num_forking, decoder_length, num_feature_dynamic)
shape (batch_size, num_forking, decoder_length, num_feat_dynamic)
feat_static_cat
shape (batch_size, encoder_length, num_feature_static_cat)
shape (batch_size, num_feat_static_cat)
past_observed_values: Tensor
shape (batch_size, encoder_length, 1)
Returns
-------
decoder output tensor of size (batch_size, num_forking, dec_len, final_dims)
decoder output tensor of size (batch_size, num_forking, dec_len, decoder_mlp_dim_seq[0])
"""

# scale is computed on the context length last units of the past target
# scale shape is (batch_size, 1, *target_shape)
# scale shape: (batch_size, 1, 1)
scaled_past_target, scale = self.scaler(
past_target.slice_axis(
axis=1, begin=-self.context_length, end=None
),
past_observed_values.slice_axis(
axis=1, begin=-self.context_length, end=None
),
past_target, past_observed_values
)

# (batch_size, num_features)
# (batch_size, sum(embedding_dimension) = num_feat_static_cat)
embedded_cat = self.embedder(feat_static_cat)

# in addition to embedding features, use the log scale as it can help prediction too
# (batch_size, num_features + prod(target_shape))
# TODO: Check why different from DeepAR case
# (batch_size, num_feat_static = sum(embedding_dimension) + 1)
feat_static_real = F.concat(
embedded_cat, F.log(scale.squeeze(axis=1)), dim=1,
embedded_cat, F.log(scale.squeeze(axis=1)), dim=1
)

# Passing past_observed_values as a feature would allow the network to
Expand All @@ -186,6 +178,8 @@ def get_decoder_network_output(
)

# arguments: target, static_features, dynamic_features
# enc_output_static shape: (batch_size, channels_seq[-1] + 1)
# enc_output_dynamic shape: (batch_size, encoder_length, channels_seq[-1] + 1)
enc_output_static, enc_output_dynamic = self.encoder(
scaled_past_target, feat_static_real, past_feat_dynamic_extended
)
Expand All @@ -197,8 +191,11 @@ def get_decoder_network_output(
)

# arguments: encoder_output_static, encoder_output_dynamic, future_features
# dec_input_static shape: (batch_size, channels_seq[-1] + 1)
# dec_input_dynamic shape:(batch_size, num_forking, channels_seq[-1] + 1 + decoder_length * num_feat_dynamic)
dec_input_static, dec_input_dynamic = self.enc2dec(
enc_output_static,
# slice axis 1 from encoder_length = context_length to num_forking
enc_output_dynamic.slice_axis(
axis=1, begin=-self.num_forking, end=None
),
Expand All @@ -210,7 +207,7 @@ def get_decoder_network_output(
# where we we only need to pass the encoder output for the last time step
dec_output = self.decoder(dec_input_dynamic, dec_input_static)

# the output shape should be: (batch_size, num_forking, dec_len, final_dims)
# the output shape should be: (batch_size, num_forking, dec_len, decoder_mlp_dim_seq[0])
return dec_output, scale


Expand All @@ -237,11 +234,11 @@ def hybrid_forward(
future_target: Tensor
shape (batch_size, num_forking, decoder_length)
past_feat_dynamic
shape (batch_size, encoder_length, num_past_feature_dynamic)
shape (batch_size, encoder_length, num_past_feat_dynamic)
future_feat_dynamic
shape (batch_size, num_forking, decoder_length, num_feature_dynamic)
shape (batch_size, num_forking, decoder_length, num_feat_dynamic)
feat_static_cat
shape (batch_size, encoder_length, num_feature_static_cat)
shape (batch_size, num_feat_static_cat)
past_observed_values: Tensor
shape (batch_size, encoder_length, 1)
future_observed_values: Tensor
Expand All @@ -251,6 +248,7 @@ def hybrid_forward(
-------
loss with shape (batch_size, prediction_length)
"""
# shape: (batch_size, num_forking, decoder_length, decoder_mlp_dim_seq[0])
dec_output, scale = self.get_decoder_network_output(
F,
past_target,
Expand All @@ -261,7 +259,9 @@ def hybrid_forward(
)

if self.quantile_output is not None:
# shape: (batch_size, num_forking, decoder_length, len(quantiles))
dec_dist_output = self.quantile_proj(dec_output)
# shape: (batch_size, num_forking, decoder_length = prediction_length)
loss = self.loss(future_target, dec_dist_output)
else:
assert self.distr_output is not None
Expand All @@ -270,6 +270,7 @@ def hybrid_forward(
loss = distr.loss(future_target)

# mask the loss based on observed indicator
# shape: (batch_size, decoder_length)
weighted_loss = weighted_average(
F=F, x=loss, weights=future_observed_values, axis=1
)
Expand All @@ -296,11 +297,11 @@ def hybrid_forward(
past_target: Tensor
shape (batch_size, encoder_length, 1)
feat_static_cat
shape (batch_size, encoder_length, num_feature_static_cat)
shape (batch_size, num_feat_static_cat)
past_feat_dynamic
shape (batch_size, encoder_length, num_past_feature_dynamic)
shape (batch_size, encoder_length, num_past_feat_dynamic)
future_feat_dynamic
shape (batch_size, num_forking, decoder_length, num_feature_dynamic)
shape (batch_size, num_forking, decoder_length, num_feat_dynamic)
past_observed_values: Tensor
shape (batch_size, encoder_length, 1)

Expand All @@ -309,6 +310,7 @@ def hybrid_forward(
prediction tensor with shape (batch_size, prediction_length)
"""

# shape: (batch_size, num_forking, decoder_length, decoder_mlp_dim_seq[0])
dec_output, _ = self.get_decoder_network_output(
F,
past_target,
Expand All @@ -319,9 +321,11 @@ def hybrid_forward(
)

# We only care about the output of the decoder for the last time step
# shape: (batch_size, decoder_length, decoder_mlp_dim_seq[0])
fcst_output = F.slice_axis(dec_output, axis=1, begin=-1, end=None)
fcst_output = F.squeeze(fcst_output, axis=1)

# shape: (batch_size, len(quantiles), decoder_length = prediction_length)
predictions = self.quantile_proj(fcst_output).swapaxes(2, 1)

return predictions
Expand Down
5 changes: 5 additions & 0 deletions src/gluonts/model/seq2seq/_mq_dnn_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
scaling: bool = False,
scaling_decoder_dynamic_feature: bool = False,
num_forking: Optional[int] = None,
max_ts_len: Optional[int] = None,
) -> None:

assert (distr_output is None) or (quantiles is None)
Expand Down Expand Up @@ -239,6 +240,7 @@ def __init__(
scaling=scaling,
scaling_decoder_dynamic_feature=scaling_decoder_dynamic_feature,
num_forking=num_forking,
max_ts_len=max_ts_len,
)

@classmethod
Expand All @@ -250,6 +252,7 @@ def derive_auto_fields(cls, train_iter):
"use_feat_dynamic_real": stats.num_feat_dynamic_real > 0,
"use_feat_static_cat": bool(stats.feat_static_cat),
"cardinality": [len(cats) for cats in stats.feat_static_cat],
"max_ts_len": stats.max_target_length,
}

@classmethod
Expand Down Expand Up @@ -318,6 +321,7 @@ def __init__(
distr_output: Optional[DistributionOutput] = None,
scaling: bool = False,
scaling_decoder_dynamic_feature: bool = False,
num_forking: Optional[int] = None,
) -> None:

assert (
Expand Down Expand Up @@ -373,4 +377,5 @@ def __init__(
trainer=trainer,
scaling=scaling,
scaling_decoder_dynamic_feature=scaling_decoder_dynamic_feature,
num_forking=num_forking,
)
Loading