# Differentiable indexing of arrays
Since, table or array lookup is an inherently non-differentiable process, the autograd is unable to resolve gradients of the result with respect to the output.

Here we have some strategies which gives a valid gradient with respect to both the index and input array. The choice of the strategy is problem specific. Also, this is by no means an exhaustive list. You can help by expanding it.

Abbreviation: WRT = With respect to

Further Reading: [Neural Turing Machines](https://arxiv.org/abs/1410.5401)

In [1]:
import tensorflow as tf

## Naive lookup
Naive lookup does produce a gradient wrt its input array but not wrt the index.

In [2]:
@tf.function
def naive_lookup(arr, index):
    index = tf.round(index)
    index = tf.cast(index, tf.int32)
    result = arr[index]
    return result

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    z = naive_lookup(arr, index)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, index))

tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor([0. 0. 1. 0. 0.], shape=(5,), dtype=float32)
None


## Superposition lookup
In this method, we have a distribution instead of an integer index. This distribution usually comes after a softmax operation. The result is the dot product of the index and the input array. This is a very popular method in DNN literature.

In [3]:
@tf.function
def superposition_lookup(arr, indices):
    result = arr * indices
    return tf.reduce_sum(result, axis=0)

arr = tf.Variable([1,2,3,5,4],dtype=tf.float32)
indices = tf.Variable([0.0, 0.1, 0.8, 0.0, 0.1], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
#     indices = tf.nn.softmax(indices)
    z = superposition_lookup(arr, indices)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, indices))

tf.Tensor(3.0000002, shape=(), dtype=float32)
tf.Tensor([0.  0.1 0.8 0.  0.1], shape=(5,), dtype=float32)
tf.Tensor([1. 2. 3. 5. 4.], shape=(5,), dtype=float32)


## Linear lookup
In this method we use [linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) to interpolate between the two nearest candidates. For 2D arrays, [Bilinear interpolation](https://en.wikipedia.org/wiki/Bilinear_interpolation) can be used.

This gives a well defined gradient wrt to both the input and index. However, it is a soft lookup and can return values not present in the array itself.

In [4]:
@tf.function
def linear_lookup(arr, index):
    t1 = tf.math.floor(index)
    t2 = tf.math.ceil(index)
    
    t = tf.math.divide_no_nan((index - t1), (t2 - t1))
    
    i1 = tf.cast(t1, tf.int32)
    i2 = tf.cast(t2, tf.int32)
    
    # Linear interpolation
    result = t * arr[i1] + (1 - t) * arr[i2]
    
    return result

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    z = linear_lookup(arr, index)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, index))

tf.Tensor(2.5, shape=(), dtype=float32)
tf.Tensor([0.  0.5 0.5 0.  0. ], shape=(5,), dtype=float32)
tf.Tensor(-1.0, shape=(), dtype=float32)


## Residual lookup
In this method, we return two tensors, the result and the residue. So, although the result is not differentiable wrt to index, the residue is. This allows us to propagate some extra information in parallel which can then be consumed intelligently by some algorithm in downstream. This has the benefit that the result always exists in the original array and is never an interpolation.

In [5]:
@tf.function
def residual_lookup(arr, index):
    i = tf.round(index)
    residue = index - i
    i = tf.cast(i, tf.int32)
    
    result = arr[i]
    
    return result, residue

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result, residue = residual_lookup(arr, index)

tf.print(result, residue)
tf.print(tape.gradient(result, arr), tape.gradient(residue, arr))
tf.print(tape.gradient(result, index), tape.gradient(residue, index))

3 -0.5
[0 0 1 0 0] [0 0 0 0 0]
0 1


## Array assignment
Tensorflow does not support direct index assignment of variables. So, instead we use a masking technique

In [6]:
@tf.function
def assign_index(arr, index, element):
    arr_shape = tf.shape(arr)
    
    pos_mask = tf.eye(arr_shape[0])[index]
    pos_mask = tf.transpose(tf.expand_dims(pos_mask, 0))
    neg_mask = 1 - pos_mask
    
    tiled_element = tf.reshape(tf.tile(element, [arr_shape[0]]), arr_shape)
    
    arr = arr * neg_mask + tiled_element * pos_mask
    
    return arr

arr = tf.Variable([
    [1,1,1],
    [2,2,2],
    [3,3,3]
],dtype=tf.float32)
index = tf.constant(1)
element = tf.Variable([4,4,4],dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    new_arr = assign_index(arr, index, element)
    
print(new_arr)
print(tape.gradient(new_arr, arr))
print(tape.gradient(new_arr, index))
print(tape.gradient(new_arr, element))

tf.Tensor(
[[1. 1. 1.]
 [4. 4. 4.]
 [3. 3. 3.]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
 [0. 0. 0.]
 [1. 1. 1.]], shape=(3, 3), dtype=float32)
None
tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)
