Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Embedding(sparse_grad=True) + hybridization fallback issues #11206

Closed
leezu opened this issue Jun 8, 2018 · 2 comments · Fixed by #11306
Closed

Embedding(sparse_grad=True) + hybridization fallback issues #11206

leezu opened this issue Jun 8, 2018 · 2 comments · Fixed by #11306

Comments

@leezu
Copy link
Contributor

leezu commented Jun 8, 2018

With the mxnet master branch below code will fail. The combination of the following causes the failure:

  • Use of Embedding(sparse_grad=True) + other operations in one HybridBloc
  • Combining multiple forward passes during a single autograd record session


class Embedding(mx.gluon.HybridBlock):
    def __init__(self, num_tokens, embedding_size):
        super().__init__()
        self.num_tokens = num_tokens

        with self.name_scope():
            self.embedding = mx.gluon.nn.Embedding(
                num_tokens, embedding_size, sparse_grad=True)

    def hybrid_forward(self, F, words):
        emb = self.embedding(words)
        return emb + F.ones_like(emb)


ctx = mx.cpu()
embedding = Embedding(1000, 300)
embedding.initialize(ctx=ctx)
embedding.hybridize()

loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
with mx.autograd.record():
    emb_in = embedding(mx.nd.arange(10, ctx=ctx))
    emb_in_2 = embedding(mx.nd.arange(10, ctx=ctx))
    loss = emb_in.sum() + emb_in_2.sum()
loss.backward()

print(embedding.embedding.weight.grad().data)
 Storage type fallback detected:
operator = add_n
input storage types = [default, default, ]
output storage types = [row_sparse, ]
params = {"num_args" : 2, }
context.dev_mask = cpu
The operator with default storage type will be dispatched for execution. You're seeing this warning message because the operator above is unable to process the given ndarrays with specified storage types, context and parameter. Temporary dense ndarrays are generated in order to execute the operator. This does not affect the correctness of the programme. You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning.
Traceback (most recent call last):
  File "sparse_bug.py", line 30, in <module>
    print(embedding.embedding.weight.grad().data)
  File "/Users/lllausen/anaconda3/lib/python3.6/site-packages/mxnet/ndarray/sparse.py", line 728, in data
    return self._data()
  File "/Users/lllausen/anaconda3/lib/python3.6/site-packages/mxnet/ndarray/sparse.py", line 266, in _data
    self.wait_to_read()
  File "/Users/lllausen/anaconda3/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 1720, in wait_to_read
    check_call(_LIB.MXNDArrayWaitToRead(self.handle))
  File "/Users/lllausen/anaconda3/lib/python3.6/site-packages/mxnet/base.py", line 210, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [19:37:39] src/operator/contrib/../operator_common.h:493: Not implemented: operator = _backward_Embedding
input storage types = [default, default, ]
output storage types = [default, default, ]
params = {"input_dim" : 1000, "output_dim" : 300, "dtype" : float32, "sparse_grad" : True, }
context.dev_mask = cpu

@kalyc
Copy link
Contributor

kalyc commented Jun 14, 2018

Hi @leezu thanks for submitting this issue.
@sandeep-krishnamurthy could you add label "Sparse" & "Python" to this?

@eric-haibin-lin
Copy link
Member

The problem is that currently is output grad is not the immediate output of _backward_CachedOp, the storage types of _backward_CachedOp outputs are inferred as dense storage. This is because _backward_CachedOp didn't register FInferStorage, producing dense outputs by default.

In this example, the outputs of _backward_CachedOp are passed to add_n to produce row_sparse grad.

To solve this issue, we need to register FInferStorage for _backward_CachedOp, which performs subgraph storage type inference, and return the stype inference result to the caller.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants