# Differentiable indexing of arrays
Since, table or array lookup is an inherently non-differentiable process, the autograd is unable to resolve gradients of the result with respect to the output.

Here we have some strategies which gives a valid gradient with respect to both the index and input array. The choice of the strategy is problem specific. Also, this is by no means an exhaustive list. You can help by expanding it.

Abbreviation: WRT = With respect to

Further Reading: [Neural Turing Machines](https://arxiv.org/abs/1410.5401)

In [1]:
import tensorflow as tf

## Naive lookup
Naive lookup does produce a gradient wrt its input array but not wrt the index.

In [2]:
@tf.function
def naive_lookup(arr, index):
    index = tf.round(index)
    index = tf.cast(index, tf.int32)
    result = arr[index]
    return result

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    z = naive_lookup(arr, index)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, index))

tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor([0. 0. 1. 0. 0.], shape=(5,), dtype=float32)
None


## Linear lookup
In this method we use [linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) to interpolate between the two nearest candidates. For 2D arrays, [Bilinear interpolation](https://en.wikipedia.org/wiki/Bilinear_interpolation) can be used.

This gives a well defined gradient wrt to both the input and index. However, it is a soft lookup and can return values not present in the array itself.

One of the downsides of using this method is that it can only lookup adjacent cells in the [number line](https://en.wikipedia.org/wiki/Number_line).

![number line](images/1125px-Number-line.svg.png)


In [3]:
@tf.function
def interp_factor(index):
    t1 = tf.math.floor(index)
    t2 = tf.math.ceil(index)
    
    t = tf.math.divide_no_nan((index - t1), (t2 - t1))
    
    i1 = tf.cast(t1, tf.int32)
    i2 = tf.cast(t2, tf.int32)
    
    return t, i1, i2

In [4]:
@tf.function
def linear_lookup(arr, index):
    t, i1, i2 = interp_factor(index)
    
    # Linear interpolation
    result = t * arr[i1] + (1 - t) * arr[i2]
    
    return result

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    z = linear_lookup(arr, index)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, index))

tf.Tensor(2.5, shape=(), dtype=float32)
tf.Tensor([0.  0.5 0.5 0.  0. ], shape=(5,), dtype=float32)
tf.Tensor(-1.0, shape=(), dtype=float32)


## Superposition lookup
In this method, we have a distribution instead of an integer index. This distribution usually comes after a softmax operation. The result is the dot product of the index and the input array. This is a very popular method in DNN literature.

In [5]:
@tf.function
def superposition_lookup_vectored(arr, indices):
    if tf.rank(arr) == 1:
        arr = tf.expand_dims(arr, -1)
    indices = tf.expand_dims(indices, -1)
    result = arr * indices
    return tf.reduce_sum(result, axis=0)

arr = tf.Variable([
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [0, 1, 0, 0, 0],
    [1, 0, 0, 0, 0],
    [0, 0, 0, 1, 0],
],dtype=tf.float32)
indices = tf.Variable([0.5, 0.5, 0, 0, 0], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
#     indices = tf.nn.softmax(indices)
    z = superposition_lookup_vectored(arr, indices)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, indices))

tf.Tensor([0.  0.  0.5 0.  0.5], shape=(5,), dtype=float32)
tf.Tensor(
[[0.5 0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5 0.5]
 [0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0. ]], shape=(5, 5), dtype=float32)
tf.Tensor([1. 1. 1. 1. 1.], shape=(5,), dtype=float32)


In [6]:
arr = tf.Variable([1,2,3,5,4],dtype=tf.float32)
indices = tf.Variable([0.0, 0.1, 0.8, 0.0, 0.1], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
#     indices = tf.nn.softmax(indices)
    z = superposition_lookup_vectored(arr, indices)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, indices))

tf.Tensor([3.0000002], shape=(1,), dtype=float32)
tf.Tensor([0.  0.1 0.8 0.  0.1], shape=(5,), dtype=float32)
tf.Tensor([1. 2. 3. 5. 4.], shape=(5,), dtype=float32)


In [7]:
@tf.function
def bandwidthify(index, bandwidth):
    t, i1, i2 = interp_factor(index)
    
    # Prevent array out of bounds
    i1 = tf.clip_by_value(i1, 0, bandwidth - 1)
    i2 = tf.clip_by_value(i2, 0, bandwidth - 1)
    t = tf.clip_by_value(t, 0, 1)
    
    # Linear interpolation
    eye = tf.eye(bandwidth)
    result = t * eye[i1] + (1 - t) * eye[i2]
    
    return result

index = tf.Variable(2.5, dtype=tf.float32)
bandwidth = tf.constant(5, dtype=tf.int32)
dummy_array = tf.cast(tf.range(bandwidth), tf.float32)
with tf.GradientTape() as tape:
    z = bandwidthify(index, bandwidth)
    nz = superposition_lookup_vectored(dummy_array, z) # Lookup operation

print(z)
print(nz)
print(tape.gradient(nz, index))

