## A beautiful way to realize advanced indexing in tensorflow by 'tf.stack' and 'tf.gather_nd'

Tensorflow doesn't support advanced indexing with the same way in numpy. Here, I will give an example of advanced index in numpy and tensorflow respectively.

In [1]:
# In numpy, we can use 'row index' and 'column index' to choose corresponding components
# 'row index' and 'column index' can be two lists with the same length 'n'
# It will select an array with the same length n. result[i] = X[row(i)][column(i)]:
import numpy as np
x = np.array([[1, 2, 3], [4, 5, 6]])
row = [0, 1, 1]
column = [0, 1, 2]
x[row, column]  # select x[0,0] x[1,1] x[1,2]

array([1, 5, 6])

In [3]:
# Another example
x1 = np.array([[0, 1, 2],
               [3, 4, 5],
               [6, 7, 8],
               [9, 10, 11]])
# Use advanced indexing one needs to select all elements 'explicitly'
row1 = np.array([[0, 0],
                 [3, 3]])
column1 = np.array([[0, 2],
                    [0, 2]])
x1[row1, column1]  # specify [0,0] [0,2] [3,0] [3,2]

array([[ 0,  2],
       [ 9, 11]])

> But this kind of advanced indexing slicing is impossible in tensorflow

In [4]:
# tensorflow example
# here, we try to use the same method used in numpy
import tensorflow as tf
x = tf.constant([[1, 2, 3], [4, 5, 6]])  # [2, 3] suppose it is the logit matrix
num = tf.range(2)  # rows
label = tf.placeholder(tf.int32, shape=[2])  # specify the label of each data
out = x[num, label]  # can't compile, it will have some errors

ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [2,3], [2,2], [2,2], [2].

In [5]:
# in this case, we need to use 'tf.stack' and 'tf.gather_nd'
# tf.gather_nd will select the corresponding components defined in 'indices',
# but is's different in numpy. It needs us to specify the two indices to each component.

# therefore, tf.stack is used to construct rowIndex and columnIndex into 'indices':
indices = tf.stack([num, label], axis=1)
output = tf.gather_nd(x, indices)
with tf.Session() as sess:
    ind, out = sess.run([indices, output], feed_dict={label: [1, 2]})
print('try to understand the format of its indices:')
print('indices: ', ind)  # [[0, 1], [1, 2]] to obtain x[0][1] and x[1][2]
# it means if the length of indices is n, it will select n components
print('output: ', out)

try to understand the format of its indices:
indices:  [[0 1]
 [1 2]]
output:  [2 6]


___

> PS: tf.stack and tf.concat are different. tf.stack will change tensor's rank
Stacks a list of rank-R tensors into one rank-(R+1) tensor

In [7]:
x = tf.constant([1, 4])
y = tf.constant([2, 5])
stacked_0 = tf.stack([x, y])
stacked_1 = tf.stack([x, y], axis=1)
concated_0 = tf.concat([x, y], axis=0)
# concated_1 = tf.concat([x, y], axis=1)  will produce errors
# rank-1 tensor does not have axis=1 (rank-2 could have columns which is axis=1)
with tf.Session() as sess:
    sta_0, sta_1 = sess.run([stacked_0, stacked_1])
    cat_0 = sess.run([concated_0])
    print('sta_0: ', sta_0)
    print('sta_1: ', sta_1)
    print()
    print('cat_0', cat_0)
    #print('cat_1', cat_1)

sta_0:  [[1 4]
 [2 5]]
sta_1:  [[1 2]
 [4 5]]

cat_0 [array([1, 4, 2, 5])]
