Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update Hybridize to use kwargs as backend opts
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <spanev@nvidia.com>
  • Loading branch information
Kh4L committed Nov 5, 2020
1 parent 7373172 commit 6708b59
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ def optimize_for(self, x, *args, backend=None, clear=True, static_alloc=False, s
"""

# do hybrize API call
self.hybridize(True, backend, kwargs, clear, static_alloc=static_alloc, static_shape=static_shape)
self.hybridize(True, backend, clear, static_alloc=static_alloc, static_shape=static_shape, **kwargs)

# do part of forward API call
has_symbol, has_ndarray, ctx_set, _ = _gather_type_ctx_info([x] + list(args))
Expand Down Expand Up @@ -1137,7 +1137,8 @@ def register_child(self, block, name=None):
super(HybridBlock, self).register_child(block, name)
self._clear_cached_op()

def hybridize(self, active=True, backend=None, backend_opts=None, clear=True, **kwargs):
def hybridize(self, active=True, backend=None, clear=True,
static_alloc=False, static_shape=False, **kwargs):
"""Activates or deactivates :py:class:`HybridBlock` s recursively. Has no effect on
non-hybrid children.
Expand All @@ -1156,23 +1157,25 @@ def hybridize(self, active=True, backend=None, backend_opts=None, clear=True, **
Optimize for invariant input shapes between iterations. Must also
set static_alloc to True. Change of input shapes is still allowed
but slower.
**kwargs : dict
Optional backend options when hybridized is called with an optimization backend.
"""

self._backend = backend
if backend_opts is not None:
assert isinstance(backend_opts, dict), \
"HybridBlock hybridize requires backend_opts to be a dictionary."
self._backend_opts = backend_opts
self._backend_opts = kwargs

self._active = active
self._flags = list(kwargs.items())
self._flags = [("static_alloc", static_alloc), ("static_shape", static_shape)]
if clear:
self._clear_cached_op()
if active and self._forward_hooks or self._forward_pre_hooks:
warnings.warn('"{block}" is being hybridized while still having forward hook/pre-hook. '
'If "{block}" is a child of HybridBlock, the hooks will not take effect.'
.format(block=self))
super(HybridBlock, self).hybridize(active, **kwargs)
super(HybridBlock, self).hybridize(active, static_alloc=static_alloc, static_shape=static_shape)

def cast(self, dtype):
self._clear_cached_op()
Expand Down

0 comments on commit 6708b59

Please sign in to comment.