Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Simplify mxnet.gluon Block APIs (#18413)
Browse files Browse the repository at this point in the history
## Motivations
Currently the implementation of mxnet.gluon.block is not so pythonic and there are many redundancies 

### 1. overlaps between Block._params and Block._reg_params 
when we want to self-define a model, we currently need to use the code as follows:
```
class Net(nn.HybridBlock):
    def __init__(self, **kwargs):
        super(HybridNet, self).__init__(**kwargs)
        with self.name_scope():
            self.hidden1 = nn.Dense(256, activation='relu')
            self.a=self.params.get('a', shape=(1, ))        
```
There are several shortcomings when using this form of registration:
a. adding parameter ‘a’ will lead to double recordings in both self._params and self._reg_params, which is a redundancy. And there is also a discrepancy in Block:
      i. In the method “collect_params”, we use “_params” to get all parameters
     ii. while in the method “_collect_params_with_prefix” (and methods “load_parameters” accordingly), we use “_reg_params” to get all parameters.
b. Currently if we do not use “with self.name_scope():” for children blocks, it will lead to wrong name scopes. For the following example, we actually can not get the parameters of self.hidden1 from the result of collect_params
```
class HybridNet(nn.HybridBlock):
    def __init__(self, **kwargs):
        super(HybridNet, self).__init__(**kwargs)
        self.hidden1 = nn.Dense(256, activation='relu')
        with self.name_scope():
            self.hidden2 = nn.Dense(10, activation='relu')

    def hybrid_forward(self, F, x):
        x = self.hidden2(self.hidden1(x))
        return x
    
>>> net = HybridNet()
>>> net.initialize()
>>> print(net.collect_params())
hybridnet0_ (
  Parameter dense0_weight (shape=(256, -1), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter hybridnet0_dense0_weight (shape=(10, -1), dtype=float32)
  Parameter hybridnet0_dense0_bias (shape=(10,), dtype=float32)
)
```
From the above example we can also find that the parameter names are not related to the attributes’ names, which is not straightforward.

In all, we find that using name_scope and ParameterDict is not user-friendly. Thus we plan to remove such redundancies and simplify the definitions of children blocks and parameters, like:
```
class Net(nn.HybridBlock):
    def __init__(self, **kwargs):
        super(HybridNet, self).__init__(**kwargs)
        self.hidden1 = nn.Dense(256, activation='relu')
        self.a=gluon.parameter.Parameter(name="a", shape=(1, ))    
```

### 2. parameter sharing 
Currently, we use parameter “params” in the definition of Block for parameter sharing. It means before the __init__ of Block, shared parameters already recorded in self._params.shared. And currently Block forbids overriding parameters. 
We think that this is not convenient. A most common way to share parameter is like what Pytorch does, like 
```
self.hidden1.weight=self.hidden2.weight
```
But note that in the case where we have a HybridBlock and the block has been hybridized, then we shouldn't allow overriding the parameter but ask the user to unhybridize the Block first.
To further allow sharing parameters recursively, we plan to add an API:
```
    def share_parameters(self, params : Dict):
```
We plan to use the structured based form (like what is used in “_collect_params_with_prefix()”) to represent each parameter recursively. For example, we denote “self.hidden1.weight” as “hidden_weight”

In all, we plan to make the following improvements:

1. remove parameters “prefix” and “params” in the “\_\_init\_\_" function.
2. remove the use of self._params(ParameterDict) in Block
3. allow parameter attribute overriding in non-hydridization case.
4. add the method “share_parameters" to recursively share parameters in children blocks.

## Parameter naming
Once a parameter is created, `param.name` would not be changed in the following operations. It is in the form of `param_{uuid4}_{name}`, where `name` is from `__init __` parameter. Here `name` is optional, default `weight`. It is mainly used to denote which default initialization should be used.
We use `param.name` as the name of a parameter's symbol representation.
## collect_params()
It returns a `dict`, where the keys are structural names of parameters, like 
`{'hidden1.weight': Parameter (shape=(3, -1), dtype=float32), 'hidden1.bias': Parameter (shape=(3,), dtype=float32)}`
Note that we use `.` as the linking character again because the structured based naming scheme is no longer used in the symbol representation.

## Save and Load
For `HybridBlock`, there are two ways to save and load parameters:
### save_parameters() and load_parameters()
In `save_parameters()`, we use `structural name` to save parameters, and they should be loaded by `load_parameters()`, which loads parameters based on a model's structure.
### HybridBlock.export and SymbolBlock.imports
In `export`, we only save parameters using `param.name` without `structural name`. The param file should be loaded in SymbolBlock.imports.
## SymbolBlock
When using `SymbolBlock.imports`, keys in `self.param` would be the loaded parameters' names `param.name`.
While in `SymbolBlock(outputs, inputs, params=None)`, if you provide like `params=net.collect_params()`,  keys in `self.param` would be structural names of `net`'s parameters (keys in net.collect_params() ). It is often used in this situation that a `SymbolBlock` is a children block of another `HybridBlock`. Otherwise, keys in `self.param` would be the loaded parameters' names `param.name`.
  • Loading branch information
acphile committed Jun 19, 2020
1 parent 5585606 commit cb54a4a
Show file tree
Hide file tree
Showing 54 changed files with 1,746 additions and 2,482 deletions.
2 changes: 1 addition & 1 deletion example/gluon/style_transfer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from PIL import Image

from mxnet import autograd, gluon
from mxnet.gluon import nn, Block, HybridBlock, Parameter, ParameterDict
from mxnet.gluon import nn, Block, HybridBlock, Parameter
import mxnet.ndarray as F

import net
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,8 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
# If dtype for the param was set in the json, cast the
# param to this dtype
attr_dict = converted_sym.attr_dict()
for name, param in block.collect_params().items():
for param in block.collect_params().values():
name = param.name
if name in arg_names:
arg_dict['arg:%s'%name] = param._reduce()
if name in attr_dict and "__dtype__" in attr_dict[name]:
Expand Down Expand Up @@ -719,7 +720,7 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
if aux_param_name in arg_dict and param.dtype != arg_dict[aux_param_name].dtype:
param.cast(arg_dict[aux_param_name].dtype)

ret.collect_params().load_dict(arg_dict, ctx=ctx)
ret.load_dict(arg_dict, ctx=ctx)
return ret

def list_lp16_ops(target_dtype):
Expand Down
455 changes: 253 additions & 202 deletions python/mxnet/gluon/block.py

Large diffs are not rendered by default.

297 changes: 148 additions & 149 deletions python/mxnet/gluon/contrib/cnn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .... import symbol
from ...block import HybridBlock
from ...parameter import Parameter
from ....base import numeric_types
from ...nn import Activation

Expand Down Expand Up @@ -103,80 +104,79 @@ def __init__(self, channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0),
num_deformable_group=1, layout='NCHW', use_bias=True, in_channels=0, activation=None,
weight_initializer=None, bias_initializer='zeros',
offset_weight_initializer='zeros', offset_bias_initializer='zeros', offset_use_bias=True,
op_name='DeformableConvolution', adj=None, prefix=None, params=None):
super(DeformableConvolution, self).__init__(prefix=prefix, params=params)
with self.name_scope():
self._channels = channels
self._in_channels = in_channels

assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now"
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,) * 2
if isinstance(strides, numeric_types):
strides = (strides,) * len(kernel_size)
if isinstance(padding, numeric_types):
padding = (padding,) * len(kernel_size)
if isinstance(dilation, numeric_types):
dilation = (dilation,) * len(kernel_size)
self._op_name = op_name

offset_channels = 2 * kernel_size[0] * kernel_size[1] * num_deformable_group
self._kwargs_offset = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': offset_channels, 'num_group': groups,
'no_bias': not offset_use_bias, 'layout': layout}

self._kwargs_deformable_conv = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': channels, 'num_group': groups,
'num_deformable_group': num_deformable_group,
'no_bias': not use_bias, 'layout': layout}

if adj:
self._kwargs_offset['adj'] = adj
self._kwargs_deformable_conv['adj'] = adj

dshape = [0] * (len(kernel_size) + 2)
dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels

op = getattr(symbol, 'Convolution')
offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset)

offsetshapes = offset.infer_shape_partial()[0]

self.offset_weight = self.params.get('offset_weight', shape=offsetshapes[1],
init=offset_weight_initializer,
allow_deferred_init=True)

if offset_use_bias:
self.offset_bias = self.params.get('offset_bias', shape=offsetshapes[2],
init=offset_bias_initializer,
allow_deferred_init=True)
else:
self.offset_bias = None

