From b4401b639c68e0623339bad22a57f67b5092452f Mon Sep 17 00:00:00 2001 From: ericup Date: Wed, 4 May 2022 14:24:28 +0200 Subject: [PATCH] Add weight normalization --- celldetection/util/util.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/celldetection/util/util.py b/celldetection/util/util.py index dc7180a..997da2d 100644 --- a/celldetection/util/util.py +++ b/celldetection/util/util.py @@ -20,7 +20,7 @@ 'random_code_name_dir', 'get_device', 'num_params', 'count_submodules', 'train_epoch', 'Bytes', 'Percent', 'GpuStats', 'trainable_params', 'frozen_params', 'Tiling', 'load_image', 'gaussian_kernel', 'iter_submodules', 'replace_module_', 'wrap_module_', 'spectral_norm_', 'to_h5', 'to_tiff', - 'exponential_moving_average_', 'from_json', 'to_json'] + 'exponential_moving_average_', 'from_json', 'to_json', 'weight_norm_'] class Dict(dict): @@ -483,6 +483,41 @@ def spectral_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, **kwargs): wrap_module_(module, class_or_tuple, recursive=recursive, wrapper=nn.utils.spectral_norm, **kwargs) +def weight_norm_(module, class_or_tuple=nn.Conv2d, recursive=True, name='weight', **kwargs): + """Weight normalization. + + Applies weight normalization to parameters of all occurrences of ``class_or_tuple`` in the given module. + + Note: + This is an inplace operation. + + References: + - https://proceedings.neurips.cc/paper/2016/file/ed265bc903a5a097f61d3ec064d96d2e-Paper.pdf + + Args: + module: Module. + class_or_tuple: Class or tuple of classes whose parameters are to be normalized. + recursive: Whether to search for modules recursively. + name: Name of weight parameter. + **kwargs: Additional keyword arguments for ``torch.nn.utils.weight_norm``. + """ + for handle, k, mod in iter_submodules(module, class_or_tuple, recursive=recursive): + if mod._parameters.get(name) is not None: + handle[k] = nn.utils.weight_norm(handle[k], name=name, **kwargs) + + def extra_repr(_mod=handle[k]): + s = _mod._extra_repr() + if s is None: + s = '' + if len(s) > 0: + s += ', ' + s += 'weight_norm=True' + return s + + handle[k]._extra_repr = handle[k].extra_repr + handle[k].extra_repr = extra_repr + + def get_device(module: Union[nn.Module, Tensor, torch.device]): """Get device.