diff --git a/celldetection/util/util.py b/celldetection/util/util.py index 407acf8..a261c12 100644 --- a/celldetection/util/util.py +++ b/celldetection/util/util.py @@ -483,6 +483,9 @@ def inject_extra_repr_(module, name, fn): Note: This is an inplace operation. + Notes: + - This op may impair pickling. + Args: module: Module. name: Name of the injected function (only used to avoid duplicate injection). @@ -506,7 +509,7 @@ def wrap_module_(module: nn.Module, class_or_tuple, wrapper, recursive=True, **k handle[k] = wrapper(handle[k], **kwargs) -def spectral_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', **kwargs): +def spectral_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', add_repr=False, **kwargs): """Spectral normalization. Applies spectral normalization to parameters of all occurrences of ``class_or_tuple`` in the given module. @@ -522,6 +525,8 @@ def spectral_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weigh class_or_tuple: Class or tuple of classes whose parameters are to be normalized. recursive: Whether to search for modules recursively. name: Name of weight parameter. + add_repr: Whether to indicate use of spectral norm in a module's representation. + Note that this may impair pickling. **kwargs: Additional keyword arguments for ``torch.nn.utils.spectral_norm``. """ @@ -532,10 +537,11 @@ def extra_repr(self): for handle, k, mod in iter_submodules(module, class_or_tuple, recursive=recursive): if mod._parameters.get(name) is not None: handle[k] = nn.utils.spectral_norm(handle[k], name=name, **kwargs) - inject_extra_repr_(handle[k], 'spectral_norm', extra_repr) + if add_repr: + inject_extra_repr_(handle[k], 'spectral_norm', extra_repr) -def weight_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', **kwargs): +def weight_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', add_repr=False, **kwargs): """Weight normalization. Applies weight normalization to parameters of all occurrences of ``class_or_tuple`` in the given module. @@ -551,6 +557,8 @@ def weight_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight' class_or_tuple: Class or tuple of classes whose parameters are to be normalized. recursive: Whether to search for modules recursively. name: Name of weight parameter. + add_repr: Whether to indicate use of weight norm in a module's representation. + Note that this may impair pickling. **kwargs: Additional keyword arguments for ``torch.nn.utils.weight_norm``. """ @@ -561,7 +569,8 @@ def extra_repr(self): for handle, k, mod in iter_submodules(module, class_or_tuple, recursive=recursive): if mod._parameters.get(name) is not None: handle[k] = nn.utils.weight_norm(handle[k], name=name, **kwargs) - inject_extra_repr_(handle[k], 'weight_norm', extra_repr) + if add_repr: + inject_extra_repr_(handle[k], 'weight_norm', extra_repr) def get_device(module: Union[nn.Module, Tensor, torch.device]):