Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quick Start Example with "NegativeBinomialOutput" throw exception #91

Closed
houchangtao opened this issue Jun 7, 2019 · 2 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@houchangtao
Copy link

houchangtao commented Jun 7, 2019

System: Ubuntu 16.04
Python: 3.6.4
mxnet: 1.4.1

Code to reproduce:

import pandas as pd
url = "https://raw.githubusercontent.com/numenta/NAB/master/data/realTweets/Twitter_volume_AMZN.csv"
df = pd.read_csv(url, header=0, index_col=0)
from gluonts.dataset.common import ListDataset
training_data = ListDataset(
    [{"start": df.index[0], "target": df.value[:"2015-04-05 00:00:00"]}],
    freq = "5min"
)
from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
from gluonts.distribution import NegativeBinomialOutput, StudentTOutput
estimator = DeepAREstimator(freq="5min", prediction_length=12, distr_output=NegativeBinomialOutput(),
                            trainer=Trainer(epochs=10))
predictor = estimator.train(training_data=training_data)
Exception:
DeferredInitializationError               Traceback (most recent call last)
/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in _call_cached_op(self, *args)
    802             cargs = [args[i] if is_arg else i.data()
--> 803                      for is_arg, i in self._cached_op_args]
    804         except DeferredInitializationError:

/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in <listcomp>(.0)
    802             cargs = [args[i] if is_arg else i.data()
--> 803                      for is_arg, i in self._cached_op_args]
    804         except DeferredInitializationError:

/usr/local/lib/python3.6/site-packages/mxnet/gluon/parameter.py in data(self, ctx)
    493                                "instead." % (self.name, str(ctx), self._stype))
--> 494         return self._check_and_get(self._data, ctx)
    495 

/usr/local/lib/python3.6/site-packages/mxnet/gluon/parameter.py in _check_and_get(self, arr_list, ctx)
    207                 "You can also avoid deferred initialization by specifying in_units, " \
--> 208                 "num_features, etc., for network layers."%(self.name))
    209         raise RuntimeError(

DeferredInitializationError: Parameter 'deepartrainingnetwork0_lstm0_i2h_weight' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers.

During handling of the above exception, another exception occurred:

MXNetError                                Traceback (most recent call last)
/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in _deferred_infer_shape(self, *args)
    788         try:
--> 789             self.infer_shape(*args)
    790         except Exception as e:

/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in infer_shape(self, *args)
    861         """Infers shape of Parameters from inputs."""
--> 862         self._infer_attrs('infer_shape', 'shape', *args)
    863 

/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in _infer_attrs(self, infer_fn, attr, *args)
    850             arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
--> 851                 **{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
    852             if arg_attrs is None:

/usr/local/lib/python3.6/site-packages/mxnet/symbol/symbol.py in infer_shape(self, *args, **kwargs)
    995         try:
--> 996             res = self._infer_shape_impl(False, *args, **kwargs)
    997             if res[1] is None:

/usr/local/lib/python3.6/site-packages/mxnet/symbol/symbol.py in _infer_shape_impl(self, partial, *args, **kwargs)
   1125             ctypes.byref(aux_shape_data),
-> 1126             ctypes.byref(complete)))
   1127         if complete.value != 0:

/usr/local/lib/python3.6/site-packages/mxnet/base.py in check_call(ret)
    251     if ret != 0:
--> 252         raise MXNetError(py_str(_LIB.MXGetLastError()))
    253 

MXNetError: Error in operator deepartrainingnetwork0__mul1: [10:04:54] /home/travis/build/dmlc/mxnet-distro/mxnet-build/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)) Incompatible attr in node deepartrainingnetwork0__mul1 at 1-th input: expected [32,24], got [32,32,24]

Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x23d55a) [0x7f44332ff55a]
[bt] (1) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x23dbc1) [0x7f44332ffbc1]
[bt] (2) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x4ee78d) [0x7f44335b078d]
[bt] (3) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x5587d6) [0x7f443361a7d6]
[bt] (4) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x762338) [0x7f4433824338]
[bt] (5) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2be2cea) [0x7f4435ca4cea]
[bt] (6) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2be5658) [0x7f4435ca7658]
[bt] (7) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(MXSymbolInferShape+0x15ba) [0x7f4435c140ca]
[bt] (8) /usr/local/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7f44791201da]
[bt] (9) /usr/local/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(ffi_call+0x26b) [0x7f447911c5cb]