deformable_conv_weight_shape = [0] * (len(kernel_size) + 2)
deformable_conv_weight_shape[0] = channels
deformable_conv_weight_shape[2] = kernel_size[0]
deformable_conv_weight_shape[3] = kernel_size[1]

self.deformable_conv_weight = self.params.get('deformable_conv_weight',
shape=deformable_conv_weight_shape,
init=weight_initializer,
allow_deferred_init=True)

if use_bias:
self.deformable_conv_bias = self.params.get('deformable_conv_bias', shape=(channels,),
init=bias_initializer,
allow_deferred_init=True)
else:
self.deformable_conv_bias = None

if activation:
self.act = Activation(activation, prefix=activation + '_')
else:
self.act = None
op_name='DeformableConvolution', adj=None):
super(DeformableConvolution, self).__init__()
self._channels = channels
self._in_channels = in_channels

assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now"
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,) * 2
if isinstance(strides, numeric_types):
strides = (strides,) * len(kernel_size)
if isinstance(padding, numeric_types):
padding = (padding,) * len(kernel_size)
if isinstance(dilation, numeric_types):
dilation = (dilation,) * len(kernel_size)
self._op_name = op_name

offset_channels = 2 * kernel_size[0] * kernel_size[1] * num_deformable_group
self._kwargs_offset = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': offset_channels, 'num_group': groups,
'no_bias': not offset_use_bias, 'layout': layout}

self._kwargs_deformable_conv = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': channels, 'num_group': groups,
'num_deformable_group': num_deformable_group,
'no_bias': not use_bias, 'layout': layout}

if adj:
self._kwargs_offset['adj'] = adj
self._kwargs_deformable_conv['adj'] = adj

dshape = [0] * (len(kernel_size) + 2)
dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels

op = getattr(symbol, 'Convolution')
offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset)

offsetshapes = offset.infer_shape_partial()[0]

self.offset_weight = Parameter('offset_weight', shape=offsetshapes[1],
init=offset_weight_initializer,
allow_deferred_init=True)

if offset_use_bias:
self.offset_bias = Parameter('offset_bias', shape=offsetshapes[2],
init=offset_bias_initializer,
allow_deferred_init=True)
else:
self.offset_bias = None

deformable_conv_weight_shape = [0] * (len(kernel_size) + 2)
deformable_conv_weight_shape[0] = channels
deformable_conv_weight_shape[2] = kernel_size[0]
deformable_conv_weight_shape[3] = kernel_size[1]

self.deformable_conv_weight = Parameter('deformable_conv_weight',
shape=deformable_conv_weight_shape,
init=weight_initializer,
allow_deferred_init=True)

if use_bias:
self.deformable_conv_bias = Parameter('deformable_conv_bias', shape=(channels,),
init=bias_initializer,
allow_deferred_init=True)
else:
self.deformable_conv_bias = None

if activation:
self.act = Activation(activation)
else:
self.act = None

