Skip to content

Commit

Permalink
Update utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed May 9, 2022
1 parent eeb7bfe commit f25d2bb
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions celldetection/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ def inject_extra_repr_(module, name, fn):
Note:
This is an inplace operation.
Notes:
- This op may impair pickling.
Args:
module: Module.
name: Name of the injected function (only used to avoid duplicate injection).
Expand All @@ -506,7 +509,7 @@ def wrap_module_(module: nn.Module, class_or_tuple, wrapper, recursive=True, **k
handle[k] = wrapper(handle[k], **kwargs)


def spectral_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', **kwargs):
def spectral_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', add_repr=False, **kwargs):
"""Spectral normalization.
Applies spectral normalization to parameters of all occurrences of ``class_or_tuple`` in the given module.
Expand All @@ -522,6 +525,8 @@ def spectral_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weigh
class_or_tuple: Class or tuple of classes whose parameters are to be normalized.
recursive: Whether to search for modules recursively.
name: Name of weight parameter.
add_repr: Whether to indicate use of spectral norm in a module's representation.
Note that this may impair pickling.
**kwargs: Additional keyword arguments for ``torch.nn.utils.spectral_norm``.
"""

Expand All @@ -532,10 +537,11 @@ def extra_repr(self):
for handle, k, mod in iter_submodules(module, class_or_tuple, recursive=recursive):
if mod._parameters.get(name) is not None:
handle[k] = nn.utils.spectral_norm(handle[k], name=name, **kwargs)
inject_extra_repr_(handle[k], 'spectral_norm', extra_repr)
if add_repr:
inject_extra_repr_(handle[k], 'spectral_norm', extra_repr)


def weight_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', **kwargs):
def weight_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', add_repr=False, **kwargs):
"""Weight normalization.
Applies weight normalization to parameters of all occurrences of ``class_or_tuple`` in the given module.
Expand All @@ -551,6 +557,8 @@ def weight_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight'
class_or_tuple: Class or tuple of classes whose parameters are to be normalized.
recursive: Whether to search for modules recursively.
name: Name of weight parameter.
add_repr: Whether to indicate use of weight norm in a module's representation.
Note that this may impair pickling.
**kwargs: Additional keyword arguments for ``torch.nn.utils.weight_norm``.
"""

Expand All @@ -561,7 +569,8 @@ def extra_repr(self):
for handle, k, mod in iter_submodules(module, class_or_tuple, recursive=recursive):
if mod._parameters.get(name) is not None:
handle[k] = nn.utils.weight_norm(handle[k], name=name, **kwargs)
inject_extra_repr_(handle[k], 'weight_norm', extra_repr)
if add_repr:
inject_extra_repr_(handle[k], 'weight_norm', extra_repr)


def get_device(module: Union[nn.Module, Tensor, torch.device]):
Expand Down

0 comments on commit f25d2bb

Please sign in to comment.