# Differentiable Bubble Sort

Differentiable implementation of bubble sort with configurable (learnable) comparator function


In [1]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import layers, Input
from tensorflow.keras.models import Model
import numpy as np

In [2]:
tf.executing_eagerly()

True

In [3]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

## Swap Function

Using linear interpolation for continious swap.

\begin{equation*}
new_a = a * t + b * (1 - t)
\end{equation*}
\begin{equation*}
new_b = b * t + a * (1 - t)
\end{equation*}

When $t = 0$, then $a$ and $b$ are swapped. When $t = 1$, they remain in place.

Other compare and swap strategies include [softmax approximation](https://github.com/johnhw/differentiable_sorting), [optimal transport](https://arxiv.org/pdf/1905.11885.pdf), [projecting into higher dimensional space](https://arxiv.org/pdf/2002.08871.pdf) etc


In [4]:
@tf.function
def swap(x, i, j, t=None):
    '''
        Linear interpolation swap
        x: Tensor - Expected dims: [2, feature_size]
        i: Tensor - Scalar, int-like
        j: Tensor - Scalar, int-like
        t: Tensor - Scalar, float-like
    '''
    x_shape = tf.shape(x)
    x_len = x_shape[0]
    
    if t is None:
        t = tf.zeros(x_shape)
    
    i_pos_mask = tf.expand_dims(tf.eye(x_len)[i],axis=-1)
    i_neg_mask = 1 - i_pos_mask
    i_element = K.sum(i_pos_mask * x, axis=0)
    
    j_pos_mask = tf.expand_dims(tf.eye(x_len)[j],axis=-1)
    j_neg_mask = 1 - j_pos_mask
    j_element = K.sum(j_pos_mask * x, axis=0)
    
    i_interp_element = t * i_element + (1 - t) * j_element
    j_interp_element = t * j_element + (1 - t) * i_element
    
    x = x * i_neg_mask + i_interp_element * i_pos_mask
    x = x * j_neg_mask + j_interp_element * j_pos_mask
    
    return x

In [5]:
x = tf.Variable([
    [1, 1, 0, 0],
    [1, 1, 0, 1],
    [1, 0, 0, 0],
    [1, 0, 1, 0],
    [1, 1, 1, 0]
],dtype=tf.float32)
t = tf.Variable(0 * tf.ones(tf.shape(x)),dtype=tf.float32)
i = tf.Variable(1,dtype=tf.int32)
j = tf.Variable(2,dtype=tf.int32)
with tf.GradientTape(persistent=True) as tape:
#     z = swap(x, i, j)
    z = swap(x, i, j, t)

print(z)
print(tape.gradient(z,x))
print(tape.gradient(z,t))

tf.Tensor(
[[1. 1. 0. 0.]
 [1. 0. 0. 0.]
 [1. 1. 0. 1.]
 [1. 0. 1. 0.]
 [1. 1. 1. 0.]], shape=(5, 4), dtype=float32)
tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]], shape=(5, 4), dtype=float32)
tf.Tensor(
[[ 0.  0.  0.  0.]
 [ 0.  1.  0.  1.]
 [ 0. -1.  0. -1.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]], shape=(5, 4), dtype=float32)


In [6]:
x = tf.Variable([[1],[2],[3]],dtype=tf.float32)
t = tf.Variable(0 * tf.ones(tf.shape(x)),dtype=tf.float32)
i = tf.Variable(0,dtype=tf.int32)
j = tf.Variable(1,dtype=tf.int32)
with tf.GradientTape(persistent=True) as tape:
#     z = swap(x, i, j)
    z = swap(x, i, j, t)

print(z)
print(tape.gradient(z,x))
print(tape.gradient(z,t))

tf.Tensor(
[[2.]
 [1.]
 [3.]], shape=(3, 1), dtype=float32)
tf.Tensor(
[[1.]
 [1.]
 [1.]], shape=(3, 1), dtype=float32)
tf.Tensor(
[[-1.]
 [ 1.]
 [ 0.]], shape=(3, 1), dtype=float32)


## Bubble sort

Standard bubble sort implementation with injectable comparator function. It is to be noted that the $t$ parameter is used to decide whether to swap or not instead of having explicit conditionals.

