Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[2.0] Gluon2.0: switch to use forward interface #20262

Merged
merged 85 commits into from Jun 21, 2021

Conversation

barry-jin
Copy link
Contributor

@barry-jin barry-jin commented May 12, 2021

Description

#19138

Checklist

Essentials

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

Changes

  • Adopt packed_func based ffi on npx.group_norm
  • Gluon2.0 upgrade: use forward interface in blocks
    • gluon/data/vision/*
    • gluon/loss.py
    • gluon/model_zoo/*
    • gluon/nn/*
    • gluon/rnn/*
  • use np/npx interface in gluon/metric.py
  • Implement infer_shape method
    • gluon/nn/basic_layers.py::Dense
    • gluon/nn/basic_layers.py::BatchNorm
    • gluon/nn/basic_layers.py::InstanceNorm
    • gluon/nn/basic_layers.py::LayerNorm
    • gluon/nn/basic_layers.py::GroupNorm
    • gluon/nn/conv_layers.py::_Conv
    • gluon/nn/conv_layers.py::DeformableConvolution
    • gluon/nn/conv_layers.py::ModulatedDeformableConvolution
  • Remove hybrid mode with F in gluon/probability/*
  • gluon/rnn/*
    • conv_rnn_cell.py: implement forward, infer_shape, np/npx
    • rnn_cell.py: implement forward, np/npx; use special infer_shape method based on layer, input_size and if it's bidirectional.
    • rnn_layer.py: implement forward, infer_shape, np/npx
  • fix issue np.average return ADT type; npx.pooling
  • Copy control flow ops(loop_while, cond, foreach) from ndarray.contrib to mx.npx.control_flow
  • Register some legacy ops in npx
    • stes_op
    • sync_batch_norm
    • legacy pad (np.pad doesn't have backward computation for 'reflect" mode)
    • Some rnn related: sequence_last, sequence_reverse, slice_channel, broadcast_greater, softsign
  • Use forward, np/npx for all the gluon related tests; remove gluon tests with symbol inputs
    • remove test_gluon_data_vision.py and test_gluon_probability_v1.py as related tests are covered in test_numpy_gluon_data_vision.py and test_gluon_probability_v2.py
    • Test test_numpy_op.py::test_np_nan_to_num only for copy argument is set to True, since Inplace operations are not supported when recording in deferred compute mode
  • Remove hybrid_block interface in gluon/block.py
  • Remove hybrid_block interface in documentation and docstring
    • update docs/python_docs/python/tutorials/packages/gluon/blocks/custom_layers
    • remove python_tutorials/packages/gluon/blocks/custom_layers_beginners.md as it's duplicate to custom_layers
    • remove docs/python_docs/python/docstutorials/packages/legacy/ndarray/sparse/train_gluon as gluon2.0 do not support sparse
    • update docs/python_docs/python/tutorials/packages/gluon/blocks/custom_loss
    • update docs/python_docs/python/tutorials/packages/gluon/blocks/hybridize
  • Turn on NumPy mode by default ([NumPy] turn on set_np #18631)
  • Fix gluon2.0 reference leak.
  • Migrate control flow operators to npx namespace
    • Foreach
    • while_loop
    • cond

Some skipped tests

  • tests/python/mkl/subgraphs/test_conv_subgraph.py::test_pos_concat_scale_align
    • Reason: Scale doesn't align in numpy for numpy operators
    def check_qsym_scale_align(qsym):
      assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1
      init = False
      for k, v in qsym.attr_dict().items():
        if k.find('quantized_sg_mkldnn_conv') != -1:
          assert 'min_calib_range' in v
          assert 'max_calib_range' in v
          if not init:
            min_calib_range = v['min_calib_range']
            max_calib_range = v['max_calib_range']
            init = True
          else:
>           assert min_calib_range == v['min_calib_range']
E           AssertionError
  • tests/python/mkl/subgraphs/test_fc_subgraph.py::test_fc_eltwise
    • Reason: Operator square, square_root, abs, exp cannot be found in numpy mode
    def check_fusion(net_original, data_shape, attrs_dict, check_fp32_fusion=True, check_quantization=True,
                     out_types=['uint8', 'int8', 'auto'], dedup_subgraph=True):
      net_original.initialize()
      net_original.hybridize(static_alloc=False, static_shape=False)
      data = mx.np.random.uniform(size=data_shape, dtype='float32', ctx=mx.current_context())
      net_original(data)
      net_fusion = copy.copy(net_original)
      sym, params = net_original.export(None)
    
      if check_fp32_fusion:
        data_min = -1.0
        data_max = 1.0
        if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1:
          check_quantization = False
          data_min = 0
    
        sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=dedup_subgraph, skip_infer=True)
        for name, attrs in attrs_dict.items():
          if name in config:
            op_name = config[name][OP_NAME]
          else:
            op_name = name
          assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1
          if len(attrs):
              found = False
              for k, v in sym_sg.attr_dict().items():
                if k.find(op_name) != -1:
                  found = True
                  for attr_name, attr_value in attrs.items():
                    assert v[attr_name].lower() == attr_value.lower()
>             assert found
E             AssertionError

@barry-jin barry-jin requested a review from szha as a code owner May 12, 2021 19:29
@mxnet-bot
Copy link

Hey @barry-jin , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [website, windows-cpu, windows-gpu, miscellaneous, centos-gpu, edge, unix-gpu, sanity, unix-cpu, clang, centos-cpu]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

@mseth10 mseth10 added the pr-work-in-progress PR is still work in progress label May 12, 2021
@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress labels Jun 18, 2021
@barry-jin
Copy link
Contributor Author

@mxnet-bot run ci [all]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [clang, centos-gpu, unix-cpu, website, sanity, edge, centos-cpu, unix-gpu, windows-cpu, miscellaneous, windows-gpu]

@mseth10 mseth10 added pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-awaiting-testing PR is reviewed and waiting CI build and test pr-work-in-progress PR is still work in progress labels Jun 18, 2021
Copy link
Contributor

@TristonC TristonC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote for the forward interface compare to hybrid_forward interface. Thanks.

@@ -24,8 +24,8 @@
from time import time

import mxnet as mx
import numpy as np
from mxnet import gluon
import numpy as onp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 'o' in onp as original as this is original numpy? It will be confusing as np being well known as numpy for short.

Copy link
Contributor Author

@barry-jin barry-jin Jun 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, 'o' in onp is 'official', which is used to distinguish between official numpy and MXNet numpy. Usually, user will do
from mxnet import np and build their models with numpy operators from MXNet. This will provide numpy-compatible coding experience in MXNet for users.

tock = time()
times.append((tock - tick) * 1000.0)
times = times[args.warmup_rounds: ]
print("Time used: mean = %.3f ms, std = %.3f ms" % (np.mean(times), np.std(times)))
print("Time used: mean = %.3f ms, std = %.3f ms" % (onp.mean(times), onp.std(times)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will mxnet np provide the mean and std function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. mean and std operators are implemented as mxnet.np.mean and mxnet.np.std

@@ -56,41 +56,38 @@ The rest of methods of the `Block` class are already implemented, and majority o

Looking into implementation of [existing layers](https://mxnet.apache.org/api/python/gluon/nn.html), one may find that more often a block inherits from a [HybridBlock](https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/block.py#L428), instead of directly inheriting from `Block`.

The reason for that is that `HybridBlock` allows to write custom layers that can be used in imperative programming as well as in symbolic programming. It is convinient to support both ways, because the imperative programming eases the debugging of the code and the symbolic one provides faster execution speed. You can learn more about the difference between symbolic vs. imperative programming from [this article](https://mxnet.apache.org/api/architecture/overview.html).
The reason for that is that `HybridBlock` allows to write custom layers in imperative programming style, while computing in a symbolic way. It unifies the flexibility of imperative programming with the performance benefits of symbolic programming. You can learn more about the difference between symbolic vs. imperative programming from [this article](https://mxnet.apache.org/api/architecture/overview.html).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

between ... and for " the difference between symbolic vs. imperative programming"

Usually, a layer has a set of associated parameters, sometimes also referred as weights. This is an internal state of a layer. Most often, these parameters are the ones, that we want to learn during backpropogation step, but sometimes these parameters might be just constants we want to use during forward pass.

All parameters of a block are stored and accessed via [ParameterDict](https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/parameter.py#L508) class. This class helps with initialization, updating, saving and loading of the parameters. Each layer can have multiple set of parameters, and all of them can be stored in a single instance of the `ParameterDict` class. On a block level, the instance of the `ParameterDict` class is accessible via `self.params` field, and outside of a block one can access all parameters of the network via [collect_params()](https://mxnet.apache.org/api/python/gluon/gluon.html#mxnet.gluon.Block.collect_params) method called on a `container`. `ParameterDict` uses [Parameter](https://mxnet.apache.org/api/python/gluon/gluon.html#mxnet.gluon.Parameter) class to represent parameters inside of Apache MxNet neural network. If parameter doesn't exist, trying to get a parameter via `self.params` will create it automatically.
Usually, a layer has a set of associated parameters, sometimes also referred as weights. This is an internal state of a layer. Most often, these parameters are the ones, that we want to learn during backpropogation step, but sometimes these parameters might be just constants we want to use during forward pass. The parameters are usually represented as [Parameter](https://mxnet.apache.org/api/python/gluon/gluon.html#mxnet.gluon.Parameter) class inside of Apache MxNet neural network.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MxNet -> MXNet


for key, value in hybridlayer_params.items():
print('{} = {}\n'.format(key, value.data()))

net = gluon.nn.HybridSequential() # Define a Neural Network as a sequence of hybrid blocks
net.add(Dense(5)) # Add Dense layer with 5 neurons
net.add(NormalizationHybridLayer(hidden_units=5,
scales = nd.array([2]))) # Add our custom layer
scales = np.array([2]))) # Add our custom layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just say # Add a customer layer

@@ -1399,6 +1399,14 @@ MXNET_DLL int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle *output_handles,
int num_outputs,
SymbolHandle *out);

/*!
* \brief Clear the info node associated with the arrays.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The brief is not obvious with the function name. It is more about how the deferred compute is handled.

* \brief Clear the info node associated with the arrays.
* \param arrays array handles of arrays
* \param num number of arrays
* \return 0 when success, -1 when failure happens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 otherwise

@@ -1039,7 +1038,9 @@ def forward(self, x):
"""
def __init__(self):
super(HybridBlock, self).__init__()
self._v2 = inspect.unwrap(self.hybrid_forward.__func__) is HybridBlock.hybrid_forward
assert hasattr(self, "hybrid_forward") is False, (
"Starting from MXNet2.0, Gluon2.0 with forward interface will be used instead of "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does both MXNet2.0 and Gluon2.0 need to be met at the same time? Propose:
'forward' instead of 'hybrid_forward' interfaces needs to be used starting from Gluon 2.0. ......

@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test and removed pr-work-in-progress PR is still work in progress labels Jun 19, 2021
@barry-jin
Copy link
Contributor Author

@TristonC Thanks for your suggestions on improving the documentation!

@mseth10 mseth10 added pr-work-in-progress PR is still work in progress and removed pr-awaiting-testing PR is reviewed and waiting CI build and test labels Jun 19, 2021
@barry-jin
Copy link
Contributor Author

@mxnet-bot run ci [centos-cpu]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [centos-cpu]

@mseth10 mseth10 added pr-awaiting-testing PR is reviewed and waiting CI build and test pr-awaiting-review PR is waiting for code review and removed pr-work-in-progress PR is still work in progress pr-awaiting-testing PR is reviewed and waiting CI build and test labels Jun 21, 2021
@leezu leezu merged commit 7152685 into apache:master Jun 21, 2021
@barry-jin barry-jin deleted the issue-19138 branch August 4, 2021 21:27
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants