In [None]:
import tensorflow as tf

In [None]:
A = tf.random.uniform((3,3), minval=-1, maxval=10, dtype=tf.int32)
B = tf.random.uniform((3,3), minval=-1, maxval=10, dtype=tf.int32)

A, B

(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
 array([[ 4,  6,  0],
        [ 9,  8,  0],
        [-1,  4,  2]], dtype=int32)>,
 <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
 array([[ 7,  4,  1],
        [ 8,  7, -1],
        [ 6,  2,  8]], dtype=int32)>)

## gather

In [3]:
# Gather slices from params axis axis according to indices. indices must be an integer tensor of any dimension (often 1-D).
tf.gather(A, [2, 1], axis=0)

<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[-1,  4,  2],
       [ 9,  8,  0]], dtype=int32)>

In [4]:
tf.gather(A, [0, 1], axis=1)

<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[ 4,  6],
       [ 9,  8],
       [-1,  4]], dtype=int32)>

In [8]:
C = tf.random.uniform((2,3,3), minval=-1, maxval=10, dtype=tf.int32)
C

<tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
array([[[ 8,  8,  7],
        [ 0, -1,  3],
        [ 4, -1,  4]],

       [[ 5,  5,  5],
        [-1,  0, -1],
        [ 1, -1,  5]]], dtype=int32)>

In [12]:
# there is no index 2 for this case.
# WARNING: in previous versions, tensorflow would just put a tensor of 0s. But now it raises an error.
tf.gather(C, [0, 1], axis=0)

<tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
array([[[ 8,  8,  7],
        [ 0, -1,  3],
        [ 4, -1,  4]],

       [[ 5,  5,  5],
        [-1,  0, -1],
        [ 1, -1,  5]]], dtype=int32)>

In [14]:
tf.gather(C, [0, 2], axis=1)

<tf.Tensor: shape=(2, 2, 3), dtype=int32, numpy=
array([[[ 8,  8,  7],
        [ 4, -1,  4]],

       [[ 5,  5,  5],
        [ 1, -1,  5]]], dtype=int32)>

In [16]:
tf.gather(C, [0, 2], axis=2)

<tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
array([[[ 8,  7],
        [ 0,  3],
        [ 4,  4]],

       [[ 5,  5],
        [-1, -1],
        [ 1,  5]]], dtype=int32)>

## gather_nd

In [28]:
# gather_nd
# In tf.gather_nd, indices defines slices into the first N dimensions of params, where N = indices.shape[-1].
E = tf.random.uniform((3,4), minval=-1, maxval=10, dtype=tf.int32)
E

<tf.Tensor: shape=(3, 4), dtype=int32, numpy=
array([[ 6,  9,  4,  7],
       [ 4, -1,  0,  7],
       [ 9,  6,  9,  9]], dtype=int32)>

In [22]:
indices = [0]
tf.gather_nd(E, indices, batch_dims=0)

<tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 1, 1, 6], dtype=int32)>

In [25]:
# [row, col] -> limited to one value. Use [[row, col], ...] for more values.
indices = [0, 3]
tf.gather_nd(E, indices, batch_dims=0)

<tf.Tensor: shape=(), dtype=int32, numpy=6>

In [30]:
# [[row, col], [row, col]]
indices = [[0, 0], [1, 1]]
tf.gather_nd(E, indices, batch_dims=0)

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([ 6, -1], dtype=int32)>

In [37]:
indices = [0]
# batch_dims=1 ERROR: tensor E doesn't have batches, only rows and cols.
tf.gather_nd(E, indices, batch_dims=1)

ValueError: Argument `batch_dims` = 1 must be less than rank(`indices`) = 1

In [41]:
# [[row, col]]
indices = [[0,1], [0,0]]
tf.gather_nd(E, indices, batch_dims=0)

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([9, 6], dtype=int32)>

In [42]:
F = tf.random.uniform((2,3,4), minval=-1, maxval=10, dtype=tf.int32)
F

<tf.Tensor: shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 2,  2,  4,  3],
        [ 3, -1,  0,  1],
        [ 6,  2,  9, -1]],

       [[ 2,  7,  8,  0],
        [-1, -1,  5,  7],
        [-1,  3,  5,  6]]], dtype=int32)>

In [45]:
# [[batch, row]]
indices = [[0,0], [0,1]]
tf.gather_nd(F, indices, batch_dims=0)

<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[ 2,  2,  4,  3],
       [ 3, -1,  0,  1]], dtype=int32)>

In [56]:
# [[batch, row, col], [batch, row, col], ...]
indices = [[0,0,0], [1,2,0], [0,1,0]]
tf.gather_nd(F, indices, batch_dims=0)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 2, -1,  3], dtype=int32)>

In [52]:
# batch_dims = 1 makes the method batch wise. In other words, it makes the method apply the indices directly in each batch, in order.

# [(batch[0])[row, col], (batch[1])[row, col]] -> limited to one value from each batch. Use the notation in the cell below for more flexibility.
indices = [[0,2], [0,1]]
tf.gather_nd(F, indices, batch_dims=1)

<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 7], dtype=int32)>

In [66]:
# [ (batch[0])[ [row, col], [row, col] ],
#   (batch[1])[ [row, col], [row, col] ] ]
indices = [[[0,1],[2,2]],
           [[0,0],[1,2]]]
tf.gather_nd(F, indices, batch_dims=1)

<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[2, 9],
       [2, 5]], dtype=int32)>