In [7]:
@tf.function
def bubble_sort(x, cmp_fun):
    '''
        Bubble sort
        x: Tensor - Expected dims: [array_length, feature_size]
        cmp_fun: Function
    '''
    x_len = tf.shape(x)[0]
    for i in range(x_len):
        for j in range(i+1, x_len):
            cmp_x = tf.concat([x[i], x[j]], axis=0)
            cmp_x = tf.reshape(cmp_x, [1, 2, -1])
            t = cmp_fun(cmp_x)[0]
            x = swap(x, i, j, t)
    return x

### Sample comparator function

A sample comparator function for testing. The `tf.sign` makes it non-differentiable.

For the sake of the example. It counts the number of $1$s in the array.

In [8]:
@tf.function
def sample_comparator(x):
    '''
        x: Tensor - Expected dims: [batch_size, 2, feature_size]
    '''
    sv = tf.reduce_sum(x, axis=-1)
    sv = tf.subtract(sv[:,0], sv[:,1])
    return 1 - (tf.sign(sv) + 1) / 2

with tf.GradientTape() as tape:
    x = tf.Variable([
        [1,0,0,0],
        [1,1,1,1]
    ], dtype=tf.float32)
    cmp_x = tf.concat([x[0], x[1]], axis=0)
    cmp_x = tf.reshape(cmp_x, [1, 2, -1])
    cmp_result = sample_comparator(cmp_x)
    print(cmp_result)
    grad = tape.gradient(cmp_result, x)
    print(grad)

tf.Tensor([1.], shape=(1,), dtype=float32)
tf.Tensor(
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]], shape=(2, 4), dtype=float32)


In [9]:
x = tf.Variable([[3],[1],[2]],dtype=tf.float32)
with tf.GradientTape() as tape:
    z = bubble_sort(x, sample_comparator)
    print(z)
    print(tape.gradient(z,x))

tf.Tensor(
[[1.]
 [2.]
 [3.]], shape=(3, 1), dtype=float32)
tf.Tensor(
[[1.]
 [1.]
 [1.]], shape=(3, 1), dtype=float32)


In [10]:
x = tf.Variable([
    [1, 1, 0],
    [1, 0, 0],
    [1, 1, 1]
],dtype=tf.float32)
with tf.GradientTape(persistent=True) as tape:
    z = bubble_sort(x, sample_comparator)

print(z)
print(tape.gradient(z,x))

tf.Tensor(
[[1. 0. 0.]
 [1. 1. 0.]
 [1. 1. 1.]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]], shape=(3, 3), dtype=float32)


In [11]:
x = tf.Variable([
    [1, 1, 0, 0],
    [1, 1, 0, 1],
    [1, 0, 0, 0],
    [1, 0, 1, 0],
    [1, 1, 1, 0]
],dtype=tf.float32)
with tf.GradientTape(persistent=True) as tape:
    z = bubble_sort(x, sample_comparator)

print(z)
print(tape.gradient(z,x))

tf.Tensor(
[[1.  0.  0.  0. ]
 [1.  0.5 0.5 0. ]
 [1.  0.5 0.5 0. ]
 [1.  1.  0.5 0.5]
 [1.  1.  0.5 0.5]], shape=(5, 4), dtype=float32)
tf.Tensor(
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]], shape=(5, 4), dtype=float32)


In [12]:
data_gen = lambda: np.tril(np.ones((10,10),dtype=np.float32))
actual_data = data_gen()
actual_data

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)

In [13]:
shuffled_data = data_gen()
np.random.shuffle(shuffled_data)
shuffled_data

array([[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]], dtype=float32)

In [14]:
z = bubble_sort(shuffled_data, sample_comparator)
z



<tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>

## Learnable Comparator Function

Since the setup is end-to-end differentiable. We can use a DNN as the comparator function and expect it to learn using backpropagation.

