Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Updating SymbolBlock.imports to support different dtypes#15230

Merged
szha merged 5 commits intoapache:masterfrom
ThomasDelteil:update_symbol_block
Jun 20, 2019
Merged

Updating SymbolBlock.imports to support different dtypes#15230
szha merged 5 commits intoapache:masterfrom
ThomasDelteil:update_symbol_block

Conversation

@ThomasDelteil
Copy link
Copy Markdown
Contributor

@ThomasDelteil ThomasDelteil commented Jun 13, 2019

Description

Previously it would be impossible to do SymbolBlock.imports(...) with non-fp32 parameters.
The problem was that the type inference would fail all the time and parameters types would default to fp32.
Now we use the parameters type as provided in the parameter file, by casting the parameters to them (new dtype_source for the load _load_init and load_parameters functions to complement the cast_dtype flag)
If the parameters are not provided we attempt an automatic type inference with fp32 input.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Code is well-documented:
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@ThomasDelteil ThomasDelteil requested a review from szha as a code owner June 13, 2019 02:00
def test_gluon_param_load_dtype_source():
net = mx.gluon.nn.Dense(10, in_units=10)
net.initialize()
net.cast('float16')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a case to load INT8 weight?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also suggest for adding test for loading int8 weight. BTW, I just took a quick try and this works well with my internal patch:)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengzhao-intel @xinyu-intel, thanks for the suggestion, I'm actually unable to find a single layer in Gluon that supports int8 or uint8.

net = mx.gluon.nn.Conv2D(channels=4, kernel_size=3, in_channels=3)
net.initialize()
net.cast(np.uint8)
net(mx.nd.ones((1,3,224,224), dtype=np.uint8))

gives me a

MXNetError: std::exception

Any suggestions without resorting to symbolically fused graphs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a test for the ParameterDict that should be sufficient since it is the same mechanism as the network parameters: see https://github.com/apache/incubator-mxnet/pull/15230/files#diff-962bd5bb7248659d7eb3be37ee8a4c6bR141

@vandanavk
Copy link
Copy Markdown
Contributor

@mxnet-label-bot add [Gluon, pr-awaiting-review]

@marcoabreu marcoabreu added Gluon pr-awaiting-review PR is waiting for code review labels Jun 13, 2019
@ThomasDelteil
Copy link
Copy Markdown
Contributor Author

@szha for review

@ThomasDelteil
Copy link
Copy Markdown
Contributor Author

@pengzhao-intel is that good for merging?

@szha szha merged commit 145f82d into apache:master Jun 20, 2019
@pengzhao-intel
Copy link
Copy Markdown
Contributor

LGTM

Sorry for the late reply.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

Gluon pr-awaiting-review PR is waiting for code review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants