In [1]:
import tensorflow as tf

In [2]:
# @tf.function
@tf.custom_gradient
def differentiable_sign(x):
    tf.debugging.assert_rank(x, 0)

    def grad(dy):
        dx = x + (1e-7 * tf.sign(x))
        return tf.math.abs(dy) * dx
    
    if x > tf.constant(0.0):
        return tf.constant(1.0), grad
    else:
        return tf.constant(-1.0), grad


x = tf.constant(3.0, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    y = differentiable_sign(x)
    loss = tf.nn.l2_loss(y - tf.constant(-1.0))
    
tf.print(y)
tf.print(tape.gradient(y, x))
tf.print(loss)
tf.print(tape.gradient(loss, x))

1
3
2
6


In [3]:
x = tf.Variable(1.0)
opt = tf.keras.optimizers.Adam(1e-1)
# opt = tf.keras.optimizers.SGD(1)

def train_step():
    with tf.GradientTape() as tape:
        y = differentiable_sign(x)
        loss = tf.nn.l2_loss(y - tf.constant(-1.0))
    grads = tape.gradient(loss, x)
    opt.apply_gradients(zip([grads], [x]))
    return loss, y, grads

for i in range(100):
    loss, y, grads = train_step()
    if i % 10 == 0:
        tf.print(i, loss, grads, x, y)

0 2 2.00000024 0.899997175 1
10 2 0.152491018 0.0051279515 1
20 0 -0 -0.420691907 -1
30 0 -0 -0.585012257 -1
40 0 -0 -0.649645507 -1
50 0 -0 -0.674813747 -1
60 0 -0 -0.684475422 -1
70 0 -0 -0.688136876 -1
80 0 -0 -0.689510047 -1
90 0 -0 -0.69002068 -1