In [15]:
class ComparatorBlock(layers.Layer):
    def __init__(self):
        super(ComparatorBlock, self).__init__()
        self.dense1 = layers.Dense(10, kernel_initializer="he_normal",activation='relu')
        self.dense2 = layers.Dense(10, kernel_initializer="he_normal",activation='relu')
        self.dense3 = layers.Dense(1, activation='sigmoid')

    def build(self, input_shape):
        super(ComparatorBlock, self).build(input_shape)

    def call(self, x):
        vector_len = tf.shape(x)[-1]
        h = tf.reshape(x, [-1, 2 * vector_len])
        h = self.dense1(h)
        h = self.dense2(h)
        h = self.dense3(h)
        return h

In [16]:
# temp_comparator = ComparatorBlock()
# batch_size = 10
# vector_length = 10
# input_shape = (batch_size, 2, vector_length)
# output_shape = (batch_size, 1)
# x = tf.random.normal(input_shape)
# y = tf.math.round(tf.random.uniform(output_shape, minval=0, maxval=1))
# result = temp_comparator(x)
# print(x.shape, result.shape, y.shape)
# # print(len(temp_comparator.trainable_variables))

# a = Input(shape=(2, vector_length))
# b = temp_comparator(a)
# m = Model(inputs=a, outputs=b)
# m.compile(loss='mse', optimizer='adam')
# m.fit(x=x,y=y,epochs=100,batch_size=batch_size)

In [17]:
learned_comparator = ComparatorBlock()
learned_comparator(tf.zeros((1,2,shuffled_data.shape[-1])))
z = bubble_sort(shuffled_data, learned_comparator)
# print(z)
print(tf.nn.l2_loss(z - actual_data))

tf.Tensor(8.614502, shape=(), dtype=float32)


In [18]:
x = tf.Variable(shuffled_data, dtype=tf.float32)
with tf.GradientTape() as tape:
    z = bubble_sort(x, learned_comparator)
    loss = tf.nn.l2_loss(z - actual_data)
    grads = tape.gradient(loss, learned_comparator.trainable_variables)
    print(grads)