During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-6-7512b68a283e> in <module>()
      1 estimator = DeepAREstimator(freq="5min", prediction_length=12, distr_output=NegativeBinomialOutput(),
      2                             trainer=Trainer(epochs=10))
----> 3 predictor = estimator.train(training_data=training_data)

/usr/local/lib/python3.6/site-packages/gluonts-0.1.1-py3.6.egg/gluonts/model/estimator.py in train(self, training_data)
    187     def train(self, training_data: Dataset) -> Predictor:
    188 
--> 189         training_transformation, trained_net = self.train_model(training_data)
    190 
    191         # ensure that the prediction network is created within the same MXNet

/usr/local/lib/python3.6/site-packages/gluonts-0.1.1-py3.6.egg/gluonts/model/estimator.py in train_model(self, training_data)
    180             net=trained_net,
    181             input_names=get_hybrid_forward_input_names(trained_net),
--> 182             train_iter=training_data_loader,
    183         )
    184 

/usr/local/lib/python3.6/site-packages/gluonts-0.1.1-py3.6.egg/gluonts/trainer/_base.py in __call__(self, net, input_names, train_iter)
    256 
    257                             with mx.autograd.record():
--> 258                                 output = net(*inputs)
    259 
    260                                 # network can returns several outputs, the first being always the loss

/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args)
    538             hook(self, args)
    539 
--> 540         out = self.forward(*args)
    541 
    542         for hook in self._forward_hooks.values():

/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in forward(self, x, *args)
    905             with x.context as ctx:
    906                 if self._active:
--> 907                     return self._call_cached_op(x, *args)
    908 
    909                 try:

/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in _call_cached_op(self, *args)
    803                      for is_arg, i in self._cached_op_args]
    804         except DeferredInitializationError:
--> 805             self._deferred_infer_shape(*args)
    806             cargs = []
    807             for is_arg, i in self._cached_op_args:

/usr/local/lib/python3.6/site-packages/mxnet/gluon/block.py in _deferred_infer_shape(self, *args)
    791             error_msg = "Deferred initialization failed because shape"\
    792                         " cannot be inferred. {}".format(e)
--> 793             raise ValueError(error_msg)
    794 
    795     def _call_cached_op(self, *args):

ValueError: Deferred initialization failed because shape cannot be inferred. Error in operator deepartrainingnetwork0__mul1: [10:04:54] /home/travis/build/dmlc/mxnet-distro/mxnet-build/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)) Incompatible attr in node deepartrainingnetwork0__mul1 at 1-th input: expected [32,24], got [32,32,24]

Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x23d55a) [0x7f44332ff55a]
[bt] (1) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x23dbc1) [0x7f44332ffbc1]
[bt] (2) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x4ee78d) [0x7f44335b078d]
[bt] (3) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x5587d6) [0x7f443361a7d6]
[bt] (4) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x762338) [0x7f4433824338]
[bt] (5) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2be2cea) [0x7f4435ca4cea]
[bt] (6) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2be5658) [0x7f4435ca7658]
[bt] (7) /usr/local/lib/python3.6/site-packages/mxnet/libmxnet.so(MXSymbolInferShape+0x15ba) [0x7f4435c140ca]
[bt] (8) /usr/local/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7f44791201da]
[bt] (9) /usr/local/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(ffi_call+0x26b) [0x7f447911c5cb]
@lostella lostella added the bug Something isn't working label Jun 7, 2019
@lostella lostella self-assigned this Jun 8, 2019
@lostella
Copy link
Contributor

lostella commented Jun 8, 2019

The bug is in the expand_dims here, which should just be omitted.

I'll open a PR as soon as I have tests in place for this.

@lostella
Copy link
Contributor

lostella commented Jun 9, 2019

Thanks for reporting this! The fix will be included in the next release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants