# Differentiable Array Indexing

Array indexing is inherently non-differentiable in traditional programming because discrete indices create discontinuous gradients. This notebook explores various strategies to make array lookups differentiable while maintaining gradient flow through both the array and the index.

## The Core Problem

Traditional array lookup `arr[index]` breaks differentiability because:
- **Discrete indices**: Small changes in the index don't affect the output until crossing integer boundaries
- **Zero gradients**: The gradient w.r.t. the index is zero almost everywhere
- **Non-continuous**: The function has discrete jumps at integer boundaries

## Why This Matters

Differentiable indexing enables:
- **Learnable attention mechanisms** in neural networks
- **Soft addressing** in memory-augmented models  
- **Gradient-based optimization** of indexing operations
- **End-to-end training** of algorithms that use array lookups

## Strategies Overview

We'll explore several approaches, each with different trade-offs:

1. **Naive Lookup**: Standard indexing (baseline, not fully differentiable)
2. **Linear Interpolation**: Smooth interpolation between adjacent elements
3. **Superposition Lookup**: Probabilistic weighted combination of all elements
4. **Residual Lookup**: Separate result and continuous residue
5. **Asymmetric Lookup**: Different forward/backward behaviors
6. **Array Assignment**: Differentiable element updates

**Note**: The choice of strategy depends on your specific application requirements.

**Further Reading**: [Neural Turing Machines](https://arxiv.org/abs/1410.5401) pioneered many of these techniques.

In [1]:
import tensorflow as tf
from library.statistical_math import to_prob_dist_all, entropy

## Strategy 1: Naive Lookup (Baseline)

Standard array indexing rounds the index to the nearest integer. This approach:
- ✅ **Gradient w.r.t. array**: Well-defined (one-hot vector)
- ❌ **Gradient w.r.t. index**: None (due to rounding operation)
- ✅ **Exact values**: Returns actual array elements
- ❌ **Limited differentiability**: Can't optimize the index

## Naive lookup
Naive lookup does produce a gradient wrt its input array but not wrt the index.

## Strategy 2: Linear Interpolation Lookup

Linear interpolation creates a smooth, differentiable lookup by blending between adjacent array elements. This is the 1D version of [bilinear interpolation](https://en.wikipedia.org/wiki/Bilinear_interpolation) used in image processing.

### Mathematical Foundation

For index $i$ with fractional part, we interpolate between `floor(i)` and `ceil(i)`:

$$\text{result} = t \cdot \text{arr}[\lfloor i \rfloor] + (1-t) \cdot \text{arr}[\lceil i \rceil]$$

where $t = \frac{i - \lfloor i \rfloor}{\lceil i \rceil - \lfloor i \rfloor}$ is the interpolation factor.

### Properties
- ✅ **Fully differentiable**: Gradients w.r.t. both array and index
- ✅ **Smooth transitions**: No discontinuities
- ❌ **Soft values**: May return values not in the original array
- ❌ **Local scope**: Only uses adjacent elements

### Limitation
This method can only interpolate between neighboring elements on the [number line](https://en.wikipedia.org/wiki/Number_line), making it unsuitable for non-local lookups.

![number line](images/1125px-Number-line.svg.png)

In [2]:
@tf.function
def naive_lookup(arr, index):
    index = tf.round(index)
    index = tf.cast(index, tf.int32)
    result = arr[index]
    return result

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    z = naive_lookup(arr, index)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, index))

tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor([0. 0. 1. 0. 0.], shape=(5,), dtype=float32)
None


### Interpolation Factor Calculation

The `interp_factor` function computes the linear interpolation parameters:

## Strategy 3: Superposition Lookup

This method uses a probability distribution over all array elements rather than a single index. The result is a weighted combination of all elements, making it fully differentiable and commonly used in attention mechanisms.

### Mathematical Foundation

Given an array $\mathbf{a}$ and probability distribution $\mathbf{p}$:

$$\text{result} = \sum_{i=1}^{n} p_i \cdot a_i = \mathbf{p}^T \mathbf{a}$$

### Properties
- ✅ **Fully differentiable**: Smooth gradients everywhere
- ✅ **Global scope**: Can access any combination of elements
- ✅ **Attention-like**: Widely used in neural networks
- ❌ **Soft values**: Result may not exist in original array
- ❌ **Computational cost**: Considers all elements

This is the foundation of attention mechanisms in transformers and memory-augmented neural networks!

## Linear lookup
In this method we use [linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) to interpolate between the two nearest candidates. For 2D arrays, [Bilinear interpolation](https://en.wikipedia.org/wiki/Bilinear_interpolation) can be used.

This gives a well defined gradient wrt to both the input and index. However, it is a soft lookup and can return values not present in the array itself.

One of the downsides of using this method is that it can only lookup adjacent cells in the [number line](https://en.wikipedia.org/wiki/Number_line).

![number line](images/1125px-Number-line.svg.png)


### Multi-dimensional Array Example

Let's test with a 2D array where each row is a different vector. The indices `[0.5, 0.5, 0, 0, 0]` mean we're equally weighting the first two rows:

### 1D Array Example

With a 1D array, the superposition lookup computes: `0.0×1 + 0.1×2 + 0.8×3 + 0.0×5 + 0.1×4 = 3.0`

The gradients show:
- **Array gradient**: The weights used for each element
- **Index gradient**: The array values themselves (showing sensitivity to weight changes)

### Converting Scalar Index to Distribution

The `bandwidthify` function converts a scalar index into a probability distribution using linear interpolation, then applies it via superposition lookup:

### Unified Superposition Lookup

This combines the best of both worlds: scalar index input with the power of superposition lookup. The result matches our linear interpolation example, showing the equivalence of the approaches for local lookups.

## Strategy 4: Residual Lookup

This approach separates the lookup into two parts: a discrete result (exact array element) and a continuous residue (fractional remainder). This preserves exact array values while maintaining some differentiability.

### Key Insight

- **Result**: Always an exact array element (maintains discrete behavior)
- **Residue**: The fractional part of the index (fully differentiable)
- **Use case**: When you need exact values but want to propagate some continuous information

### Properties
- ✅ **Exact values**: Result is always from original array
- ✅ **Partial differentiability**: Residue provides gradient information
- ❌ **Limited optimization**: Can't directly optimize array access
- ✅ **Information preservation**: Residue can be used by downstream operations

In [3]:
@tf.function
def interp_factor(index):
    t1 = tf.math.floor(index)
    t2 = tf.math.ceil(index)
    
    t = tf.math.divide_no_nan((index - t1), (t2 - t1))
    
    i1 = tf.cast(t1, tf.int32)
    i2 = tf.cast(t2, tf.int32)
    
    return t, i1, i2

## Strategy 5: Asymmetric Vectored Lookup

This advanced technique uses different behaviors for forward and backward passes, enabling discrete forward computation with custom gradient behavior.

### Algorithm

**Forward Pass**:
1. Find the most likely index (argmax of the probability vector)
2. Return the value at that exact index (discrete, exact)

**Backward Pass**:
1. Estimate target value: `target = forward_result - upstream_gradients`
2. Find array element closest to target (argmin of squared differences)
3. Create gradient that encourages the "correct" index and discourages others

### Mathematical Formulation

For probability vector $\mathbf{k}$ and value array $\mathbf{v}$:

**Forward**: $\text{result} = v_{\arg\max(\mathbf{k})}$

**Backward**: 
- Target: $t = \text{result} - \nabla_{\text{upstream}}$
- Best index: $j = \arg\min_i |v_i - t|^2$  
- Gradient: $\nabla k_i = \begin{cases} -1 & \text{if } i = j \\ +1 & \text{otherwise} \end{cases}$

### Properties
- ✅ **Exact forward values**: Always returns actual array elements
- ✅ **Trainable**: Custom gradients enable optimization
- ✅ **Task-adaptive**: Backward pass optimizes for downstream objectives
- ❌ **Complex**: Requires careful gradient design
- ❌ **Asymmetric**: Forward ≠ backward behavior

In [4]:
@tf.function
def linear_lookup(arr, index):
    t, i1, i2 = interp_factor(index)
    
    # Linear interpolation
    result = t * arr[i1] + (1 - t) * arr[i2]
    
    return result

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    z = linear_lookup(arr, index)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, index))

tf.Tensor(2.5, shape=(), dtype=float32)
tf.Tensor([0.  0.5 0.5 0.  0. ], shape=(5,), dtype=float32)
tf.Tensor(-1.0, shape=(), dtype=float32)


### Testing Asymmetric Lookup

Let's test the asymmetric lookup with two examples:
- Row 1: `k=[0,1,0]` selects `v=[1,2,3]` → chooses `v[1]=2`
- Row 2: `k=[1,0,0]` selects `v=[10,20,30]` → chooses `v[0]=10`

The custom gradients will be computed based on the loss and target values.

## Strategy 6: Differentiable Array Assignment

Traditional array assignment `arr[index] = value` is also non-differentiable. We can make it differentiable using masking techniques that blend between the old and new array states.

### Core Concept

Instead of discrete assignment, we use weighted combination:
$$\text{new\_arr} = \text{mask} \cdot \text{new\_value} + (1 - \text{mask}) \cdot \text{old\_arr}$$

where the mask determines which elements to update.

## Superposition lookup
In this method, we have a distribution instead of an integer index. This distribution usually comes after a softmax operation. The result is the dot product of the index and the input array. This is a very popular method in DNN literature.

### Soft Assignment with Probability Vectors

For soft assignment using probability distributions instead of discrete indices:

In [5]:
@tf.function
def superposition_lookup_vectored(arr, indices):
    if tf.rank(arr) == 1:
        arr = tf.expand_dims(arr, -1)
    indices = tf.expand_dims(indices, -1)
    result = arr * indices
    return tf.reduce_sum(result, axis=0)

arr = tf.Variable([
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1],
    [0, 1, 0, 0, 0],
    [1, 0, 0, 0, 0],
    [0, 0, 0, 1, 0],
],dtype=tf.float32)
indices = tf.Variable([0.5, 0.5, 0, 0, 0], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
#     indices = tf.nn.softmax(indices)
    z = superposition_lookup_vectored(arr, indices)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, indices))

tf.Tensor([0.  0.  0.5 0.  0.5], shape=(5,), dtype=float32)
tf.Tensor(
[[0.5 0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5 0.5]
 [0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  0. ]], shape=(5, 5), dtype=float32)
tf.Tensor([1. 1. 1. 1. 1.], shape=(5,), dtype=float32)


## Higher-Dimensional Arrays

The techniques extend naturally to higher-dimensional arrays. Here we implement 2D tensor operations using outer products and broadcasting.

In [6]:
arr = tf.Variable([1,2,3,5,4],dtype=tf.float32)
indices = tf.Variable([0.0, 0.1, 0.8, 0.0, 0.1], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
#     indices = tf.nn.softmax(indices)
    z = superposition_lookup_vectored(arr, indices)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, indices))

tf.Tensor([3.0000002], shape=(1,), dtype=float32)
tf.Tensor([0.  0.1 0.8 0.  0.1], shape=(5,), dtype=float32)
tf.Tensor([1. 2. 3. 5. 4.], shape=(5,), dtype=float32)


### Shape Matching and Broadcasting

The `match_shapes` function handles broadcasting between tensors of different dimensions, enabling flexible operations on multi-dimensional arrays.

### 2D Tensor Lookup

The `tensor_lookup_2d` function implements differentiable 2D array indexing using outer products of the row and column indices. This creates a 2D mask that selects the desired element while maintaining gradients.

### 2D Tensor Assignment

Similarly, `tensor_write_2d` implements differentiable 2D array assignment, allowing us to update specific positions in multi-dimensional arrays while preserving gradient flow.

## Summary

This notebook demonstrated multiple strategies for making array indexing differentiable:

### Strategy Comparison

| Method | Exact Values | Index Gradient | Array Gradient | Use Case |
|--------|--------------|----------------|----------------|----------|
| **Naive** | ✅ | ❌ | ✅ | Baseline comparison |
| **Linear** | ❌ | ✅ | ✅ | Local interpolation |
| **Superposition** | ❌ | ✅ | ✅ | Attention mechanisms |
| **Residual** | ✅ | Partial | ✅ | Hybrid approaches |
| **Asymmetric** | ✅ | ✅ | ✅ | Complex optimization |

### Key Insights

1. **Trade-offs**: Each method balances differentiability, exactness, and computational cost
2. **Application-dependent**: Choose based on whether you need exact values or smooth optimization
3. **Attention connection**: Superposition lookup is the foundation of attention mechanisms
4. **Custom gradients**: Enable sophisticated behaviors like asymmetric lookup
5. **Extensibility**: All techniques extend to higher-dimensional arrays

These differentiable indexing techniques are fundamental building blocks for neural network architectures, differentiable algorithms, and learnable data structures!

In [7]:
@tf.function
def bandwidthify(index, bandwidth):
    t, i1, i2 = interp_factor(index)
    
    # Prevent array out of bounds
    i1 = tf.clip_by_value(i1, 0, bandwidth - 1)
    i2 = tf.clip_by_value(i2, 0, bandwidth - 1)
    t = tf.clip_by_value(t, 0, 1)
    
    # Linear interpolation
    eye = tf.eye(bandwidth)
    result = t * eye[i1] + (1 - t) * eye[i2]
    
    return result

index = tf.Variable(2.5, dtype=tf.float32)
bandwidth = tf.constant(5, dtype=tf.int32)
dummy_array = tf.cast(tf.range(bandwidth), tf.float32)
with tf.GradientTape() as tape:
    z = bandwidthify(index, bandwidth)
    nz = superposition_lookup_vectored(dummy_array, z) # Lookup operation

print(z)
print(nz)
print(tape.gradient(nz, index))

tf.Tensor([0.  0.  0.5 0.5 0. ], shape=(5,), dtype=float32)
tf.Tensor([2.5], shape=(1,), dtype=float32)
tf.Tensor(-1.0, shape=(), dtype=float32)


In [8]:
@tf.function
def superposition_lookup(arr, index):
    bandwidth = tf.shape(arr)[0]
    vectored_index = bandwidthify(index, bandwidth)
    result = superposition_lookup_vectored(arr, vectored_index)
    
    return result

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    z = superposition_lookup(arr, index)

print(z)
print(tape.gradient(z, arr))
print(tape.gradient(z, index))

tf.Tensor([2.5], shape=(1,), dtype=float32)
tf.Tensor([0.  0.5 0.5 0.  0. ], shape=(5,), dtype=float32)
tf.Tensor(-1.0, shape=(), dtype=float32)


In [9]:
@tf.function
def bulk_bandwidthify(indices, bandwidth):
    num_indices = tf.shape(indices)[0]
    
    indices = tf.unstack(indices)
    result = tf.zeros((num_indices, bandwidth), dtype=tf.float32)
    result = tf.unstack(result)
    
    for i, index in enumerate(indices):
        b_index = bandwidthify(index, bandwidth)
        result[i] += b_index
    
    result = tf.stack(result)
    return result

indices = tf.Variable([1,2,3.5,0,4],dtype=tf.float32)
bandwidth = tf.constant(5, dtype=tf.int32)
dummy_array = tf.cast(tf.range(bandwidth), tf.float32)
with tf.GradientTape() as tape:
    z = bulk_bandwidthify(indices, bandwidth)
    nz = superposition_lookup_vectored(dummy_array, z) # Lookup operation

print(z)
print(nz)
print(tape.gradient(nz, indices))

tf.Tensor(
[[0.  1.  0.  0.  0. ]
 [0.  0.  1.  0.  0. ]
 [0.  0.  0.  0.5 0.5]
 [1.  0.  0.  0.  0. ]
 [0.  0.  0.  0.  1. ]], shape=(5, 5), dtype=float32)
tf.Tensor(
[[0. ]
 [1. ]
 [2. ]
 [1.5]
 [6. ]], shape=(5, 1), dtype=float32)
tf.Tensor([ 0.  0. -1.  0.  0.], shape=(5,), dtype=float32)


## Residual lookup
In this method, we return two tensors, the result and the residue. So, although the result is not differentiable wrt to index, the residue is. This allows us to propagate some extra information in parallel which can then be consumed intelligently by some algorithm in downstream. This has the benefit that the result always exists in the original array and is never an interpolation.

In [10]:
@tf.function
def residual_lookup(arr, index):
    i = tf.round(index)
    residue = index - i
    i = tf.cast(i, tf.int32)
    
    result = arr[i]
    
    return result, residue

arr = tf.Variable([1,2,3,4,5],dtype=tf.float32)
index = tf.Variable(1.5, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result, residue = residual_lookup(arr, index)

tf.print(result, residue)
tf.print(tape.gradient(result, arr), tape.gradient(residue, arr))
tf.print(tape.gradient(result, index), tape.gradient(residue, index))

3 -0.5
[0 0 1 0 0] [0 0 0 0 0]
0 1


## Asymmetrical Vectored Lookup

In this method, we calculate the result of the forward pass by finding the most likely index of the vector and returning the value associated with that index. The forward pass is thus non-differentiable. Therefore we have to define our own backward pass. To calculate the backward pass, we first estimate our target. The $target$ is the difference of the result obtained in the forward pass and the gradients from the loss. We now, find the value in our vector which is closest to the target. We want to increase the probability of this index while decreasing the probability of other indexes. So, we create a vector which is $-1$ at the index of the target and $1$ everywhere else. The optimizer substracts the gradient, so it has to be negative.

The gradients of the forward pass is not equal to the backward pass. In that sense, this is asymmetric.

In [11]:
@tf.function
@tf.custom_gradient
def asymmetrical_vectored_lookup(v, k):
    k_shape = tf.shape(k)

    # Pick the value at the most likely index, non-differentiably
    b_idx = tf.argmax(k, axis=-1)
    idx_len = tf.shape(b_idx)[0]
    a_idx = tf.range(idx_len, dtype=tf.int64)
    idx = tf.stack([a_idx, b_idx], axis=1)
    forward_result = tf.gather_nd(v, idx)

    def grad(upstream_grads):
        # Estimate the target scalar which we want to look up
        target = forward_result - upstream_grads
        target = tf.expand_dims(target, -1)

        # Find the index of element in the array which is closest to target
        diff_vector = tf.math.squared_difference(v, target)
        d_idx = tf.argmin(diff_vector, axis=-1)

        # Create a vector which is 1 everywhere except the idx
        # of the target, where it is -1
        ones = tf.ones(k_shape)
        eyes = tf.one_hot([d_idx], k_shape[-1])[0]
        k_grad = -(2 * eyes - ones)

        # d/dv (v . k) = k
        v_grad = k

        upstream_grads = tf.expand_dims(upstream_grads, -1)
        return upstream_grads * v_grad, tf.math.abs(upstream_grads) * k_grad

    return forward_result, grad

v = tf.constant([[1,2,3], [10,20,30]], dtype=tf.float32)
k = tf.constant([[0,1,0], [1,0,0]], dtype=tf.float32)
t = tf.constant([3, 20], dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(k)
    r = asymmetrical_vectored_lookup(v, k)
    loss = tf.nn.l2_loss(r - t)

tf.print(r)
tf.print(tape.gradient(r, k))
tf.print(tape.gradient(loss, k))

[2 10]
[[-1 1 1]
 [-1 1 1]]
[[1 1 -1]
 [10 -10 10]]


In [12]:
values = tf.constant([
    [20,40,30],
    [50,70,10],
], dtype=tf.float32)
choice = tf.Variable([
    [1,0,0],
    [0,0,1],
], dtype=tf.float32)

target = tf.constant([40, 70], dtype=tf.float32)

# opt = tf.keras.optimizers.Adam(3e-4)
# opt = tf.keras.optimizers.Adam(1e-2)
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    1e-1,
    decay_steps=100,
    decay_rate=1e-1,
    staircase=True)
opt = tf.keras.optimizers.Adam(lr_schedule)

steps = 10

for i in range(steps):
    with tf.GradientTape() as tape:
        out = asymmetrical_vectored_lookup(values, choice)
        target_loss = tf.nn.l2_loss(out - target)
        entropy_loss = entropy(choice)

        loss = target_loss + entropy_loss * 1e-2

    variables = [choice]
    grads = tape.gradient(loss, variables)

    opt.apply_gradients(zip(grads, variables))
    choice.assign(to_prob_dist_all(choice))

    if i % (steps // 10) == 0:
        tf.print(target_loss, entropy_loss, out, )

tf.round(choice * 100)

2000 [-0 -0] [20 10]
2000 [0.577248037 0.577248216] [20 10]
2000 [0.542812288 0.54283917] [20 10]
2000 [0.822554886 0.822548866] [20 10]
2000 [0.728512 0.728538215] [20 10]
0 [0.875905 0.875911474] [40 70]
0 [0.686750948 0.686771] [40 70]
0 [0.781492531 0.781499624] [40 70]
0 [0.552687764 0.552703917] [40 70]
0 [0.624352634 0.624367058] [40 70]


<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[12., 88.,  0.],
       [ 0., 88., 12.]], dtype=float32)>

## Array assignment
Tensorflow does not support direct index assignment of variables. So, instead we use a masking technique

In [13]:
@tf.function
def assign_index(arr, index, element):
    arr_shape = tf.shape(arr)
    
    pos_mask = tf.eye(arr_shape[0])[index]
    pos_mask = tf.transpose(tf.expand_dims(pos_mask, 0))
    neg_mask = 1 - pos_mask
    
    tiled_element = tf.reshape(tf.tile(element, [arr_shape[0]]), arr_shape)
    
    arr = arr * neg_mask + tiled_element * pos_mask
    
    return arr

arr = tf.Variable([
    [1,1,1],
    [2,2,2],
    [3,3,3]
],dtype=tf.float32)
index = tf.constant(1)
element = tf.Variable([4,4,4],dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    new_arr = assign_index(arr, index, element)
    
print(new_arr)
print(tape.gradient(new_arr, arr))
print(tape.gradient(new_arr, index))
print(tape.gradient(new_arr, element))

tf.Tensor(
[[1. 1. 1.]
 [4. 4. 4.]
 [3. 3. 3.]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
 [0. 0. 0.]
 [1. 1. 1.]], shape=(3, 3), dtype=float32)
None
tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)


Superpositioned assignment in case of a vector like index

In [14]:
@tf.function
def assign_index_vectored(arr, index, element):
    arr_shape = tf.shape(arr)
    
    pos_mask = tf.transpose(tf.expand_dims(index, 0))
    neg_mask = 1 - pos_mask
    
    tiled_element = tf.reshape(tf.tile(element, [arr_shape[0]]), arr_shape)

    arr = arr * neg_mask + tiled_element * pos_mask
    
    return arr

arr = tf.Variable([
    [1,1,1],
    [2,2,2],
    [3,3,3]
],dtype=tf.float32)
index1 = tf.Variable([0,1,0], dtype=tf.float32)
index2 = tf.Variable([0.5,0.5,0], dtype=tf.float32)
element = tf.Variable([4,4,4],dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    new_arr1 = assign_index_vectored(arr, index1, element)
    new_arr2 = assign_index_vectored(arr, index2, element)

tf.print(new_arr1, tape.gradient(new_arr1, index1))
tf.print(new_arr2, tape.gradient(new_arr2, index2))
print(tape.gradient(new_arr1, arr))
print(tape.gradient(new_arr1, element))

[[1 1 1]
 [4 4 4]
 [3 3 3]] [9 6 3]
[[2.5 2.5 2.5]
 [3 3 3]
 [3 3 3]] [9 6 3]
tf.Tensor(
[[1. 1. 1.]
 [0. 0. 0.]
 [1. 1. 1.]], shape=(3, 3), dtype=float32)
tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)


## Higher dimensional arrays

In [15]:
@tf.function
def match_shapes(x, y):
    # Find which one needs to be broadcasted
    low, high = (y, x) if tf.rank(x) > tf.rank(y) else (x, y)
    l_rank, l_shape = tf.rank(low), tf.shape(low)
    h_rank, h_shape = tf.rank(high), tf.shape(high)
    
    # Find the difference in ranks
    common_shape = h_shape[:l_rank]
    tf.debugging.assert_equal(common_shape, l_shape, 'No common shape to broadcast')
    padding = tf.ones(h_rank - l_rank, dtype=tf.int32)
    
    # Pad the difference with ones and reshape
    new_shape = tf.concat((common_shape, padding),axis=0)
    low = tf.reshape(low, new_shape)

    return high, low

@tf.function
def broadcast_multiply(x, y):
    x, y = match_shapes(x, y)
    return x * y
    
x = tf.ones((3, 3, 2)) * 3
y = tf.ones((3, 3)) * 2
broadcast_multiply(x, y)

<tf.Tensor: shape=(3, 3, 2), dtype=float32, numpy=
array([[[6., 6.],
        [6., 6.],
        [6., 6.]],

       [[6., 6.],
        [6., 6.],
        [6., 6.]],

       [[6., 6.],
        [6., 6.],
        [6., 6.]]], dtype=float32)>

In [16]:
@tf.function
def tensor_lookup_2d(arr, x_index, y_index):
    # Calculate outer product
    mask = tf.tensordot(x_index, y_index, axes=0)
    
    # Broadcast the mask to match dimensions with arr
    masked_arr = broadcast_multiply(mask, arr)
    
    # Reduce max to extract the cell
    element = tf.math.reduce_max(masked_arr, axis=[0,1])
    return element

arr = tf.Variable([
    [[1,1],[1,11],[1,111]],
    [[2,2],[2,22],[2,222]],
    [[3,3],[3,33],[3,333]]
],dtype=tf.float32)
x_index = tf.Variable(tf.one_hot(1, 3),dtype=tf.float32)
y_index = tf.Variable(tf.one_hot(2, 3),dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    element = tensor_lookup_2d(arr, x_index, y_index)
    
print(element)
print(tape.gradient(element, arr))
print(tape.gradient(element, x_index))
print(tape.gradient(element, y_index))

tf.Tensor([  2. 222.], shape=(2,), dtype=float32)
tf.Tensor(
[[[0. 0.]
  [0. 0.]
  [0. 0.]]

 [[0. 0.]
  [0. 0.]
  [1. 1.]]

 [[0. 0.]
  [0. 0.]
  [0. 0.]]], shape=(3, 3, 2), dtype=float32)
tf.Tensor([  0. 224.   0.], shape=(3,), dtype=float32)
tf.Tensor([  0.   0. 224.], shape=(3,), dtype=float32)


In [17]:
@tf.function
def tensor_write_2d(arr, element, x_index, y_index):
    arr_shape = tf.shape(arr)
    mask = tf.tensordot(x_index, y_index, axes=0)
    
    # Broadcast the mask to match dimensions with arr
    _, mask = match_shapes(arr, mask)
    
    element = tf.reshape(element,[1,1,-1])
    element = tf.tile(element, [arr_shape[0], arr_shape[1], 1])
    
    result = (1.0 - mask) * arr + mask * element
    
    return result

arr = tf.Variable([
    [[1,1],[1,11],[1,111]],
    [[2,2],[2,22],[2,222]],
    [[3,3],[3,33],[3,333]]
],dtype=tf.float32)
element = tf.Variable([5,555], dtype=tf.float32)
x_index = tf.Variable(tf.one_hot(1, 3),dtype=tf.float32)
y_index = tf.Variable(tf.one_hot(2, 3),dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    new_arr = tensor_write_2d(arr, element, x_index, y_index)
    
print(new_arr)
print(tape.gradient(new_arr, arr))
print(tape.gradient(new_arr, element))
print(tape.gradient(new_arr, x_index))
print(tape.gradient(new_arr, y_index))

tf.Tensor(
[[[  1.   1.]
  [  1.  11.]
  [  1. 111.]]

 [[  2.   2.]
  [  2.  22.]
  [  5. 555.]]

 [[  3.   3.]
  [  3.  33.]
  [  3. 333.]]], shape=(3, 3, 2), dtype=float32)
tf.Tensor(
[[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[1. 1.]
  [1. 1.]
  [0. 0.]]

 [[1. 1.]
  [1. 1.]
  [1. 1.]]], shape=(3, 3, 2), dtype=float32)
tf.Tensor([1. 1.], shape=(2,), dtype=float32)
tf.Tensor([448. 336. 224.], shape=(3,), dtype=float32)
tf.Tensor([556. 536. 336.], shape=(3,), dtype=float32)
