In [1]:
import tensorflow as tf

In [2]:

# tf.print('-'*80)
# tf.print(r1)
# tf.print(tape.gradient(r1, k1))
# tf.print(tape.gradient(loss_1, k1))
# tf.print('-'*80)
# tf.print(tape.gradient(r1, k2))
# tf.print(tape.gradient(loss_1, k2))

# tf.print('-'*80)
# tf.print(r2)
# tf.print(loss_2)
# tf.print(tape.gradient(r2, k1))
# tf.print(tape.gradient(loss_2, k1))
# tf.print('-'*80)
# tf.print(tape.gradient(r2, k2))
# tf.print(tape.gradient(loss_2, k2))

In [14]:
# @tf.function
@tf.custom_gradient
def asymmetrical_vectored_lookup(v, k):
    tf.debugging.assert_rank(v, 2)
    tf.debugging.assert_rank(k, 2)
    tf.debugging.assert_equal(tf.shape(v), tf.shape(k))

    k_shape = tf.shape(k)

    # Pick the value at the most likely index, non-differentiably
    b_idx = tf.argmax(k, axis=-1)
    idx_len = tf.shape(b_idx)[0]
    a_idx = tf.range(idx_len, dtype=tf.int64)
    idx = tf.stack([a_idx, b_idx], axis=1)
    forward_result = tf.gather_nd(v, idx)

    def grad(upstream_grads):
        tf.print(f'[upstream_grads] {upstream_grads}')
        tf.print(f'[forward_result] {forward_result}')
        # Estimate the target scalar which we want to look up
        target = forward_result - upstream_grads
        tf.print(f'[target] {target}')
        target = tf.expand_dims(target, -1)

        # Find the index of element in the array which is closest to target
        diff_vector = tf.math.squared_difference(v, target)
        d_idx = tf.argmin(diff_vector, axis=-1)
        tf.print(f'[d_idx] {d_idx}')

        # Create a vector which is 1 everywhere except the idx
        # of the target, where it is -1
        ones = tf.ones(k_shape)
        eyes = tf.one_hot([d_idx], k_shape[-1])[0]
        k_grad = -(2 * eyes - ones)

        # d/dv (v . k) = k
        v_grad = k

        upstream_grads = tf.expand_dims(upstream_grads, -1)

        # 1. The k_grad should dictate the direction of the vector. So, upstream grad is always positive
        # 2. We want it to scale to zero as it gets closer to target. So we clip it between 0 and 1.
        # 3. If there is an exact match in the vector, then we dont send the gradients downstream for other entries. (disabled)
        min_clipped_abs_grad = tf.abs(upstream_grads)
        min_clipped_abs_grad = tf.clip_by_value(min_clipped_abs_grad, 0, 1)
        # min_clipped_abs_grad = tf.reduce_min(min_clipped_abs_grad)
        tf.print(f'[k_grad] {tf.squeeze(k_grad)}')
        tf.print(f'[clipped_abs_grad] {tf.squeeze(min_clipped_abs_grad)}')
        tf.print(f'[final] {tf.squeeze(min_clipped_abs_grad * k_grad )}')
        tf.print(f'[v_grad] {tf.squeeze(v_grad)}')

        tf.print('-'*80)
        return upstream_grads * v_grad, min_clipped_abs_grad * k_grad 

    return forward_result, grad

# v = tf.constant([[1,2,3], [10,20,30]], dtype=tf.float32)
v = tf.constant([[1,11,3], [12,20,30]], dtype=tf.float32)
k1 = tf.constant([[1,0,0], [0,1,0]], dtype=tf.float32)
k2 = tf.constant([[1,0]], dtype=tf.float32)
t1 = tf.constant([2, 20], dtype=tf.float32)
t2 = tf.constant([10], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(k1)
    tape.watch(k2)

    r1 = asymmetrical_vectored_lookup(v, k1)
    loss_1 = tf.nn.l2_loss(r1 - t1)

    r1 = tf.expand_dims(r1, 0)

    r2 = asymmetrical_vectored_lookup(r1, k2)
    loss_2 = tf.nn.l2_loss(r2 - t2)

tf.print('-'*80)
tf.print(r2)
tf.print(r1)
tf.print('-'*80)
# tf.print(tape.gradient(loss_2, v))
# tf.print(tape.gradient(loss_2, k1))
# tf.print(tape.gradient(loss_1, k1))
# tf.print('-'*80)
tf.print(tape.gradient(loss_2, k2))
# tf.print('-'*80)


--------------------------------------------------------------------------------
[1]
[[1 20]]
--------------------------------------------------------------------------------
[upstream_grads] [-9.]
[forward_result] [1.]
[target] [10.]
[d_idx] [0]
[k_grad] [-1.  1.]
[clipped_abs_grad] 1.0
[final] [-1.  1.]
[v_grad] [1. 0.]
--------------------------------------------------------------------------------
[upstream_grads] [-9. -0.]
[forward_result] [ 1. 20.]
[target] [10. 20.]
[d_idx] [1 1]
[k_grad] [[ 1. -1.  1.]
 [ 1. -1.  1.]]
[clipped_abs_grad] [1. 0.]
[final] [[ 1. -1.  1.]
 [ 0. -0.  0.]]
[v_grad] [[1. 0. 0.]
 [0. 1. 0.]]
--------------------------------------------------------------------------------
[[-1 1]]


In [12]:
@tf.custom_gradient
def bar(x, y):
  def grad(upstream):
    tf.print(f'[upstream] {upstream}')
    dz_dx = y
    dz_dy = x
    return upstream * dz_dx, upstream * dz_dy
  z = x * y
  return z, grad
x = tf.constant([2.0,3.0], dtype=tf.float32)
y = tf.constant([3.0,4.0], dtype=tf.float32)
with tf.GradientTape(persistent=True) as tape:
  tape.watch(x)
  tape.watch(y)
  w = bar(x, y)
  z = bar(w, w)

tf.print(z)
tf.print(tape.gradient(z, x))
# tf.print(tape.gradient(z, y))
tf.print(2*x*y*y)

[36 144]
[upstream] [1. 1.]
[upstream] [12. 24.]
[36 96]
[36 96]
