From c3150eb7b3f58181ddda385a2ef273349003c782 Mon Sep 17 00:00:00 2001 From: Danielle Robinson Date: Thu, 17 Sep 2020 10:20:10 -0700 Subject: [PATCH] MQ-CNN: Bound context_length by the max_ts_len - prediction_length (#1037) * Bound context_length by the max_ts_len - prediction_length to remove unnecessary zero padding * Fixing test_accuracy * Reverted variable_length=False in data_loader and did update in place in instance splitter, updated pad_to_size util function to support right and left padding * Updating the comments for right padding * Revert back to from_hyperparameters in the tests * Reverting data loader changes Co-authored-by: Danielle Robinson --- src/gluonts/dataset/stat.py | 4 + .../model/seq2seq/_forking_estimator.py | 23 +++-- src/gluonts/model/seq2seq/_forking_network.py | 64 ++++++------ .../model/seq2seq/_mq_dnn_estimator.py | 5 + src/gluonts/model/seq2seq/_transform.py | 99 ++++++++----------- src/gluonts/mx/block/decoder.py | 22 ++--- src/gluonts/mx/block/enc2dec.py | 49 ++++----- src/gluonts/mx/block/encoder.py | 38 +++---- src/gluonts/mx/block/feature.py | 2 +- src/gluonts/support/util.py | 26 +++-- test/dataset/test_stat.py | 47 +++++---- 11 files changed, 192 insertions(+), 187 deletions(-) diff --git a/src/gluonts/dataset/stat.py b/src/gluonts/dataset/stat.py index 8728037153..bbb40450a1 100644 --- a/src/gluonts/dataset/stat.py +++ b/src/gluonts/dataset/stat.py @@ -114,6 +114,7 @@ class DatasetStatistics(NamedTuple): mean_abs_target: float mean_target: float mean_target_length: float + max_target_length: int min_target: float feat_static_real: List[Set[float]] feat_static_cat: List[Set[int]] @@ -173,6 +174,7 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics: scale_histogram = ScaleHistogram() with tqdm(enumerate(ts_dataset, start=1), total=len(ts_dataset)) as it: + max_target_length = 0 for num_time_series, ts in it: # TARGET @@ -190,6 +192,7 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics: ) num_time_observations += num_observations + max_target_length = max(num_observations, max_target_length) min_target = float(min(min_target, observed_target.min())) max_target = float(max(max_target, observed_target.max())) num_missing_values += int(np.isnan(target).sum()) @@ -401,6 +404,7 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics: mean_abs_target=mean_abs_target, mean_target=mean_target, mean_target_length=mean_target_length, + max_target_length=max_target_length, min_target=min_target, num_missing_values=num_missing_values, feat_static_real=observed_feat_static_real diff --git a/src/gluonts/model/seq2seq/_forking_estimator.py b/src/gluonts/model/seq2seq/_forking_estimator.py index 0f77ed63c4..f21e9b7447 100644 --- a/src/gluonts/model/seq2seq/_forking_estimator.py +++ b/src/gluonts/model/seq2seq/_forking_estimator.py @@ -156,6 +156,7 @@ def __init__( scaling_decoder_dynamic_feature: bool = False, dtype: DType = np.float32, num_forking: Optional[int] = None, + max_ts_len: Optional[int] = None, ) -> None: super().__init__(trainer=trainer) @@ -187,8 +188,18 @@ def __init__( if context_length is not None else 4 * self.prediction_length ) + if max_ts_len is not None: + max_pad_len = max(max_ts_len - self.prediction_length, 0) + # Don't allow context_length to be longer than the max pad length + self.context_length = ( + min(max_pad_len, self.context_length) + if max_pad_len > 0 + else self.context_length + ) self.num_forking = ( - num_forking if num_forking is not None else self.context_length + min(num_forking, self.context_length) + if num_forking is not None + else self.context_length ) self.use_past_feat_dynamic_real = use_past_feat_dynamic_real self.use_feat_dynamic_real = use_feat_dynamic_real @@ -252,7 +263,7 @@ def create_transformation(self) -> Transformation: time_features=time_features_from_frequency_str(self.freq), pred_length=self.prediction_length, dtype=self.dtype, - ), + ) ) dynamic_feat_fields.append(FieldName.FEAT_TIME) @@ -263,7 +274,7 @@ def create_transformation(self) -> Transformation: output_field=FieldName.FEAT_AGE, pred_length=self.prediction_length, dtype=self.dtype, - ), + ) ) dynamic_feat_fields.append(FieldName.FEAT_AGE) @@ -290,7 +301,7 @@ def create_transformation(self) -> Transformation: pred_length=self.prediction_length, const=0.0, # For consistency in case with no dynamic features dtype=self.dtype, - ), + ) ) dynamic_feat_fields.append(FieldName.FEAT_CONST) @@ -315,7 +326,7 @@ def create_transformation(self) -> Transformation: SetField( output_field=FieldName.FEAT_STATIC_CAT, value=np.array([0], dtype=np.int32), - ), + ) ) # --- SAMPLE AND CUT THE TIME-SERIES --- @@ -361,7 +372,7 @@ def create_transformation(self) -> Transformation: else [] ), prediction_time_decoder_exclude=[FieldName.OBSERVED_VALUES], - ), + ) ) # past_feat_dynamic features generated above in ForkingSequenceSplitter from those under feat_dynamic - we need diff --git a/src/gluonts/model/seq2seq/_forking_network.py b/src/gluonts/model/seq2seq/_forking_network.py index 06405f287d..87f58a7e29 100644 --- a/src/gluonts/model/seq2seq/_forking_network.py +++ b/src/gluonts/model/seq2seq/_forking_network.py @@ -11,13 +11,13 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -# Third-party imports +# Standard library imports from typing import List, Optional, Tuple # Third-party imports import mxnet as mx -import numpy as np from mxnet import gluon +import numpy as np # First-party imports from gluonts.core.component import DType, validated @@ -73,7 +73,6 @@ def __init__( enc2dec: Seq2SeqEnc2Dec, decoder: Seq2SeqDecoder, context_length: int, - num_forking: int, cardinality: List[int], embedding_dimension: List[int], distr_output: Optional[DistributionOutput] = None, @@ -81,6 +80,7 @@ def __init__( scaling: bool = False, scaling_decoder_dynamic_feature: bool = False, dtype: DType = np.float32, + num_forking: Optional[int] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -92,14 +92,13 @@ def __init__( self.decoder = decoder self.distr_output = distr_output self.quantile_output = quantile_output - self.context_length = context_length - self.num_forking = num_forking - self.cardinality = cardinality - self.embedding_dimension = embedding_dimension self.scaling = scaling self.scaling_decoder_dynamic_feature = scaling_decoder_dynamic_feature self.scaling_decoder_dynamic_feature_axis = 1 self.dtype = dtype + self.num_forking = ( + num_forking if num_forking is not None else context_length + ) if self.scaling: self.scaler = MeanScaler(keepdims=True) @@ -146,37 +145,30 @@ def get_decoder_network_output( past_target: Tensor shape (batch_size, encoder_length, 1) past_feat_dynamic - shape (batch_size, encoder_length, num_past_feature_dynamic) + shape (batch_size, encoder_length, num_past_feat_dynamic) future_feat_dynamic - shape (batch_size, num_forking, decoder_length, num_feature_dynamic) + shape (batch_size, num_forking, decoder_length, num_feat_dynamic) feat_static_cat - shape (batch_size, encoder_length, num_feature_static_cat) + shape (batch_size, num_feat_static_cat) past_observed_values: Tensor shape (batch_size, encoder_length, 1) Returns ------- - decoder output tensor of size (batch_size, num_forking, dec_len, final_dims) + decoder output tensor of size (batch_size, num_forking, dec_len, decoder_mlp_dim_seq[0]) """ - # scale is computed on the context length last units of the past target - # scale shape is (batch_size, 1, *target_shape) + # scale shape: (batch_size, 1, 1) scaled_past_target, scale = self.scaler( - past_target.slice_axis( - axis=1, begin=-self.context_length, end=None - ), - past_observed_values.slice_axis( - axis=1, begin=-self.context_length, end=None - ), + past_target, past_observed_values ) - # (batch_size, num_features) + # (batch_size, sum(embedding_dimension) = num_feat_static_cat) embedded_cat = self.embedder(feat_static_cat) # in addition to embedding features, use the log scale as it can help prediction too - # (batch_size, num_features + prod(target_shape)) - # TODO: Check why different from DeepAR case + # (batch_size, num_feat_static = sum(embedding_dimension) + 1) feat_static_real = F.concat( - embedded_cat, F.log(scale.squeeze(axis=1)), dim=1, + embedded_cat, F.log(scale.squeeze(axis=1)), dim=1 ) # Passing past_observed_values as a feature would allow the network to @@ -186,6 +178,8 @@ def get_decoder_network_output( ) # arguments: target, static_features, dynamic_features + # enc_output_static shape: (batch_size, channels_seq[-1] + 1) + # enc_output_dynamic shape: (batch_size, encoder_length, channels_seq[-1] + 1) enc_output_static, enc_output_dynamic = self.encoder( scaled_past_target, feat_static_real, past_feat_dynamic_extended ) @@ -197,8 +191,11 @@ def get_decoder_network_output( ) # arguments: encoder_output_static, encoder_output_dynamic, future_features + # dec_input_static shape: (batch_size, channels_seq[-1] + 1) + # dec_input_dynamic shape:(batch_size, num_forking, channels_seq[-1] + 1 + decoder_length * num_feat_dynamic) dec_input_static, dec_input_dynamic = self.enc2dec( enc_output_static, + # slice axis 1 from encoder_length = context_length to num_forking enc_output_dynamic.slice_axis( axis=1, begin=-self.num_forking, end=None ), @@ -210,7 +207,7 @@ def get_decoder_network_output( # where we we only need to pass the encoder output for the last time step dec_output = self.decoder(dec_input_dynamic, dec_input_static) - # the output shape should be: (batch_size, num_forking, dec_len, final_dims) + # the output shape should be: (batch_size, num_forking, dec_len, decoder_mlp_dim_seq[0]) return dec_output, scale @@ -237,11 +234,11 @@ def hybrid_forward( future_target: Tensor shape (batch_size, num_forking, decoder_length) past_feat_dynamic - shape (batch_size, encoder_length, num_past_feature_dynamic) + shape (batch_size, encoder_length, num_past_feat_dynamic) future_feat_dynamic - shape (batch_size, num_forking, decoder_length, num_feature_dynamic) + shape (batch_size, num_forking, decoder_length, num_feat_dynamic) feat_static_cat - shape (batch_size, encoder_length, num_feature_static_cat) + shape (batch_size, num_feat_static_cat) past_observed_values: Tensor shape (batch_size, encoder_length, 1) future_observed_values: Tensor @@ -251,6 +248,7 @@ def hybrid_forward( ------- loss with shape (batch_size, prediction_length) """ + # shape: (batch_size, num_forking, decoder_length, decoder_mlp_dim_seq[0]) dec_output, scale = self.get_decoder_network_output( F, past_target, @@ -261,7 +259,9 @@ def hybrid_forward( ) if self.quantile_output is not None: + # shape: (batch_size, num_forking, decoder_length, len(quantiles)) dec_dist_output = self.quantile_proj(dec_output) + # shape: (batch_size, num_forking, decoder_length = prediction_length) loss = self.loss(future_target, dec_dist_output) else: assert self.distr_output is not None @@ -270,6 +270,7 @@ def hybrid_forward( loss = distr.loss(future_target) # mask the loss based on observed indicator + # shape: (batch_size, decoder_length) weighted_loss = weighted_average( F=F, x=loss, weights=future_observed_values, axis=1 ) @@ -296,11 +297,11 @@ def hybrid_forward( past_target: Tensor shape (batch_size, encoder_length, 1) feat_static_cat - shape (batch_size, encoder_length, num_feature_static_cat) + shape (batch_size, num_feat_static_cat) past_feat_dynamic - shape (batch_size, encoder_length, num_past_feature_dynamic) + shape (batch_size, encoder_length, num_past_feat_dynamic) future_feat_dynamic - shape (batch_size, num_forking, decoder_length, num_feature_dynamic) + shape (batch_size, num_forking, decoder_length, num_feat_dynamic) past_observed_values: Tensor shape (batch_size, encoder_length, 1) @@ -309,6 +310,7 @@ def hybrid_forward( prediction tensor with shape (batch_size, prediction_length) """ + # shape: (batch_size, num_forking, decoder_length, decoder_mlp_dim_seq[0]) dec_output, _ = self.get_decoder_network_output( F, past_target, @@ -319,9 +321,11 @@ def hybrid_forward( ) # We only care about the output of the decoder for the last time step + # shape: (batch_size, decoder_length, decoder_mlp_dim_seq[0]) fcst_output = F.slice_axis(dec_output, axis=1, begin=-1, end=None) fcst_output = F.squeeze(fcst_output, axis=1) + # shape: (batch_size, len(quantiles), decoder_length = prediction_length) predictions = self.quantile_proj(fcst_output).swapaxes(2, 1) return predictions diff --git a/src/gluonts/model/seq2seq/_mq_dnn_estimator.py b/src/gluonts/model/seq2seq/_mq_dnn_estimator.py index d4c9b1fa93..050f7e88af 100644 --- a/src/gluonts/model/seq2seq/_mq_dnn_estimator.py +++ b/src/gluonts/model/seq2seq/_mq_dnn_estimator.py @@ -141,6 +141,7 @@ def __init__( scaling: bool = False, scaling_decoder_dynamic_feature: bool = False, num_forking: Optional[int] = None, + max_ts_len: Optional[int] = None, ) -> None: assert (distr_output is None) or (quantiles is None) @@ -239,6 +240,7 @@ def __init__( scaling=scaling, scaling_decoder_dynamic_feature=scaling_decoder_dynamic_feature, num_forking=num_forking, + max_ts_len=max_ts_len, ) @classmethod @@ -250,6 +252,7 @@ def derive_auto_fields(cls, train_iter): "use_feat_dynamic_real": stats.num_feat_dynamic_real > 0, "use_feat_static_cat": bool(stats.feat_static_cat), "cardinality": [len(cats) for cats in stats.feat_static_cat], + "max_ts_len": stats.max_target_length, } @classmethod @@ -318,6 +321,7 @@ def __init__( distr_output: Optional[DistributionOutput] = None, scaling: bool = False, scaling_decoder_dynamic_feature: bool = False, + num_forking: Optional[int] = None, ) -> None: assert ( @@ -373,4 +377,5 @@ def __init__( trainer=trainer, scaling=scaling, scaling_decoder_dynamic_feature=scaling_decoder_dynamic_feature, + num_forking=num_forking, ) diff --git a/src/gluonts/model/seq2seq/_transform.py b/src/gluonts/model/seq2seq/_transform.py index c4ce354f90..248dbebf1f 100644 --- a/src/gluonts/model/seq2seq/_transform.py +++ b/src/gluonts/model/seq2seq/_transform.py @@ -26,16 +26,6 @@ from gluonts.transform import FlatMapTransformation, shift_timestamp -def pad_to_size(xs: np.array, size: int): - """Pads `xs` with 0 on the left on the first axis.""" - pad_length = size - xs.shape[0] - if pad_length <= 0: - return xs - - pad_width = [(pad_length, 0)] + ([(0, 0)] * (xs.ndim - 1)) - return np.pad(xs, mode="constant", pad_width=pad_width) - - class ForkingSequenceSplitter(FlatMapTransformation): """Forking sequence splitter.""" @@ -132,35 +122,39 @@ def flatmap_transform( ) for sampling_idx in sampling_indices: - # ensure start index is not negative - start_idx = max(0, sampling_idx - self.enc_len) - # irrelevant data should have been removed by now in the # transformation chain, so copying everything is ok out = data.copy() + enc_len_diff = sampling_idx - self.enc_len + dec_len_diff = sampling_idx - self.num_forking + + # ensure start indices are not negative + start_idx_enc = max(0, enc_len_diff) + start_idx_dec = max(0, dec_len_diff) + + # Define pad length indices for shorter time series of variable length being updated in place + pad_length_enc = max(0, -enc_len_diff) + pad_length_dec = max(0, -dec_len_diff) + for ts_field in list(ts_fields_counter.keys()): # target is 1d, this ensures ts is always 2d ts = np.atleast_2d(out[ts_field]).T + ts_len = ts.shape[1] if ts_fields_counter[ts_field] == 1: del out[ts_field] else: ts_fields_counter[ts_field] -= 1 - # take enc_len values from ts, depending on sampling_idx - slice = ts[start_idx:sampling_idx, :] - - ts_len = ts.shape[1] - past_piece = np.zeros( + out[self._past(ts_field)] = np.zeros( shape=(self.enc_len, ts_len), dtype=ts.dtype ) - if ts_field not in self.encoder_disabled_fields: - # if we have less than enc_len values, pad_left with 0 - past_piece = pad_to_size(slice, self.enc_len) - out[self._past(ts_field)] = past_piece + out[self._past(ts_field)][pad_length_enc:] = ts[ + start_idx_enc:sampling_idx, : + ] # exclude some fields at prediction time if ( @@ -169,61 +163,46 @@ def flatmap_transform( ): continue - # This is were some of the forking magic happens: - # For each of the encoder_len time-steps at which the decoder is applied we slice the - # corresponding inputs called decoder_fields to the appropriate dec_len if ts_field in self.decoder_series_fields: - - forking_dec_field = np.zeros( + out[self._future(ts_field)] = np.zeros( shape=(self.num_forking, self.dec_len, ts_len), dtype=ts.dtype, ) - # in case it's not disabled we copy the actual values if ts_field not in self.decoder_disabled_fields: - # In case we sample and index too close to the beginning of the time series we would run out of - # bounds (i.e. try to copy non existent time series data) to prepare the input for the decoder. - # Instead of copying the partially available data from the time series and padding it with - # zeros, we simply skip copying the partial data. Since copying data would result in overriding - # the 0 pre-initialized 3D array, the end result of skipping is that the affected 2D decoder - # inputs (entries of the 3D array - of which there are skip many) will still be all 0." - skip = max(0, self.num_forking - sampling_idx) - start_idx = max(0, sampling_idx - self.num_forking) - # For 2D column-major (Fortran) ordering transposed array strides = (dtype, dtype*n_rows) - # For standard row-major arrays, strides = (dtype*n_cols, dtype) - stride = ts.strides - forking_dec_field[skip:, :, :] = as_strided( - ts[ - start_idx - + 1 : start_idx - + 1 - + self.num_forking - - skip, - :, - ], + # This is where some of the forking magic happens: + # For each of the num_forking time-steps at which the decoder is applied we slice the + # corresponding inputs called decoder_fields to the appropriate dec_len + decoder_fields = ts[ + start_idx_dec + 1 : sampling_idx + 1, : + ] + # For default row-major arrays, strides = (dtype*n_cols, dtype). Since this array is transposed, + # it is stored in column-major (Fortran) ordering with strides = (dtype, dtype*n_rows) + stride = decoder_fields.strides + out[self._future(ts_field)][ + pad_length_dec: + ] = as_strided( + decoder_fields, shape=( - self.num_forking - skip, + self.num_forking - pad_length_dec, self.dec_len, ts_len, ), # strides for 2D array expanded to 3D array of shape (dim1, dim2, dim3) = - # strides for 2D array expanded to 3D array of shape (dim1, dim2, dim3) = - # (1, n_rows, n_cols). Note since this array has been transposed, it is stored in - # column-major (Fortan) ordering, i.e. for transposed data of shape (dim1, dim2, dim3), - # strides = (dtype, dtype * dim1, dtype*dim1*dim2) = (dtype, dtype, dtype*n_rows). + # (1, n_rows, n_cols). For transposed data, strides = + # (dtype, dtype * dim1, dtype*dim1*dim2) = (dtype, dtype, dtype*n_rows). strides=stride[0:1] + stride, ) + # edge case for prediction_length = 1 - if forking_dec_field.shape[-1] == 1: + if out[self._future(ts_field)].shape[-1] == 1: out[self._future(ts_field)] = np.squeeze( - forking_dec_field, axis=-1 + out[self._future(ts_field)], axis=-1 ) - else: - out[self._future(ts_field)] = forking_dec_field - # So far pad indicator not in use + # So far encoder pad indicator not in use - + # Marks that left padding for the encoder will occur on shorter time series pad_indicator = np.zeros(self.enc_len) - pad_length = max(0, self.enc_len - sampling_idx) - pad_indicator[:pad_length] = True + pad_indicator[:pad_length_enc] = True out[self._past(self.is_pad_out)] = pad_indicator # So far pad forecast_start not in use diff --git a/src/gluonts/mx/block/decoder.py b/src/gluonts/mx/block/decoder.py index 8a20f7cf11..a2f84b1ead 100644 --- a/src/gluonts/mx/block/decoder.py +++ b/src/gluonts/mx/block/decoder.py @@ -43,11 +43,12 @@ def hybrid_forward( Parameters ---------- dynamic_input - dynamic_features, shape (batch_size, sequence_length, num_features) + dynamic_features, shape (batch_size, sequence_length, channels_seq[-1] + + 1 + decoder_length * num_feat_dynamic) or (N, T, C) static_input - static features, shape (batch_size, num_features) or (N, C) + static features, shape (batch_size, channels_seq[-1] + 1) or (N, C) """ pass @@ -121,14 +122,14 @@ def hybrid_forward( dynamic_input dynamic_features, shape (batch_size, sequence_length, num_features) or (N, T, C) where sequence_length is equal to the encoder length, and num_features is equal - to channel_seq[-1] for the MQCNN for example. + to channels_seq[-1] + 1 + decoder_length * num_feat_dynamic for the MQ-CNN for example. static_input not used in this decoder. Returns ------- Tensor - mlp output, shape (0, 0, dec_len, final_dims). + mlp output, shape (batch_size, sequence_length, decoder_length, decoder_mlp_dim_seq[0]). """ mlp_output = self.model(dynamic_input) @@ -169,10 +170,7 @@ def __init__( ) def hybrid_forward( - self, - F, - static_input: Tensor, # (batch_size, static_input_dim) - dynamic_input: Tensor, # (batch_size, + self, F, static_input: Tensor, dynamic_input: Tensor ) -> Tensor: """ OneShotDecoder forward call @@ -184,15 +182,17 @@ def hybrid_forward( API in MXNet. static_input - static features, shape (batch_size, num_features) or (N, C) + static features, shape (batch_size, channels_seq[-1] + 1) or (N, C) dynamic_input - dynamic_features, shape (batch_size, sequence_length, num_features) + dynamic_features, shape (batch_size, sequence_length, channels_seq[-1] + + 1 + decoder_length * num_feat_dynamic) or (N, T, C) + Returns ------- Tensor - mlp output, shape (batch_size, dec_len, size of last layer) + mlp output, shape (batch_size, decoder_length, size of last layer) """ static_input_tile = self.expander(static_input).reshape( (0, self.decoder_length, self.static_outputs_per_time_step) diff --git a/src/gluonts/mx/block/enc2dec.py b/src/gluonts/mx/block/enc2dec.py index 9bc58f8567..de81424d99 100644 --- a/src/gluonts/mx/block/enc2dec.py +++ b/src/gluonts/mx/block/enc2dec.py @@ -45,22 +45,13 @@ def hybrid_forward( ---------- encoder_output_static - shape (batch_size, num_features) or (N, C) + shape (batch_size, channels_seq[-1] + 1) or (N, C) encoder_output_dynamic - shape (batch_size, sequence_length, num_features) or (N, T, C) + shape (batch_size, sequence_length, channels_seq[-1] + 1) or (N, T, C) future_features_dynamic - shape (batch_size, sequence_length, prediction_length, num_features) or (N, T, P, C`) - - - Returns - ------- - Tensor - shape (batch_size, num_features) or (N, C) - - Tensor - shape (batch_size, sequence_length, num_features) or (N, T, C) + shape (batch_size, sequence_length, prediction_length=decoder_length, num_feat_dynamic) or (N, T, P, C`) """ pass @@ -83,29 +74,29 @@ def hybrid_forward( ---------- encoder_output_static - shape (batch_size, num_features) or (N, C) + shape (batch_size, channels_seq[-1] + 1) or (N, C) encoder_output_dynamic - shape (batch_size, sequence_length, num_features) or (N, T, C) + shape (batch_size, sequence_length, channels_seq[-1] + 1) or (N, T, C) future_features_dynamic - shape (batch_size, sequence_length, prediction_length, num_features) or (N, T, P, C`) + shape (batch_size, sequence_length, prediction_length=decoder_length, num_feat_dynamic) or (N, T, P, C`) Returns ------- Tensor - shape (batch_size, num_features) or (N, C) + shape (batch_size, channels_seq[-1] + 1) or (N, C) Tensor - shape (batch_size, prediction_length, num_features_02) or (N, T, C) + shape (batch_size, sequence_length, channels_seq[-1] + 1) or (N, T, C) """ return encoder_output_static, encoder_output_dynamic class FutureFeatIntegratorEnc2Dec(Seq2SeqEnc2Dec): """ - Integrates the encoder_ouput_dynamic and future_features_dynamic into one + Integrates the encoder_output_dynamic and future_features_dynamic into one and passes them through as the dynamic input to the decoder. """ @@ -121,40 +112,36 @@ def hybrid_forward( ---------- encoder_output_static - shape (batch_size, num_features) or (N, C) + shape (batch_size, channels_seq[-1] + 1) or (N, C) encoder_output_dynamic - shape (batch_size, sequence_length, num_features) or (N, T, C) + shape (batch_size, sequence_length, channels_seq[-1] + 1) or (N, T, C) future_features_dynamic - shape (batch_size, sequence_length, prediction_length, num_features) or (N, T, P, C`) + shape (batch_size, sequence_length, prediction_length=decoder_length, num_feat_dynamic) or (N, T, P, C`) Returns ------- Tensor - shape (batch_size, num_features) or (N, C) + shape (batch_size, channels_seq[-1] + 1) or (N, C) Tensor - shape (batch_size, prediction_length, num_features_02) or (N, T, C) + shape (batch_size, prediction_length=decoder_length, channels_seq[-1] + 1 + decoder_length * num_feat_dynamic) or (N, T, C) - Tensor - shape (1,) """ # flatten the last two dimensions: - # => (batch_size, encoder_length, decoder_length * num_feature_dynamic) + # => (batch_size, sequence_length, decoder_length * num_feat_dynamic), where + # num_future_feat_dynamic = decoder_length * num_feat_dynamic future_features_dynamic = F.reshape( future_features_dynamic, shape=(0, 0, -1) ) # concatenate output of decoder and future_feat_dynamic covariates: - # => (batch_size, encoder_length, num_dec_input_dynamic + num_future_feat_dynamic) + # => (batch_size, sequence_length, num_dec_input_dynamic + num_future_feat_dynamic) total_dec_input_dynamic = F.concat( encoder_output_dynamic, future_features_dynamic, dim=2 ) - return ( - encoder_output_static, - total_dec_input_dynamic, - ) + return (encoder_output_static, total_dec_input_dynamic) diff --git a/src/gluonts/mx/block/encoder.py b/src/gluonts/mx/block/encoder.py index 3aeae78c8d..0fabcb2e54 100644 --- a/src/gluonts/mx/block/encoder.py +++ b/src/gluonts/mx/block/encoder.py @@ -52,19 +52,19 @@ def hybrid_forward( shape (batch_size, sequence_length) static_features static features, - shape (batch_size, num_static_features) + shape (batch_size, num_feat_static) dynamic_features dynamic_features, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, num_feat_dynamic) Returns ------- Tensor static code, - shape (batch_size, num_static_features) + shape (batch_size, num_feat_static) Tensor dynamic code, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, num_feat_dynamic) """ raise NotImplementedError @@ -89,17 +89,17 @@ def _assemble_inputs( shape (batch_size, sequence_length, 1) static_features static features, - shape (batch_size, num_static_features) + shape (batch_size, num_feat_static) dynamic_features dynamic_features, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, num_feat_dynamic) Returns ------- Tensor combined features, shape (batch_size, sequence_length, - num_static_features + num_dynamic_features + 1) + num_feat_static + num_feat_dynamic + 1) """ helper_ones = F.ones_like(target) # Ones of (N, T, 1) @@ -194,18 +194,18 @@ def hybrid_forward( shape (batch_size, sequence_length, 1) static_features static features, - shape (batch_size, num_static_features) + shape (batch_size, num_feat_static) dynamic_features dynamic_features, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, num_feat_dynamic) Returns ------- Tensor static code, - shape (batch_size, num_static_features) + shape (batch_size, channel_seqs + (1) if use_residual) Tensor dynamic code, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, channel_seqs + (1) if use_residual) """ if self.use_dynamic_feat and self.use_static_feat: @@ -302,19 +302,19 @@ def hybrid_forward( shape (batch_size, sequence_length, 1) static_features static features, - shape (batch_size, num_static_features) + shape (batch_size, num_feat_static) dynamic_features dynamic_features, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, num_feat_dynamic) Returns ------- Tensor static code, - shape (batch_size, num_static_features) + shape (batch_size, num_feat_static) Tensor dynamic code, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, num_feat_dynamic) """ if self.use_dynamic_feat and self.use_static_feat: inputs = self._assemble_inputs( @@ -367,19 +367,19 @@ def hybrid_forward( shape (batch_size, sequence_length) static_features static features, - shape (batch_size, num_static_features) + shape (batch_size, num_feat_static) dynamic_features dynamic_features, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, num_feat_dynamic) Returns ------- Tensor static code, - shape (batch_size, num_static_features) + shape (batch_size, num_feat_static) Tensor dynamic code, - shape (batch_size, sequence_length, num_dynamic_features) + shape (batch_size, sequence_length, num_feat_dynamic) """ inputs = self._assemble_inputs( diff --git a/src/gluonts/mx/block/feature.py b/src/gluonts/mx/block/feature.py index 36eb29b8a3..d0f20b4246 100644 --- a/src/gluonts/mx/block/feature.py +++ b/src/gluonts/mx/block/feature.py @@ -93,7 +93,7 @@ def hybrid_forward(self, F, features: Tensor) -> Tensor: Returns ------- concatenated_tensor: Tensor - Concatenated tensor of embeddings whth shape: (N,T,C) or (N,C), + Concatenated tensor of embeddings with shape: (N,T,C) or (N,C), where C is the sum of the embedding dimensions for each categorical feature, i.e. C = sum(self.config.embedding_dims). """ diff --git a/src/gluonts/support/util.py b/src/gluonts/support/util.py index 68823d5503..989f549166 100644 --- a/src/gluonts/support/util.py +++ b/src/gluonts/support/util.py @@ -12,22 +12,14 @@ # permissions and limitations under the License. # Standard library imports +import functools import inspect import os import signal import tempfile import time from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Union, - cast, -) +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast # Third-party imports import mxnet as mx @@ -42,6 +34,20 @@ MXNET_HAS_ERFINV = hasattr(mx.nd, "erfinv") +def pad_to_size( + x: np.array, size: int, axis: int = 0, is_right_pad: bool = True +): + """Pads `xs` with 0 on the right (default) on the specified axis, which is the first axis by default.""" + pad_length = size - x.shape[axis] + if pad_length <= 0: + return x + + pad_width = [(0, 0)] * x.ndim + right_pad = (0, pad_length) + pad_width[axis] = right_pad if is_right_pad else right_pad[::-1] + return np.pad(x, mode="constant", pad_width=pad_width) + + class Timer: """Context manager for measuring the time of enclosed code fragments.""" diff --git a/test/dataset/test_stat.py b/test/dataset/test_stat.py index 78616ba387..bd1571bfbc 100644 --- a/test/dataset/test_stat.py +++ b/test/dataset/test_stat.py @@ -102,34 +102,43 @@ def ts( class DatasetStatisticsTest(unittest.TestCase): def test_dataset_statistics(self) -> None: - n = 2 - T = 10 + num_time_series = 3 + num_time_observations = 10 + num_feat_dynamic_real = 2 + num_past_feat_dynamic_real = 3 + num_feat_dynamic_cat = 2 + num_missing_values = 0 # use integers to avoid float conversion that can fail comparison np.random.seed(0) - targets = np.random.randint(0, 10, (n, T)) + targets = np.random.randint( + 0, 10, (num_time_series - 1, num_time_observations) + ) scale_histogram = ScaleHistogram() - for i in range(n): + for i in range(num_time_series - 1): scale_histogram.add(targets[i, :]) scale_histogram.add([]) expected = DatasetStatistics( integer_dataset=True, - num_time_series=n + 1, + num_time_series=num_time_series, # includes empty array num_time_observations=targets.size, - mean_target_length=T * 2 / 3, + mean_target_length=num_time_observations + * (num_time_series - 1) + / num_time_series, + max_target_length=num_time_observations, min_target=targets.min(), mean_target=targets.mean(), mean_abs_target=targets.mean(), max_target=targets.max(), feat_static_real=[{0.1}, {0.2, 0.3}], feat_static_cat=[{1}, {2, 3}], - num_feat_dynamic_real=2, - num_past_feat_dynamic_real=3, - num_feat_dynamic_cat=2, - num_missing_values=0, + num_feat_dynamic_real=num_feat_dynamic_real, + num_past_feat_dynamic_real=num_past_feat_dynamic_real, + num_feat_dynamic_cat=num_feat_dynamic_cat, + num_missing_values=num_missing_values, scale_histogram=scale_histogram, ) @@ -141,25 +150,25 @@ def test_dataset_statistics(self) -> None: target=targets[0, :], feat_static_cat=[1, 2], feat_static_real=[0.1, 0.2], - num_feat_dynamic_cat=2, - num_feat_dynamic_real=2, - num_past_feat_dynamic_real=3, + num_feat_dynamic_cat=num_feat_dynamic_cat, + num_feat_dynamic_real=num_feat_dynamic_real, + num_past_feat_dynamic_real=num_past_feat_dynamic_real, ), make_time_series( target=targets[1, :], feat_static_cat=[1, 3], feat_static_real=[0.1, 0.3], - num_feat_dynamic_cat=2, - num_feat_dynamic_real=2, - num_past_feat_dynamic_real=3, + num_feat_dynamic_cat=num_feat_dynamic_cat, + num_feat_dynamic_real=num_feat_dynamic_real, + num_past_feat_dynamic_real=num_past_feat_dynamic_real, ), make_time_series( target=np.array([]), feat_static_cat=[1, 3], feat_static_real=[0.1, 0.3], - num_feat_dynamic_cat=2, - num_feat_dynamic_real=2, - num_past_feat_dynamic_real=3, + num_feat_dynamic_cat=num_feat_dynamic_cat, + num_feat_dynamic_real=num_feat_dynamic_real, + num_past_feat_dynamic_real=num_past_feat_dynamic_real, ), ], )