-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
weight_norm_hook.py
244 lines (201 loc) · 8.48 KB
/
weight_norm_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _C_ops
from ...base.data_feeder import check_variable_and_dtype
from ...base.layer_helper import LayerHelper
from ...framework import in_dynamic_mode
__all__ = []
def l2_norm(x, axis, epsilon=1e-12, name=None):
if len(x.shape) == 1:
axis = 0
if in_dynamic_mode():
out, norm = _C_ops.norm(x, 1 if axis is None else axis, epsilon, False)
return paddle.squeeze(norm, axis=[axis])
check_variable_and_dtype(x, "X", ("float32", "float64"), "norm")
helper = LayerHelper("l2_normalize", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
norm = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="norm",
inputs={"X": x},
outputs={"Out": out, "Norm": norm},
attrs={
"axis": 1 if axis is None else axis,
"epsilon": epsilon,
},
)
return paddle.squeeze(norm, axis=[axis])
def norm_except_dim(p, dim):
shape = p.shape
ndims = len(shape)
if dim == -1:
return paddle.sqrt(paddle.sum(paddle.square(p)) + 1e-12)
elif dim == 0:
p_matrix = paddle.reshape(p, (shape[0], -1))
return l2_norm(p_matrix, axis=1)
elif dim == ndims - 1:
p_matrix = paddle.reshape(p, (-1, shape[-1]))
return l2_norm(p_matrix, axis=0)
else:
perm = list(range(ndims))
perm[0] = dim
perm[dim] = 0
p_transposed = paddle.transpose(p, perm)
return norm_except_dim(p_transposed, 0)
def _weight_norm(v, g, dim):
shape = v.shape
ndims = len(shape)
if dim == -1:
v_normalized = v / (paddle.sqrt(paddle.sum(paddle.square(v))) + 1e-12)
elif dim == 0:
p_matrix = paddle.reshape(v, (shape[0], -1))
v_normalized = paddle.nn.functional.normalize(p_matrix, axis=1)
v_normalized = paddle.reshape(v_normalized, shape)
elif dim == ndims - 1:
p_matrix = paddle.reshape(v, (-1, shape[-1]))
v_normalized = paddle.nn.functional.normalize(p_matrix, axis=0)
v_normalized = paddle.reshape(v_normalized, shape)
else:
perm = list(range(ndims))
perm[0] = dim
perm[dim] = 0
p_transposed = paddle.transpose(v, perm)
transposed_shape = p_transposed.shape
p_matrix = paddle.reshape(p_transposed, (p_transposed.shape[0], -1))
v_normalized = paddle.nn.functional.normalize(p_matrix, axis=1)
v_normalized = paddle.reshape(v_normalized, transposed_shape)
v_normalized = paddle.transpose(v_normalized, perm)
weight = paddle.tensor.math._multiply_with_axis(
v_normalized, g, axis=dim if dim is not None else -1
)
return weight
class WeightNorm:
def __init__(self, name, dim):
if dim is None:
dim = -1
self.name = name
self.dim = dim
def compute_weight(self, layer):
g = getattr(layer, self.name + '_g')
v = getattr(layer, self.name + '_v')
return _weight_norm(v, g, self.dim)
@staticmethod
def apply(layer, name, dim):
for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError(
"Cannot register two weight_norm hooks on "
f"the same parameter {name}"
)
if dim is None:
dim = -1
# support dim is negative number, (dim = -1) == (dim = None)
weight_dim = len(layer._parameters[name].shape)
assert (
dim < weight_dim and dim >= -1 * weight_dim
), "dim must set between [-R, R), R means the dimension of weight."
if dim != -1:
dim = (dim + weight_dim) % weight_dim
fn = WeightNorm(name, dim)
w = getattr(layer, name)
del layer._parameters[name]
g_var = norm_except_dim(w, dim)
v = layer.create_parameter(w.shape, dtype=w.dtype)
layer.add_parameter(name + "_v", v)
g = layer.create_parameter(g_var.shape, dtype=g_var.dtype)
layer.add_parameter(name + '_g', g)
with paddle.no_grad():
paddle.assign(w, v)
paddle.assign(g_var, g)
setattr(layer, name, fn.compute_weight(layer))
layer.register_forward_pre_hook(fn)
return fn
def remove(self, layer):
w_var = self.compute_weight(layer)
delattr(layer, self.name)
del layer._parameters[self.name + '_g']
del layer._parameters[self.name + '_v']
w = layer.create_parameter(w_var.shape, dtype=w_var.dtype)
layer.add_parameter(self.name, w)
with paddle.no_grad():
paddle.assign(w_var, w)
def __call__(self, layer, inputs):
setattr(layer, self.name, self.compute_weight(layer))
def weight_norm(layer, name='weight', dim=0):
r"""
Applies weight normalization to a parameter according to the
following formula:
.. math::
\mathbf{w} = g \dfrac{v}{\|v\|}
Weight normalization is a reparameterization of the weight vectors in a neural network that
decouples the magnitude of those weight vectors from their direction. Weight normalization
replaces the parameter specified by ``name`` (eg: 'weight') with two parameters: one parameter
specifying the magnitude (eg: 'weight_g') and one parameter specifying the direction
(eg: 'weight_v'). Weight normalization has been implemented as discussed in this paper:
`Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks
<https://arxiv.org/pdf/1602.07868.pdf>`_.
Parameters:
layer(Layer): Layer of paddle, which has weight.
name(str, optional): Name of the weight parameter. Default: 'weight'.
dim(int, optional): Dimension over which to compute the norm. Dim is a non-negative number
which is less than the rank of weight Tensor. For Example, dim can be chosen from 0,
1, 2, 3 for convolution whose weight shape is [cout, cin, kh, kw] and rank is 4.
If dim is set to None, meaning that all elements will be normalized. Default: 0.
Returns:
Origin layer with weight norm hook.
Examples:
.. code-block:: python
>>> from paddle.nn import Conv2D
>>> from paddle.nn.utils import weight_norm
>>> conv = Conv2D(3, 5, 3)
>>> wn = weight_norm(conv)
>>> print(conv.weight_g.shape)
[5]
>>> print(conv.weight_v.shape)
[5, 3, 3, 3]
"""
WeightNorm.apply(layer, name, dim)
return layer
def remove_weight_norm(layer, name='weight'):
"""
remove weight normalization from layer.
Parameters:
layer(Layer): Layer of paddle, which has weight.
name(str, optional): Name of the weight parameter. Default: 'weight'.
Returns:
Layer, the origin layer without weight norm
Examples:
.. code-block:: python
>>> import paddle
>>> from paddle.nn import Conv2D
>>> from paddle.nn.utils import weight_norm, remove_weight_norm
>>> paddle.seed(2023)
>>> conv = Conv2D(3, 5, 3)
>>> wn = weight_norm(conv)
>>> print(conv.weight_g)
Parameter containing:
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=False,
[1.35883713, 1.32126212, 1.56303072, 1.20874095, 1.22893476])
>>> remove_weight_norm(conv)
>>> # The following is the effect after removing the weight norm:
>>> # print(conv.weight_g)
>>> # AttributeError: 'Conv2D' object has no attribute 'weight_g'
"""
for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
hook.remove(layer)
del layer._forward_pre_hooks[k]
return layer
raise ValueError(f"weight_norm of '{name}' not found in {layer}")