def hybrid_forward(self, F, x, offset_weight, deformable_conv_weight, offset_bias=None, deformable_conv_bias=None):
if offset_bias is None:
Expand Down Expand Up @@ -296,81 +296,80 @@ def __init__(self, channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0),
num_deformable_group=1, layout='NCHW', use_bias=True, in_channels=0, activation=None,
weight_initializer=None, bias_initializer='zeros',
offset_weight_initializer='zeros', offset_bias_initializer='zeros', offset_use_bias=True,
op_name='ModulatedDeformableConvolution', adj=None, prefix=None, params=None):
super(ModulatedDeformableConvolution, self).__init__(prefix=prefix, params=params)
with self.name_scope():
self._channels = channels
self._in_channels = in_channels

assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now"
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,) * 2
if isinstance(strides, numeric_types):
strides = (strides,) * len(kernel_size)
if isinstance(padding, numeric_types):
padding = (padding,) * len(kernel_size)
if isinstance(dilation, numeric_types):
dilation = (dilation,) * len(kernel_size)
self._op_name = op_name

offset_channels = num_deformable_group * 3 * kernel_size[0] * kernel_size[1]
self.offset_split_index = num_deformable_group * 2 * kernel_size[0] * kernel_size[1]
self._kwargs_offset = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': offset_channels, 'num_group': groups,
'no_bias': not offset_use_bias, 'layout': layout}

self._kwargs_deformable_conv = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': channels, 'num_group': groups,
'num_deformable_group': num_deformable_group,
'no_bias': not use_bias, 'layout': layout}

if adj:
self._kwargs_offset['adj'] = adj
self._kwargs_deformable_conv['adj'] = adj

deformable_conv_weight_shape = [0] * (len(kernel_size) + 2)
deformable_conv_weight_shape[0] = channels
deformable_conv_weight_shape[2] = kernel_size[0]
deformable_conv_weight_shape[3] = kernel_size[1]

self.deformable_conv_weight = self.params.get('deformable_conv_weight',
shape=deformable_conv_weight_shape,
init=weight_initializer,
allow_deferred_init=True)

if use_bias:
self.deformable_conv_bias = self.params.get('deformable_conv_bias', shape=(channels,),
init=bias_initializer,
allow_deferred_init=True)
else:
self.deformable_conv_bias = None

dshape = [0] * (len(kernel_size) + 2)
dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels

op = getattr(symbol, 'Convolution')
offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset)

offsetshapes = offset.infer_shape_partial()[0]

self.offset_weight = self.params.get('offset_weight', shape=offsetshapes[1],
init=offset_weight_initializer,
allow_deferred_init=True)

if offset_use_bias:
self.offset_bias = self.params.get('offset_bias', shape=offsetshapes[2],
init=offset_bias_initializer,
allow_deferred_init=True)
else:
self.offset_bias = None

if activation:
self.act = Activation(activation, prefix=activation + '_')
else:
self.act = None
op_name='ModulatedDeformableConvolution', adj=None):
super(ModulatedDeformableConvolution, self).__init__()
self._channels = channels
self._in_channels = in_channels

assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now"
if isinstance(kernel_size, numeric_types):
kernel_size = (kernel_size,) * 2
if isinstance(strides, numeric_types):
strides = (strides,) * len(kernel_size)
if isinstance(padding, numeric_types):
padding = (padding,) * len(kernel_size)
if isinstance(dilation, numeric_types):
dilation = (dilation,) * len(kernel_size)
self._op_name = op_name

offset_channels = num_deformable_group * 3 * kernel_size[0] * kernel_size[1]
self.offset_split_index = num_deformable_group * 2 * kernel_size[0] * kernel_size[1]
self._kwargs_offset = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': offset_channels, 'num_group': groups,
'no_bias': not offset_use_bias, 'layout': layout}

self._kwargs_deformable_conv = {
'kernel': kernel_size, 'stride': strides, 'dilate': dilation,
'pad': padding, 'num_filter': channels, 'num_group': groups,
'num_deformable_group': num_deformable_group,
'no_bias': not use_bias, 'layout': layout}

if adj:
self._kwargs_offset['adj'] = adj
self._kwargs_deformable_conv['adj'] = adj

deformable_conv_weight_shape = [0] * (len(kernel_size) + 2)
deformable_conv_weight_shape[0] = channels
deformable_conv_weight_shape[2] = kernel_size[0]
deformable_conv_weight_shape[3] = kernel_size[1]

self.deformable_conv_weight = Parameter('deformable_conv_weight',
shape=deformable_conv_weight_shape,
init=weight_initializer,
allow_deferred_init=True)

if use_bias:
self.deformable_conv_bias = Parameter('deformable_conv_bias', shape=(channels,),
init=bias_initializer,
allow_deferred_init=True)
else:
self.deformable_conv_bias = None

dshape = [0] * (len(kernel_size) + 2)
dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels

op = getattr(symbol, 'Convolution')
offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset)

offsetshapes = offset.infer_shape_partial()[0]

self.offset_weight = Parameter('offset_weight', shape=offsetshapes[1],
init=offset_weight_initializer,
allow_deferred_init=True)

if offset_use_bias:
self.offset_bias = Parameter('offset_bias', shape=offsetshapes[2],
init=offset_bias_initializer,
allow_deferred_init=True)
else:
self.offset_bias = None

if activation:
self.act = Activation(activation)
else:
self.act = None

def hybrid_forward(self, F, x, offset_weight, deformable_conv_weight, offset_bias=None, deformable_conv_bias=None):
if offset_bias is None:
Expand Down
14 changes: 7 additions & 7 deletions python/mxnet/gluon/contrib/data/vision/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def create_image_augment(data_shape, resize=0, rand_crop=False, rand_resize=Fals
"""
if inter_method == 10:
inter_method = np.random.randint(0, 5)
augmenter = HybridSequential('default_img_augment_')
augmenter = HybridSequential()
if resize > 0:
augmenter.add(transforms.image.Resize(resize, interpolation=inter_method))
crop_size = (data_shape[2], data_shape[1])
Expand Down Expand Up @@ -220,9 +220,9 @@ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None,
augmenter = create_image_augment(data_shape, **kwargs)
elif isinstance(aug_list, list):
if all([isinstance(a, HybridBlock) for a in aug_list]):
augmenter = HybridSequential('user_img_augment_')
augmenter = HybridSequential()
else:
augmenter = Sequential('user_img_augment_')
augmenter = Sequential()
for aug in aug_list:
augmenter.add(aug)
elif isinstance(aug_list, Block):
Expand Down Expand Up @@ -316,7 +316,7 @@ def create_bbox_augment(data_shape, rand_crop=0, rand_pad=0, rand_gray=0,
"""
if inter_method == 10:
inter_method = np.random.randint(0, 5)
augmenter = Sequential('default_bbox_aug_')
augmenter = Sequential()
if rand_crop > 0:
augmenter.add(bbox.ImageBboxRandomCropWithConstraints(
p=rand_crop, min_scale=area_range[0], max_scale=1.0,
Expand Down Expand Up @@ -439,17 +439,17 @@ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None,
augmenter = create_bbox_augment(data_shape, **kwargs)
elif isinstance(aug_list, list):
if all([isinstance(a, HybridBlock) for a in aug_list]):
augmenter = HybridSequential('user_bbox_augment_')
augmenter = HybridSequential()
else:
augmenter = Sequential('user_bbox_augment_')
augmenter = Sequential()
for aug in aug_list:
augmenter.add(aug)
elif isinstance(aug_list, Block):
augmenter = aug_list
else:
raise ValueError('aug_list must be a list of Blocks')
augmenter.hybridize()
wrapper_aug = Sequential('wrapper_bbox_aug_')
wrapper_aug = Sequential()
wrapper_aug.add(BboxLabelTransform(coord_normalized))
wrapper_aug.add(augmenter)

Expand Down
Loading

0 comments on commit cb54a4a

Please sign in to comment.