Skip to content

Commit

Permalink
Add weight normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed May 4, 2022
1 parent 7dc145d commit b4401b6
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion celldetection/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
'random_code_name_dir', 'get_device', 'num_params', 'count_submodules', 'train_epoch', 'Bytes', 'Percent',
'GpuStats', 'trainable_params', 'frozen_params', 'Tiling', 'load_image', 'gaussian_kernel',
'iter_submodules', 'replace_module_', 'wrap_module_', 'spectral_norm_', 'to_h5', 'to_tiff',
'exponential_moving_average_', 'from_json', 'to_json']
'exponential_moving_average_', 'from_json', 'to_json', 'weight_norm_']


class Dict(dict):
Expand Down Expand Up @@ -483,6 +483,41 @@ def spectral_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, **kwargs):
wrap_module_(module, class_or_tuple, recursive=recursive, wrapper=nn.utils.spectral_norm, **kwargs)


def weight_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', **kwargs):
"""Weight normalization.
Applies weight normalization to parameters of all occurrences of ``class_or_tuple`` in the given module.
Note:
This is an inplace operation.
References:
- https://proceedings.neurips.cc/paper/2016/file/ed265bc903a5a097f61d3ec064d96d2e-Paper.pdf
Args:
module: Module.
class_or_tuple: Class or tuple of classes whose parameters are to be normalized.
recursive: Whether to search for modules recursively.
name: Name of weight parameter.
**kwargs: Additional keyword arguments for ``torch.nn.utils.weight_norm``.
"""
for handle, k, mod in iter_submodules(module, class_or_tuple, recursive=recursive):
if mod._parameters.get(name) is not None:
handle[k] = nn.utils.weight_norm(handle[k], name=name, **kwargs)

def extra_repr(_mod=handle[k]):
s = _mod._extra_repr()
if s is None:
s = ''
if len(s) > 0:
s += ', '
s += 'weight_norm=True'
return s

handle[k]._extra_repr = handle[k].extra_repr
handle[k].extra_repr = extra_repr


def get_device(module: Union[nn.Module, Tensor, torch.device]):
"""Get device.
Expand Down

0 comments on commit b4401b6

Please sign in to comment.