tf.Tensor([0.  0.  0.5 0.5 0. ], shape=(5,), dtype=float32)
tf.Tensor([2.5], shape=(1,), dtype=float32)
tf.Tensor(-1.0, shape=(), dtype=float32)


In [8]:
@tf.function
def superposition_lookup(arr, index):
    bandwidth = tf.shape(arr)[0]
    vectored_index = bandwidthify(index, bandwidth)
    result = superposition_lookup_vectored(arr, vectored_index)
    
    return result

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    z = superposition_lookup(arr, index)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, index))

tf.Tensor([2.5], shape=(1,), dtype=float32)
tf.Tensor([0.  0.5 0.5 0.  0. ], shape=(5,), dtype=float32)
tf.Tensor(-1.0, shape=(), dtype=float32)


In [9]:
@tf.function
def bulk_bandwidthify(indices, bandwidth):
    num_indices = tf.shape(indices)[0]
    
    indices = tf.unstack(indices)
    result = tf.zeros((num_indices, bandwidth), dtype=tf.float32)
    result = tf.unstack(result)
    
    for i, index in enumerate(indices):
        b_index = bandwidthify(index, bandwidth)
        result[i] += b_index
    
    result = tf.stack(result)
    return result

indices = tf.Variable([1,2,3.5,0,4],dtype=tf.float32)
bandwidth = tf.constant(5, dtype=tf.int32)
dummy_array = tf.cast(tf.range(bandwidth), tf.float32)
with tf.GradientTape() as tape:
    z = bulk_bandwidthify(indices, bandwidth)
    nz = superposition_lookup_vectored(dummy_array, z) # Lookup operation

print(z)
print(nz)
print(tape.gradient(nz, indices))

tf.Tensor(
[[0.  1.  0.  0.  0. ]
 [0.  0.  1.  0.  0. ]
 [0.  0.  0.  0.5 0.5]
 [1.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  1. ]], shape=(5, 5), dtype=float32)
tf.Tensor(
[[0. ]
 [1. ]
 [2. ]
 [1.5]
 [6. ]], shape=(5, 1), dtype=float32)
tf.Tensor([ 0.  0. -1.  0.  0.], shape=(5,), dtype=float32)


## Residual lookup
In this method, we return two tensors, the result and the residue. So, although the result is not differentiable wrt to index, the residue is. This allows us to propagate some extra information in parallel which can then be consumed intelligently by some algorithm in downstream. This has the benefit that the result always exists in the original array and is never an interpolation.

In [10]:
@tf.function
def residual_lookup(arr, index):
    i = tf.round(index)
    residue = index - i
    i = tf.cast(i, tf.int32)
    
    result = arr[i]
    
    return result, residue

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result, residue = residual_lookup(arr, index)

tf.print(result, residue)
tf.print(tape.gradient(result, arr), tape.gradient(residue, arr))
tf.print(tape.gradient(result, index), tape.gradient(residue, index))

3 -0.5
[0 0 1 0 0] [0 0 0 0 0]
0 1


## Array assignment
Tensorflow does not support direct index assignment of variables. So, instead we use a masking technique

In [11]:
@tf.function
def assign_index(arr, index, element):
    arr_shape = tf.shape(arr)
    
    pos_mask = tf.eye(arr_shape[0])[index]
    pos_mask = tf.transpose(tf.expand_dims(pos_mask, 0))
    neg_mask = 1 - pos_mask
    
    tiled_element = tf.reshape(tf.tile(element, [arr_shape[0]]), arr_shape)
    
    arr = arr * neg_mask + tiled_element * pos_mask
    
    return arr

arr = tf.Variable([
    [1,1,1],
    [2,2,2],
    [3,3,3]
],dtype=tf.float32)
index = tf.constant(1)
element = tf.Variable([4,4,4],dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    new_arr = assign_index(arr, index, element)
    
print(new_arr)
print(tape.gradient(new_arr, arr))
print(tape.gradient(new_arr, index))
print(tape.gradient(new_arr, element))

tf.Tensor(
[[1. 1. 1.]
 [4. 4. 4.]
 [3. 3. 3.]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
 [0. 0. 0.]
 [1. 1. 1.]], shape=(3, 3), dtype=float32)
None
tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)


Superpositioned assignment in case of a vector like index

In [12]:
@tf.function
def assign_index_vectored(arr, index, element):
    arr_shape = tf.shape(arr)
    
    pos_mask = tf.transpose(tf.expand_dims(index, 0))
    neg_mask = 1 - pos_mask
    
    tiled_element = tf.reshape(tf.tile(element, [arr_shape[0]]), arr_shape)

    arr = arr * neg_mask + tiled_element * pos_mask
    
    return arr

arr = tf.Variable([
    [1,1,1],
    [2,2,2],
    [3,3,3]
],dtype=tf.float32)
index1 = tf.Variable([0,1,0], dtype=tf.float32)
index2 = tf.Variable([0.5,0.5,0], dtype=tf.float32)
element = tf.Variable([4,4,4],dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    new_arr1 = assign_index_vectored(arr, index1, element)
    new_arr2 = assign_index_vectored(arr, index2, element)

tf.print(new_arr1, tape.gradient(new_arr1, index1))
tf.print(new_arr2, tape.gradient(new_arr2, index2))
print(tape.gradient(new_arr1, arr))
print(tape.gradient(new_arr1, element))

[[1 1 1]
 [4 4 4]
 [3 3 3]] [9 6 3]
[[2.5 2.5 2.5]
 [3 3 3]
 [3 3 3]] [9 6 3]
tf.Tensor(
[[1. 1. 1.]
 [0. 0. 0.]
 [1. 1. 1.]], shape=(3, 3), dtype=float32)
tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)


## Higher dimensional arrays

In [13]:
@tf.function
def match_shapes(x, y):
    # Find which one needs to be broadcasted
    low, high = (y, x) if tf.rank(x) > tf.rank(y) else (x, y)
    l_rank, l_shape = tf.rank(low), tf.shape(low)
    h_rank, h_shape = tf.rank(high), tf.shape(high)
    
    # Find the difference in ranks
    common_shape = h_shape[:l_rank]
    tf.debugging.assert_equal(common_shape, l_shape, 'No common shape to broadcast')
    padding = tf.ones(h_rank - l_rank, dtype=tf.int32)
    
    # Pad the difference with ones and reshape
    new_shape = tf.concat((common_shape, padding),axis=0)
    low = tf.reshape(low, new_shape)

    return high, low

@tf.function
def broadcast_multiply(x, y):
    x, y = match_shapes(x, y)
    return x * y
    
x = tf.ones((3, 3, 2)) * 3
y = tf.ones((3, 3)) * 2
broadcast_multiply(x, y)

<tf.Tensor: shape=(3, 3, 2), dtype=float32, numpy=
array([[[6., 6.],
        [6., 6.],
        [6., 6.]],

       [[6., 6.],
        [6., 6.],
        [6., 6.]],

       [[6., 6.],
        [6., 6.],
        [6., 6.]]], dtype=float32)>

In [14]:
@tf.function
def tensor_lookup_2d(arr, x_index, y_index):
    # Calculate outer product
    mask = tf.tensordot(x_index, y_index, axes=0)
    
    # Broadcast the mask to match dimensions with arr
    masked_arr = broadcast_multiply(mask, arr)
    
    # Reduce max to extract the cell
    element = tf.math.reduce_max(masked_arr, axis=[0,1])
    return element

arr = tf.Variable([
    [[1,1],[1,11],[1,111]],
    [[2,2],[2,22],[2,222]],
    [[3,3],[3,33],[3,333]]
],dtype=tf.float32)
x_index = tf.Variable(tf.one_hot(1, 3),dtype=tf.float32)
y_index = tf.Variable(tf.one_hot(2, 3),dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    element = tensor_lookup_2d(arr, x_index, y_index)
    
print(element)
print(tape.gradient(element, arr))
print(tape.gradient(element, x_index))
print(tape.gradient(element, y_index))

tf.Tensor([  2. 222.], shape=(2,), dtype=float32)
tf.Tensor(
[[[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [1. 1.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]], shape=(3, 3, 2), dtype=float32)
tf.Tensor([  0. 224.   0.], shape=(3,), dtype=float32)
tf.Tensor([  0.   0. 224.], shape=(3,), dtype=float32)


In [15]:
@tf.function
def tensor_write_2d(arr, element, x_index, y_index):
    arr_shape = tf.shape(arr)
    mask = tf.tensordot(x_index, y_index, axes=0)
    
    # Broadcast the mask to match dimensions with arr
    _, mask = match_shapes(arr, mask)
    
    element = tf.reshape(element,[1,1,-1])
    element = tf.tile(element, [arr_shape[0], arr_shape[1], 1])
    
    result = (1.0 - mask) * arr + mask * element
    
    return result

arr = tf.Variable([
    [[1,1],[1,11],[1,111]],
    [[2,2],[2,22],[2,222]],
    [[3,3],[3,33],[3,333]]
],dtype=tf.float32)
element = tf.Variable([5,555], dtype=tf.float32)
x_index = tf.Variable(tf.one_hot(1, 3),dtype=tf.float32)
y_index = tf.Variable(tf.one_hot(2, 3),dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    new_arr = tensor_write_2d(arr, element, x_index, y_index)
    
print(new_arr)
print(tape.gradient(new_arr, arr))
print(tape.gradient(new_arr, element))
print(tape.gradient(new_arr, x_index))
print(tape.gradient(new_arr, y_index))

tf.Tensor(
[[[  1.   1.]
  [  1.  11.]
  [  1. 111.]]

 [[  2.   2.]
  [  2.  22.]
  [  5. 555.]]

 [[  3.   3.]
  [  3.  33.]
  [  3. 333.]]], shape=(3, 3, 2), dtype=float32)
tf.Tensor(
[[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [0. 0.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]], shape=(3, 3, 2), dtype=float32)
tf.Tensor([1. 1.], shape=(2,), dtype=float32)
tf.Tensor([448. 336. 224.], shape=(3,), dtype=float32)
tf.Tensor([556. 536. 336.], shape=(3,), dtype=float32)
