Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

NightlyTestsForBinaries tutorials test broken #15374

Open
lebeg opened this issue Jun 26, 2019 · 15 comments
Open

NightlyTestsForBinaries tutorials test broken #15374

lebeg opened this issue Jun 26, 2019 · 15 comments
Labels

Comments

@lebeg
Copy link
Contributor

lebeg commented Jun 26, 2019

The nightly test for tutorials is broken, in particular test_tutorials.test_amp

http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/NightlyTestsForBinaries/detail/master/355/pipeline

ERROR:root:An error occurred while executing the following cell:
------------------
mbox_loss = gcv.loss.SSDMultiBoxLoss()

for epoch in range(1):
    ce_metric.reset()
    smoothl1_metric.reset()
    tic = time.time()
    btic = time.time()

    for i, batch in enumerate(train_data):
        batch_size = batch[0].shape[0]
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
        box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
        with autograd.record():
            cls_preds = []
            box_preds = []
            for x in data:
                cls_pred, box_pred, _ = net(x)
                cls_preds.append(cls_pred)
                box_preds.append(box_pred)
            sum_loss, cls_loss, box_loss = mbox_loss(
                cls_preds, box_preds, cls_targets, box_targets)
            autograd.backward(sum_loss)
        trainer.step(1)
        ce_metric.update(0, [l * batch_size for l in cls_loss])
        smoothl1_metric.update(0, [l * batch_size for l in box_loss])
        if not (i + 1) % 50:
            name1, loss1 = ce_metric.get()
            name2, loss2 = smoothl1_metric.get()
            logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
                epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2))
        btic = time.time()
------------------

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
<ipython-input-6-67de7c0b11cc> in <module>
     20                 box_preds.append(box_pred)
     21             sum_loss, cls_loss, box_loss = mbox_loss(
---> 22                 cls_preds, box_preds, cls_targets, box_targets)
     23             autograd.backward(sum_loss)
     24         trainer.step(1)

/work/mxnet/python/mxnet/gluon/block.py in __call__(self, *args)
    546             hook(self, args)
    547 
--> 548         out = self.forward(*args)
    549 
    550         for hook in self._forward_hooks.values():

/usr/local/lib/python3.5/dist-packages/gluoncv/loss.py in forward(self, cls_pred, box_pred, cls_target, box_target)
    157             rank = (cls_loss * (pos - 1)).argsort(axis=1).argsort(axis=1)
    158             hard_negative = rank < nd.maximum(self._min_hard_negatives, pos.sum(axis=1)
--> 159                                               * self._negative_mining_ratio).expand_dims(-1)
    160             # mask out if not positive or negative
    161             cls_loss = nd.where((pos + hard_negative) > 0, cls_loss, nd.zeros_like(cls_loss))

/work/mxnet/python/mxnet/ndarray/ndarray.py in __lt__(self, other)
    342     def __lt__(self, other):
    343         """x.__lt__(y) <=> x<y <=> mx.nd.lesser(x, y) """
--> 344         return lesser(self, other)
    345 
    346     def __le__(self, other):

/work/mxnet/python/mxnet/ndarray/ndarray.py in lesser(lhs, rhs)
   3506         lambda x, y: 1 if x < y else 0,
   3507         _internal._lesser_scalar,
-> 3508         _internal._greater_scalar)
   3509     # pylint: enable= no-member, protected-access
   3510 

/work/mxnet/python/mxnet/ndarray/ndarray.py in _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar)
   2706         return lfn_scalar(lhs, float(rhs))
   2707     elif isinstance(rhs, NDArray):
-> 2708         return fn_array(lhs, rhs)
   2709     else:
   2710         raise TypeError('type %s not supported' % str(type(rhs)))

/work/mxnet/python/mxnet/ndarray/register.py in broadcast_lesser(lhs, rhs, out, name, **kwargs)

/work/mxnet/python/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out)
     90         c_str_array(keys),
     91         c_str_array([str(s) for s in vals]),
---> 92         ctypes.byref(out_stypes)))
     93 
     94     if original_output is not None:

/work/mxnet/python/mxnet/base.py in check_call(ret)
    251     """
    252     if ret != 0:
--> 253         raise MXNetError(py_str(_LIB.MXGetLastError()))
    254 
    255 

MXNetError: [20:54:33] /work/mxnet/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected int32, got float32
Stack trace:
  [bt] (0) /work/mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x3c) [0x7f465e8070bc]
  [bt] (1) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseAttr<int, &mxnet::op::type_is_none, &mxnet::op::type_assign, true, &mxnet::op::type_string[abi:cxx11], -1, -1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*, int const&)::{lambda(std::vector<int, std::allocator<int> > const&, unsigned long, char const*)#1}::operator()(std::vector<int, std::allocator<int> > const&, unsigned long, char const*) const+0x346) [0x7f465e953636]
  [bt] (2) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseAttr<int, &mxnet::op::type_is_none, &mxnet::op::type_assign, true, &mxnet::op::type_string[abi:cxx11], -1, -1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*, int const&)+0x25d) [0x7f465e954a0d]
  [bt] (3) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseType<2, 1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*)+0x34f) [0x7f465eb9bc5f]
  [bt] (4) /work/mxnet/python/mxnet/../../lib/libmxnet.so(std::_Function_handler<bool (nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*), bool (*)(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*)>::_M_invoke(std::_Any_data const&, nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*&&, std::vector<int, std::allocator<int> >*&&)+0x1d) [0x7f465e8adfcd]
  [bt] (5) /work/mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::imperative::SetShapeType(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::DispatchMode*)+0x22d5) [0x7f4661792ba5]
  [bt] (6) /work/mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x110) [0x7f466179e1b0]
  [bt] (7) /work/mxnet/python/mxnet/../../lib/libmxnet.so(MXImperativeInvokeImpl(void*, int, void**, int*, void***, int, char const**, char const**)+0x1c9) [0x7f466228acf9]
  [bt] (8) /work/mxnet/python/mxnet/../../lib/libmxnet.so(MXImperativeInvokeEx+0x8f) [0x7f466228b1ff]

MXNetError: [20:54:33] /work/mxnet/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected int32, got float32
Stack trace:
  [bt] (0) /work/mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x3c) [0x7f465e8070bc]
  [bt] (1) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseAttr<int, &mxnet::op::type_is_none, &mxnet::op::type_assign, true, &mxnet::op::type_string[abi:cxx11], -1, -1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*, int const&)::{lambda(std::vector<int, std::allocator<int> > const&, unsigned long, char const*)#1}::operator()(std::vector<int, std::allocator<int> > const&, unsigned long, char const*) const+0x346) [0x7f465e953636]
  [bt] (2) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseAttr<int, &mxnet::op::type_is_none, &mxnet::op::type_assign, true, &mxnet::op::type_string[abi:cxx11], -1, -1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*, int const&)+0x25d) [0x7f465e954a0d]
  [bt] (3) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseType<2, 1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*)+0x34f) [0x7f465eb9bc5f]
  [bt] (4) /work/mxnet/python/mxnet/../../lib/libmxnet.so(std::_Function_handler<bool (nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*), bool (*)(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*)>::_M_invoke(std::_Any_data const&, nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*&&, std::vector<int, std::allocator<int> >*&&)+0x1d) [0x7f465e8adfcd]
  [bt] (5) /work/mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::imperative::SetShapeType(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::DispatchMode*)+0x22d5) [0x7f4661792ba5]
  [bt] (6) /work/mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x110) [0x7f466179e1b0]
  [bt] (7) /work/mxnet/python/mxnet/../../lib/libmxnet.so(MXImperativeInvokeImpl(void*, int, void**, int*, void***, int, char const**, char const**)+0x1c9) [0x7f466228acf9]
  [bt] (8) /work/mxnet/python/mxnet/../../lib/libmxnet.so(MXImperativeInvokeEx+0x8f) [0x7f466228b1ff]

======================================================================
FAIL: test_tutorials.test_amp
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/nose/case.py", line 198, in runTest
    self.test(*self.arg)
  File "/work/mxnet/tests/tutorials/test_tutorials.py", line 209, in test_amp
    assert _test_tutorial_nb('amp/amp_tutorial')
AssertionError
@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Test

@lebeg
Copy link
Contributor Author

lebeg commented Jun 26, 2019

The test was broken by this PR #15170: [MXNET-1413] Adding Large Tensor support for sort operators.
The reason is that the default type of NDArray for ArgSort operator changed from kFloat32 to kInt64 or kInt32.

#if USE_INT64_TENSOR_SIZE == 1
    .set_default(mshadow::kInt64)
#else
    .set_default(mshadow::kInt32)
#endif

This has caused a problem in gluon-cv in SSDMultiBoxLoss in particular at this line

rank = (cls_loss * (pos - 1)).argsort(axis=1).argsort(axis=1)
hard_negative = rank < nd.maximum(self._min_hard_negatives, pos.sum(axis=1)
                                              * self._negative_mining_ratio).expand_dims(-1)
MXNetError: [20:54:33] /work/mxnet/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected int32, got float32
Stack trace:
  [bt] (0) /work/mxnet/python/mxnet/../../lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x3c) [0x7f465e8070bc]
  [bt] (1) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseAttr<int, &mxnet::op::type_is_none, &mxnet::op::type_assign, true, &mxnet::op::type_string[abi:cxx11], -1, -1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*, int const&)::{lambda(std::vector<int, std::allocator<int> > const&, unsigned long, char const*)#1}::operator()(std::vector<int, std::allocator<int> > const&, unsigned long, char const*) const+0x346) [0x7f465e953636]
  [bt] (2) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseAttr<int, &mxnet::op::type_is_none, &mxnet::op::type_assign, true, &mxnet::op::type_string[abi:cxx11], -1, -1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*, int const&)+0x25d) [0x7f465e954a0d]
  [bt] (3) /work/mxnet/python/mxnet/../../lib/libmxnet.so(bool mxnet::op::ElemwiseType<2, 1>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*)+0x34f) [0x7f465eb9bc5f]
  [bt] (4) /work/mxnet/python/mxnet/../../lib/libmxnet.so(std::_Function_handler<bool (nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*), bool (*)(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*)>::_M_invoke(std::_Any_data const&, nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*&&, std::vector<int, std::allocator<int> >*&&)+0x1d) [0x7f465e8adfcd]
  [bt] (5) /work/mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::imperative::SetShapeType(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::DispatchMode*)+0x22d5) [0x7f4661792ba5]
  [bt] (6) /work/mxnet/python/mxnet/../../lib/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x110) [0x7f466179e1b0]
  [bt] (7) /work/mxnet/python/mxnet/../../lib/libmxnet.so(MXImperativeInvokeImpl(void*, int, void**, int*, void***, int, char const**, char const**)+0x1c9) [0x7f466228acf9]
  [bt] (8) /work/mxnet/python/mxnet/../../lib/libmxnet.so(MXImperativeInvokeEx+0x8f) [0x7f466228b1ff]

@lebeg
Copy link
Contributor Author

lebeg commented Jun 26, 2019

@mxnet-label-bot add labels Test

@lebeg
Copy link
Contributor Author

lebeg commented Jun 26, 2019

@access2rohit could you take a look and say whether we should revert your PR or recommend GluonCV to adopt to new behaviour?

@lebeg
Copy link
Contributor Author

lebeg commented Jun 26, 2019

In my opinion it's not very good that argsort changes the type of the tensor it's applied to.

@vdantu
Copy link
Contributor

vdantu commented Jun 26, 2019

@mxnet-label-bot add [Test]

@apeforest
Copy link
Contributor

#15360 will hopefully fix this.

@access2rohit
Copy link
Contributor

@lebeg I will merge my PR once unix tests pass. Currently they are failing on unix CI/CD build, which is independent of my code change.
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/mxnet-validation%2Funix-gpu/detail/PR-15360/3/pipeline

@lebeg
Copy link
Contributor Author

lebeg commented Jun 27, 2019

I've restarted both unix verification jobs.

@Chancebair
Copy link
Contributor

Chancebair commented Jul 9, 2019

Unfortunately this is still broken as of Jul 8th.

@Chancebair
Copy link
Contributor

@perdasilva
Copy link
Contributor

@apeforest
Copy link
Contributor

@access2rohit Can you help to check if the nightly test failure is still related to your topk change?

@access2rohit
Copy link
Contributor

Ran amp tutorial manually and found no error

Output:
[00:13:37] src/base.cc:84: Upgrade advisory: this mxnet has been built against cuDNN lib version 7401, which is older than the oldest version tested by CI (7600). Set MXNET_CUDNN_LIB_CHECKING=0 to quiet this warning. Model file is not found. Downloading. Downloading /home/ubuntu/.mxnet/models/resnet50_v1-cc729d95.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet50_v1-cc729d95.zip... 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 57421/57421 [00:01<00:00, 55765.69KB/s] [00:13:42] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable) INFO:root:[Epoch 0][Batch 49], Speed: 8.813 samples/sec, CrossEntropy=1.190, SmoothL1=0.688 INFO:root:[Epoch 0][Batch 99], Speed: 8.678 samples/sec, CrossEntropy=0.693, SmoothL1=0.536 INFO:root:[Epoch 0][Batch 149], Speed: 8.700 samples/sec, CrossEntropy=0.500, SmoothL1=0.453 INFO:root:[Epoch 0][Batch 199], Speed: 8.599 samples/sec, CrossEntropy=0.396, SmoothL1=0.400 INFO:root:Using AMP [00:17:15] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable) INFO:root:[Epoch 0][Batch 49], Speed: 9.219 samples/sec, CrossEntropy=1.165, SmoothL1=0.684 INFO:root:[Epoch 0][Batch 99], Speed: 9.356 samples/sec, CrossEntropy=0.682, SmoothL1=0.533 INFO:root:[Epoch 0][Batch 149], Speed: 9.206 samples/sec, CrossEntropy=0.493, SmoothL1=0.451 INFO:root:[Epoch 0][Batch 199], Speed: 9.209 samples/sec, CrossEntropy=0.391, SmoothL1=0.398 INFO:root:downloaded http://data.mxnet.io/models/imagenet/resnet/18-layers/resnet-18-symbol.json into model/imagenet1k-resnet-18-symbol.json successfully INFO:root:downloaded http://data.mxnet.io/models/imagenet/resnet/18-layers/resnet-18-0000.params into model/imagenet1k-resnet-18-0000.params successfully [00:20:27] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v0.8.0. Attempting to upgrade... [00:20:27] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded! [00:20:29] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable) Conversion and Inference completed successfully INFO:root:model/imagenet1k-resnet-18-symbol.json exists, skipping download INFO:root:model/imagenet1k-resnet-18-0000.params exists, skipping download [00:20:33] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v0.8.0. Attempting to upgrade... [00:20:33] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded! Conversion and Inference completed successfully INFO:root:Saved checkpoint to "amp_tutorial_model-0000.params"

@access2rohit
Copy link
Contributor

`$ python -c "from mxnet.runtime import Features; print(Features())"

[✔ CUDA, ✔ CUDNN, ✖ NCCL, ✖ CUDA_RTC, ✖ TENSORRT, ✔ CPU_SSE, ✔ CPU_SSE2, ✔ CPU_SSE3, ✔ CPU_SSE4_1, ✔ CPU_SSE4_2, ✖ CPU_SSE4A, ✔ CPU_AVX, ✖ CPU_AVX2, ✖ OPENMP, ✖ SSE, ✔ F16C, ✖ JEMALLOC, ✔ BLAS_OPEN, ✖ BLAS_ATLAS, ✖ BLAS_MKL, ✖ BLAS_APPLE, ✖ LAPACK, ✖ MKLDNN, ✖ OPENCV, ✖ CAFFE, ✖ PROFILER, ✖ DIST_KVSTORE, ✖ CXX14, ✔ INT64_TENSOR_SIZE, ✖ SIGNAL_HANDLER, ✖ DEBUG, ✖ TVM_OP]`

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

8 participants