Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

BucketingModule crash with parameters conditioned on bucket key #11384

Open
fhieber opened this issue Jun 25, 2018 · 3 comments
Open

BucketingModule crash with parameters conditioned on bucket key #11384

fhieber opened this issue Jun 25, 2018 · 3 comments

Comments

@fhieber
Copy link
Contributor

fhieber commented Jun 25, 2018

I am running into an issue with the BucketingModule when using a sym_gen function that uses a parameter variable only for certain bucket keys. Consider the following example code:

import mxnet as mx

def sym_gen(bucket_key):
    data = mx.sym.Variable('data')
    weight = mx.sym.Variable('weight')
    if bucket_key <= 1:
        out = mx.sym.FullyConnected(data=data, weight=weight, no_bias=True, num_hidden=2, flatten=False)
    else:
        out = data
    print(out.list_arguments())
    return out, ['data'], []


default_bucket_key = 10
mod = mx.module.BucketingModule(sym_gen=sym_gen, default_bucket_key=default_bucket_key)
data_shapes = [mx.io.DataDesc(name='data', shape=(2, default_bucket_key, 2))]
mod.bind(data_shapes=data_shapes, for_training=False, grad_req="null")
mod.init_params()
print("module initialized with default bucket key")

for bucket_key in range(1, default_bucket_key):
    print("BUCKET KEY", bucket_key)
    data_batch = mx.io.DataBatch(data=[mx.nd.ones((2, bucket_key, 2))],
                                 label=[],
                                 provide_data=[mx.io.DataDesc(name='data', shape=(2, bucket_key, 2))],
                                 bucket_key=bucket_key)

    mod.forward(data_batch=data_batch)
    print(mod.get_outputs())

The above code crashes with the following output:

['data']
['data']
module initialized with default bucket key
BUCKET KEY 1
['data', 'weight']
libc++abi.dylib: terminating with uncaught exception of type std::out_of_range: unordered_map::at: key not found

Process finished with exit code 134 (interrupted by signal 6: SIGABRT)

The crash happens in the mod.forward() call when trying to allocate/switch to a new bucket (1).

If the above sym_gen code is changed to use the weight variable for the default bucket key (e.g. change if bucket_key <= 1: to if bucket_key > 1:), everything runs fine, as the default graph probably has allocated memory for the weight variable.

My questions are as follows:

  • It took my a while to figure out the problem in my actual use case as the low-level error message is not really helpful. It'd be great if MXNet could guard against such parameter-related issues.
  • What is your general opinion about this kind of code? I can work around this issue by setting the default bucket key of the module to use 'all potential parameters' (in this case default_bucket_key=1), but in my use case that probably hurts memory sharing between buckets: usually one sets the default bucket key so that it corresopnds to the 'largest' computation graph (for example in terms of sequence length). In this particular example the questions is what is 'largest': largest sequence length, or largest w.r.t parameters/variables.
@vrakesh
Copy link
Contributor

vrakesh commented Jun 25, 2018

@fhieber Thank you for reporting the issue, will look into this, @sandeep-krishnamurthy requesting the issue be labeled under bugs

@fhieber
Copy link
Contributor Author

fhieber commented Jun 28, 2018

Great thanks! It might very well be that this shouldn't be a supported use case (different # of parameters for different buckets in a module), but mxnet should at least have a more informative error message.

@vishaalkapoor
Copy link
Contributor

Hey @vrakesh, did you happen to look into the issue? Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

4 participants