-
Notifications
You must be signed in to change notification settings - Fork 21
/
centralized_gradients.py
55 lines (45 loc) · 1.98 KB
/
centralized_gradients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import tensorflow as tf
import keras.backend as K
def get_centralized_gradients(optimizer, loss, params):
"""
Compute the centralized gradients.
Modified version of tf.keras.optimizers.Optimizer.get_gradients
Reference:
https://arxiv.org/abs/2004.01461
"""
# We here just provide a modified get_gradients() function since we are trying to just compute the centralized
# gradients at this stage which can be used in other optimizers.
grads = []
for grad in K.gradients(loss, params):
grad_len = len(grad.shape)
if grad_len > 1:
axis = list(range(grad_len - 1))
grad -= tf.reduce_mean(grad,
axis=axis,
keep_dims=True)
grads.append(grad)
if None in grads:
raise ValueError('An operation has `None` for gradient. '
'Please make sure that all of your ops have a '
'gradient defined (i.e. are differentiable). '
'Common ops without gradient: '
'K.argmax, K.round, K.eval.')
if hasattr(optimizer, 'clipnorm') and optimizer.clipnorm > 0:
norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads]))
grads = [
tf.keras.optimizers.clip_norm(
g,
optimizer.clipnorm,
norm) for g in grads]
if hasattr(optimizer, 'clipvalue') and optimizer.clipvalue > 0:
grads = [K.clip(g, -optimizer.clipvalue, optimizer.clipvalue)
for g in grads]
return grads
def centralized_gradients_for_optimizer(optimizer):
"""Create a centralized gradients functions for an optimizer.
# Arguments
optimizer: a tf.keras.optimizer object. The optimizer you are using.
"""
def get_centralized_gradients_for_optimizer(loss, params):
return get_centralized_gradients(optimizer, loss, params)
return get_centralized_gradients_for_optimizer