Updating SymbolBlock.imports to support different dtypes#15230
Updating SymbolBlock.imports to support different dtypes#15230szha merged 5 commits intoapache:masterfrom
Conversation
| def test_gluon_param_load_dtype_source(): | ||
| net = mx.gluon.nn.Dense(10, in_units=10) | ||
| net.initialize() | ||
| net.cast('float16') |
There was a problem hiding this comment.
Would you mind adding a case to load INT8 weight?
There was a problem hiding this comment.
Also suggest for adding test for loading int8 weight. BTW, I just took a quick try and this works well with my internal patch:)
There was a problem hiding this comment.
@pengzhao-intel @xinyu-intel, thanks for the suggestion, I'm actually unable to find a single layer in Gluon that supports int8 or uint8.
net = mx.gluon.nn.Conv2D(channels=4, kernel_size=3, in_channels=3)
net.initialize()
net.cast(np.uint8)
net(mx.nd.ones((1,3,224,224), dtype=np.uint8))gives me a
MXNetError: std::exception
Any suggestions without resorting to symbolically fused graphs?
There was a problem hiding this comment.
I have added a test for the ParameterDict that should be sufficient since it is the same mechanism as the network parameters: see https://github.com/apache/incubator-mxnet/pull/15230/files#diff-962bd5bb7248659d7eb3be37ee8a4c6bR141
|
@mxnet-label-bot add [Gluon, pr-awaiting-review] |
|
@szha for review |
|
@pengzhao-intel is that good for merging? |
|
LGTM Sorry for the late reply. |
Description
Previously it would be impossible to do
SymbolBlock.imports(...)with non-fp32 parameters.The problem was that the type inference would fail all the time and parameters types would default to fp32.
Now we use the parameters type as provided in the parameter file, by casting the parameters to them (new
dtype_sourcefor theload_load_initandload_parametersfunctions to complement thecast_dtypeflag)If the parameters are not provided we attempt an automatic type inference with fp32 input.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.