Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MXNet 1.8.0.post0 sparse feature segmentation fault #20197

Open
barry-jin opened this issue Apr 21, 2021 · 0 comments
Open

MXNet 1.8.0.post0 sparse feature segmentation fault #20197

barry-jin opened this issue Apr 21, 2021 · 0 comments
Labels
Bug Sparse v1.x Targeting v1.x branch

Comments

@barry-jin
Copy link
Contributor

Description

GluonNLP v0.x branch CI is blocked after switching from MXNet 1.7.0.post1 to MXNet 1.8.0.post0 (tracked in dmlc/gluon-nlp#1559). It looks like the sparse feature in MXNet 1.8.0 will cause segmentation fault.

Error Message

Segmentation fault: 11

terminate called without an active exception
Aborted (core dumped)

To Reproduce

import mxnet as mx
from mxnet import nd, gluon
from mxnet.gluon import Block, HybridBlock

class _Helper(HybridBlock):
    def __init__(self, num_classes, num_sampled, in_unit):
        super(_Helper, self).__init__()
        self._num_classes = num_classes
        self._num_sampled = num_sampled
        self._in_unit = in_unit

    def hybrid_forward(self, F, x, sampled_values, label, w_all, b_all):
        """Forward computation."""
        sampled_candidates, expected_count_sampled, expected_count_true = sampled_values
        # (num_sampled, in_unit)
        w_sampled = w_all.slice(begin=(0, 0), end=(self._num_sampled, None))
        w_true = w_all.slice(begin=(self._num_sampled, 0), end=(None, None))
        b_sampled = b_all.slice(begin=(0,), end=(self._num_sampled,))
        b_true = b_all.slice(begin=(self._num_sampled,), end=(None,))
        # true pred
        # (batch_size, 1)
        x = x.reshape((-1, self._in_unit))
        pred_true = (w_true * x).sum(axis=1) + b_true
        # samples pred
        # (batch_size, num_sampled)
        b_sampled = F.reshape(b_sampled, (-1,))
        pred_sampled = F.FullyConnected(x, weight=w_sampled, bias=b_sampled,
                                        num_hidden=self._num_sampled)

        # remove accidental hits
        label_vec = F.reshape(label, (-1, 1)).astype('int32')
        sample_vec = F.reshape(sampled_candidates, (1, -1)).astype('int32')
        mask = F.broadcast_equal(label_vec, sample_vec).astype('float32') * -1e37
        pred_sampled = pred_sampled + mask

        # subtract log(q)
        expected_count_sampled = expected_count_sampled.astype('float32')
        expected_count_sampled = expected_count_sampled.reshape(shape=(1, self._num_sampled))
        expected_count_true = expected_count_true.astype('float32').reshape((-1,))
        pred_true = pred_true - F.log(expected_count_true)
        pred_true = pred_true.reshape((-1, 1))
        pred_sampled = F.broadcast_sub(pred_sampled, F.log(expected_count_sampled))

        # pred and new_labels
        # (batch_size, 1+num_sampled)
        pred = F.concat(pred_true, pred_sampled, dim=1)
        new_label = F.zeros_like(label)
        return pred, new_label

class SimpleSparse(Block):
    def __init__(self, num_classes, num_sampled, in_unit):
        super(SimpleSparse, self).__init__()
        with self.name_scope():
            self.weight = self.params.get('weight', shape=(num_classes, in_unit),
                                          init=None, dtype='float32',
                                          grad_stype='row_sparse', stype='row_sparse')
            self.bias = self.params.get('bias', shape=(num_classes,), init='zeros',
                                        dtype='float32')
        self._num_classes = num_classes
        self._num_sampled = num_sampled
        self._in_unit = in_unit
        self._kwargs = {'input_dim': self._num_classes, 'output_dim': self._in_unit,
                        'sparse_grad': True}
        self._dense = _Helper(num_classes, num_sampled, in_unit)

    def forward(self, x, sampled_values, label): # pylint: disable=arguments-differ
        """Forward computation."""
        sampled_candidates, _, _ = sampled_values
        # (batch_size,)
        label = label.reshape(shape=(-1,))
        # (num_sampled+batch_size,)
        ids = nd.concat(sampled_candidates.astype('int32'), label.astype('int32'), dim=0)
        # lookup weights and biases
        weight = self.weight.row_sparse_data(ids)
        bias = self.bias.data(ids.context)
        # (num_sampled+batch_size, dim)
        w_all = nd.Embedding(data=ids, weight=weight, **self._kwargs)
        # (num_sampled+batch_size,)
        b_all = nd.take(bias, indices=ids)
        out, new_targets = self._dense(x, sampled_values, label, w_all, b_all)
        return out, new_targets

def test():
    ctx = mx.cpu()
    batch_size = 2
    num_sampled = 3
    vocab_size = 10
    num_hidden = 5
    model = SimpleSparse(vocab_size, num_sampled, num_hidden)
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
    model.hybridize()
    model.initialize(mx.init.Xavier(), ctx=ctx)
    trainer = mx.gluon.Trainer(model.collect_params(), 'sgd')
    x = mx.nd.ones((batch_size, num_hidden))
    y = mx.nd.ones((batch_size,))
    sampled_cls = mx.nd.ones((num_sampled,), dtype='float32')
    sampled_cls_cnt = mx.nd.ones((num_sampled,), dtype='float32')
    true_cls_cnt = mx.nd.ones((batch_size,), dtype='float32')
    samples = (sampled_cls, sampled_cls_cnt, true_cls_cnt)
    with mx.autograd.record():
        pred, new_y = model(x, samples, y)
        l = loss(pred, new_y)
    l.backward()
    mx.nd.waitall()

if __name__ == '__main__':
    test()

Steps to reproduce

(Paste the commands you ran that produced the error.)

Run script above
or

$ git clone https://github.com/gluon-nlp
$ cd gluon-nlp
$ git checkout v0.x
$ python3 -m pip install -e .[extra,dev]
$ python3 -m pytest tests/unittest/test_sampled_logits.py::test_is_softmax_loss

What have you tried to solve it?

Environment

We recommend using our script for collecting the diagnostic information with the following command
curl --retry 10 -s https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py | python3

Environment Information
# Paste the diagnose.py command output here
@szha szha added Sparse v1.x Targeting v1.x branch and removed needs triage labels Apr 21, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Bug Sparse v1.x Targeting v1.x branch
Projects
None yet
Development

No branches or pull requests

2 participants