In [1]:
import numpy as np
import tensorflow as tf

# Numpy

In [2]:
# channels:
x = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
groups = 4

In [3]:
def channel_shuffle(x, groups):
    channels_per_group = int(len(x)/groups)
    # main algorithm:
    x = x.reshape(groups, channels_per_group)
    x = x.T
    x = x.reshape(groups*channels_per_group)
    return x

In [4]:
def print_grouped(x, groups):
    channels_per_group = int(len(x)/groups)
    s = []
    for i, v in enumerate(x):
        if i % channels_per_group == 0:
            s += ['|']
        s += [str(v)]
    s += ['|']
    return ' '.join(s)

In [5]:
# before
print_grouped(x, groups)

'| 0 1 2 | 3 4 5 | 6 7 8 | 9 10 11 |'

In [6]:
# and after channel shuffle
print_grouped(channel_shuffle(x, groups), groups)

'| 0 3 6 | 9 1 4 | 7 10 2 | 5 8 11 |'

In [7]:
# change number of groups
groups = 3
# and try again:

In [8]:
print_grouped(x, groups)

'| 0 1 2 3 | 4 5 6 7 | 8 9 10 11 |'

In [9]:
print_grouped(channel_shuffle(x, groups), groups)

'| 0 4 8 1 | 5 9 2 6 | 10 3 7 11 |'

# Tensorflow

In [10]:
def tf_channel_shuffle(X, groups):
    height, width, in_channels = X.shape.as_list()[1:]
    in_channels_per_group = int(in_channels/groups)

    shape = tf.stack([-1, height, width, groups, in_channels_per_group])
    X = tf.reshape(X, shape)
    
    X = tf.transpose(X, [0, 1, 2, 4, 3])
    
    shape = tf.stack([-1, height, width, in_channels])
    X = tf.reshape(X, shape)
    return X

In [11]:
# two images with height=3, width=4, channels=9
batch = np.stack([
    np.stack([i*np.ones((3, 4)) for i in range(0, 9)], axis=-1),
    np.stack([(i + 0.5)*np.ones((3, 4)) for i in range(0, 9)], axis=-1)
])
print(batch)

[[[[ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]]

  [[ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]]

  [[ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]
   [ 0.   1.   2.   3.   4.   5.   6.   7.   8. ]]]


 [[[ 0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5]
   [ 0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5]
   [ 0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5]
   [ 0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5]]

  [[ 0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5]
   [ 0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5]
   [ 0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5]
   [ 0.5  1.5  2.5  3.5  4.5  5.5  6.5  

In [12]:
tf.reset_default_graph()

P = tf.placeholder(tf.float32, shape=[None, 3, 4, 9])
X = tf_channel_shuffle(P, groups=3)

with tf.Session() as sess:
    x = sess.run(X, feed_dict={P: batch})
    print(x)

[[[[ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]]

  [[ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]]

  [[ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]
   [ 0.   3.   6.   1.   4.   7.   2.   5.   8. ]]]


 [[[ 0.5  3.5  6.5  1.5  4.5  7.5  2.5  5.5  8.5]
   [ 0.5  3.5  6.5  1.5  4.5  7.5  2.5  5.5  8.5]
   [ 0.5  3.5  6.5  1.5  4.5  7.5  2.5  5.5  8.5]
   [ 0.5  3.5  6.5  1.5  4.5  7.5  2.5  5.5  8.5]]

  [[ 0.5  3.5  6.5  1.5  4.5  7.5  2.5  5.5  8.5]
   [ 0.5  3.5  6.5  1.5  4.5  7.5  2.5  5.5  8.5]
   [ 0.5  3.5  6.5  1.5  4.5  7.5  2.5  5.5  8.5]
   [ 0.5  3.5  6.5  1.5  4.5  7.5  2.5  