[<tf.Tensor: shape=(20, 10), dtype=float32, numpy=
array([[-0.13441893,  0.        ,  0.        , -0.2395423 ,  0.        ,
         0.47758985,  0.10906857,  0.15641265,  0.01129907,  0.20773324],
       [-0.02731699,  0.        ,  0.        , -0.12215602,  0.        ,
         0.5126155 , -0.05847526,  0.15641265,  0.02561067,  0.1794533 ],
       [-0.02837466,  0.        ,  0.        , -0.09804745,  0.        ,
         0.4148222 , -0.06289126,  0.15641265,  0.01812573,  0.15368316],
       [ 0.02158301,  0.        ,  0.        , -0.0093091 ,  0.        ,
         0.4866098 , -0.18216284,  0.15641265,  0.01812573,  0.15059519],
       [ 0.03758298,  0.        ,  0.        ,  0.04376996,  0.        ,
         0.40505195, -0.23971845,  0.15641265,  0.01408147,  0.12554932],
       [ 0.00737026,  0.        ,  0.        ,  0.03066381,  0.        ,
         0.25486577, -0.14665182,  0.07815661,  0.00488025,  0.0834264 ],
       [ 0.0394711 ,  0.        ,  0.        ,  0.07257791,  0.    

## Training

We can train the setup end-to-end withing Adam optimizer.

In [19]:
x = tf.Variable(shuffled_data, dtype=tf.float32)
opt = tf.keras.optimizers.Adam(learning_rate=3e-4)

for i in range(1000):
    with tf.GradientTape() as tape:
        z = bubble_sort(x, learned_comparator)
        loss = tf.nn.l2_loss(z - actual_data)
    print(loss)
    var_list = learned_comparator.trainable_variables
    grads = tape.gradient(loss, var_list)
    opt.apply_gradients(zip(grads, var_list))


tf.Tensor(8.614502, shape=(), dtype=float32)
tf.Tensor(8.602819, shape=(), dtype=float32)
tf.Tensor(8.59146, shape=(), dtype=float32)
tf.Tensor(8.580219, shape=(), dtype=float32)
tf.Tensor(8.56898, shape=(), dtype=float32)
tf.Tensor(8.557745, shape=(), dtype=float32)
tf.Tensor(8.546511, shape=(), dtype=float32)
tf.Tensor(8.535281, shape=(), dtype=float32)
tf.Tensor(8.524055, shape=(), dtype=float32)
tf.Tensor(8.5128, shape=(), dtype=float32)
tf.Tensor(8.501542, shape=(), dtype=float32)
tf.Tensor(8.490296, shape=(), dtype=float32)
tf.Tensor(8.47906, shape=(), dtype=float32)
tf.Tensor(8.467796, shape=(), dtype=float32)
tf.Tensor(8.455988, shape=(), dtype=float32)
tf.Tensor(8.4439945, shape=(), dtype=float32)
tf.Tensor(8.431954, shape=(), dtype=float32)
tf.Tensor(8.419878, shape=(), dtype=float32)
tf.Tensor(8.407517, shape=(), dtype=float32)
tf.Tensor(8.394468, shape=(), dtype=float32)
tf.Tensor(8.3811245, shape=(), dtype=float32)
tf.Tensor(8.367964, shape=(), dtype=float32)
tf.Tensor(8.3

tf.Tensor(5.950186, shape=(), dtype=float32)
tf.Tensor(5.936161, shape=(), dtype=float32)
tf.Tensor(5.9220877, shape=(), dtype=float32)
tf.Tensor(5.9079485, shape=(), dtype=float32)
tf.Tensor(5.893783, shape=(), dtype=float32)
tf.Tensor(5.879616, shape=(), dtype=float32)
tf.Tensor(5.865402, shape=(), dtype=float32)
tf.Tensor(5.851142, shape=(), dtype=float32)
tf.Tensor(5.836833, shape=(), dtype=float32)
tf.Tensor(5.8224387, shape=(), dtype=float32)
tf.Tensor(5.808523, shape=(), dtype=float32)
tf.Tensor(5.794501, shape=(), dtype=float32)
tf.Tensor(5.780468, shape=(), dtype=float32)
tf.Tensor(5.7664156, shape=(), dtype=float32)
tf.Tensor(5.7523413, shape=(), dtype=float32)
tf.Tensor(5.7382317, shape=(), dtype=float32)
tf.Tensor(5.7240787, shape=(), dtype=float32)
tf.Tensor(5.7098794, shape=(), dtype=float32)
tf.Tensor(5.695656, shape=(), dtype=float32)
tf.Tensor(5.6814294, shape=(), dtype=float32)
tf.Tensor(5.6672025, shape=(), dtype=float32)
tf.Tensor(5.652959, shape=(), dtype=float32)


tf.Tensor(2.946125, shape=(), dtype=float32)
tf.Tensor(2.9282854, shape=(), dtype=float32)
tf.Tensor(2.9104905, shape=(), dtype=float32)
tf.Tensor(2.8928356, shape=(), dtype=float32)
tf.Tensor(2.8752415, shape=(), dtype=float32)
tf.Tensor(2.8577082, shape=(), dtype=float32)
tf.Tensor(2.8402457, shape=(), dtype=float32)
tf.Tensor(2.8229, shape=(), dtype=float32)
tf.Tensor(2.8058364, shape=(), dtype=float32)
tf.Tensor(2.7889352, shape=(), dtype=float32)
tf.Tensor(2.7723017, shape=(), dtype=float32)
tf.Tensor(2.75579, shape=(), dtype=float32)
tf.Tensor(2.7394018, shape=(), dtype=float32)
tf.Tensor(2.723153, shape=(), dtype=float32)
tf.Tensor(2.7070317, shape=(), dtype=float32)
tf.Tensor(2.6910408, shape=(), dtype=float32)
tf.Tensor(2.675521, shape=(), dtype=float32)
tf.Tensor(2.6602604, shape=(), dtype=float32)
tf.Tensor(2.6451366, shape=(), dtype=float32)
tf.Tensor(2.6301692, shape=(), dtype=float32)
tf.Tensor(2.6153824, shape=(), dtype=float32)
tf.Tensor(2.6007304, shape=(), dtype=float

tf.Tensor(1.3551619, shape=(), dtype=float32)
tf.Tensor(1.3512442, shape=(), dtype=float32)
tf.Tensor(1.347346, shape=(), dtype=float32)
tf.Tensor(1.3434651, shape=(), dtype=float32)
tf.Tensor(1.3395997, shape=(), dtype=float32)
tf.Tensor(1.3357493, shape=(), dtype=float32)
tf.Tensor(1.3319293, shape=(), dtype=float32)
tf.Tensor(1.3281424, shape=(), dtype=float32)
tf.Tensor(1.3243711, shape=(), dtype=float32)
tf.Tensor(1.3206136, shape=(), dtype=float32)
tf.Tensor(1.31687, shape=(), dtype=float32)
tf.Tensor(1.3131403, shape=(), dtype=float32)
tf.Tensor(1.3094232, shape=(), dtype=float32)
tf.Tensor(1.3057189, shape=(), dtype=float32)
tf.Tensor(1.302027, shape=(), dtype=float32)
tf.Tensor(1.2983459, shape=(), dtype=float32)
tf.Tensor(1.2946754, shape=(), dtype=float32)
tf.Tensor(1.2910151, shape=(), dtype=float32)
tf.Tensor(1.2873642, shape=(), dtype=float32)
tf.Tensor(1.2837267, shape=(), dtype=float32)
tf.Tensor(1.2800992, shape=(), dtype=float32)
tf.Tensor(1.2764795, shape=(), dtype=f

tf.Tensor(0.71648675, shape=(), dtype=float32)
tf.Tensor(0.7135684, shape=(), dtype=float32)
tf.Tensor(0.71064746, shape=(), dtype=float32)
tf.Tensor(0.7077364, shape=(), dtype=float32)
tf.Tensor(0.70482814, shape=(), dtype=float32)
tf.Tensor(0.7019225, shape=(), dtype=float32)
tf.Tensor(0.69901997, shape=(), dtype=float32)
tf.Tensor(0.6961206, shape=(), dtype=float32)
tf.Tensor(0.69322455, shape=(), dtype=float32)
tf.Tensor(0.69033295, shape=(), dtype=float32)
tf.Tensor(0.68744355, shape=(), dtype=float32)
tf.Tensor(0.6845613, shape=(), dtype=float32)
tf.Tensor(0.6816841, shape=(), dtype=float32)
tf.Tensor(0.6788125, shape=(), dtype=float32)
tf.Tensor(0.6759445, shape=(), dtype=float32)
tf.Tensor(0.6730809, shape=(), dtype=float32)
tf.Tensor(0.67022187, shape=(), dtype=float32)
tf.Tensor(0.6673676, shape=(), dtype=float32)
tf.Tensor(0.6645199, shape=(), dtype=float32)
tf.Tensor(0.6616778, shape=(), dtype=float32)
tf.Tensor(0.6588395, shape=(), dtype=float32)
tf.Tensor(0.656007, shape=

tf.Tensor(0.31263167, shape=(), dtype=float32)
tf.Tensor(0.31123933, shape=(), dtype=float32)
tf.Tensor(0.3098598, shape=(), dtype=float32)
tf.Tensor(0.30862337, shape=(), dtype=float32)
tf.Tensor(0.3071483, shape=(), dtype=float32)
tf.Tensor(0.3058139, shape=(), dtype=float32)
tf.Tensor(0.30448186, shape=(), dtype=float32)
tf.Tensor(0.30315372, shape=(), dtype=float32)
tf.Tensor(0.3018355, shape=(), dtype=float32)
tf.Tensor(0.30052578, shape=(), dtype=float32)
tf.Tensor(0.29921925, shape=(), dtype=float32)
tf.Tensor(0.2979167, shape=(), dtype=float32)
tf.Tensor(0.2966212, shape=(), dtype=float32)
tf.Tensor(0.2953369, shape=(), dtype=float32)
tf.Tensor(0.2940564, shape=(), dtype=float32)
tf.Tensor(0.29278517, shape=(), dtype=float32)
tf.Tensor(0.2915201, shape=(), dtype=float32)
tf.Tensor(0.2902639, shape=(), dtype=float32)
tf.Tensor(0.28903913, shape=(), dtype=float32)
tf.Tensor(0.2878207, shape=(), dtype=float32)
tf.Tensor(0.28660452, shape=(), dtype=float32)
tf.Tensor(0.2853909, sha

In [20]:
z = bubble_sort(x, learned_comparator)
z = tf.round(z)
print(z - actual_data)

tf.Tensor(
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)
