# Implement Sinkhorn iteration with a tolerance terminate condition 

In [1]:
import numpy as np 
import matplotlib.pylab as plt
import tensorflow as tf

%matplotlib inline

In [2]:
sess = None

## 1. define Cuturi function

In [1]:
def sinkhorn_iteration(_K, tol):
    M, N = _K.get_shape().as_list()
    
    epsilon = 0
    
    p = tf.ones((M,1), dtype=tf.float64)/M
    q = tf.ones((N,1), dtype=tf.float64)/N
    def terminate_cond(_itr, _u, _v):
        _phat = _u * tf.matmul(_K, _v)
        _qhat = _v * tf.matmul(tf.transpose(_K), _u)
        _err = tf.reduce_sum(tf.abs(_phat - p)) + tf.reduce_sum(tf.abs(_qhat - q)) 
        return tf.greater(_err, tol)
    def sinkhorn_update(_itr, _u, _v):
        _u = 1/tf.matmul(_K, _v)/M + epsilon
        _v = 1/tf.matmul(tf.transpose(_K), _u)/N + epsilon
        return (_itr+1, _u, _v)
    
    u0 = tf.ones((M,1), dtype=tf.float64)
    v0 = tf.ones((N,1), dtype=tf.float64)
    
    _, _u, _v = tf.while_loop(terminate_cond, sinkhorn_update, [0, u0, v0])
    _P = tf.matmul(
        tf.diag(tf.squeeze(_u)), 
        tf.matmul(_K, tf.diag(tf.squeeze(_v))))
    return _P, _u, _v

In [171]:
def transport_cost(_X, _Y):
    M = _X.get_shape().as_list()[0]
    Naxis = len(_X.get_shape().as_list())
    vecs = []
    for k1 in range(M):
        vecs.append( tf.reduce_sum(tf.abs(_Y - _X[k1,:]), axis= [*range(1, Naxis)]))
    _TC = tf.stack(vecs)
    return _TC

In [172]:
def cuturi_func(_X, _Y, eps_thermal, tol):
    _M = transport_cost(_X, _Y)
    _K = tf.exp(-_M/eps_thermal)
    _, _u, _v = sinkhorn_iteration(_K, tol)
    _alpha = tf.log(_u) * eps_thermal + 1/2
    _beta = tf.log(_v) * eps_thermal + 1/2
    return tf.reduce_mean(_alpha) + tf.reduce_mean(_beta) - eps_thermal

## 2 test

### 2-1.  Sinkhorn iteration can work?

In [173]:
M = 2**4
N = 2**3
K = np.random.rand(M, N)

In [174]:
_K = tf.placeholder(dtype=tf.float64, shape=(M, N))

In [175]:
tol = 1e-8

In [176]:
_P, _u, _v = sinkhorn_iteration(_K, tol)

In [177]:
if sess is not None:
    sess.close()
sess = tf.InteractiveSession()

In [None]:
P, u, v = sess.run((_P, _u, _v), feed_dict={_K: K} )

check if P = diag(u) * K * diag(v)

In [49]:
Pref = np.dot(np.diag(u.squeeze()), np.dot(K, np.diag(v.squeeze())))
assert np.max(np.abs(P - Pref)) < 1e-16

Check if ...
* P * ones(N,1) = ones(M,1)/M 
* P.T * ones(M,1) = ones(N, 1)/N

In [50]:
phat = np.dot(P, np.ones((N, 1)))
p = np.ones((M, 1))/M
qhat = np.dot(P.T, np.ones((M, 1)))
q = np.ones((N, 1))/N

In [51]:
np.sum(np.abs(phat - p)), np.sum(np.abs(qhat-q))

(8.446181323784607e-09, 4.163336342344337e-17)

### 2-2.  Check transport cost

In [82]:
N = 2**1
M = 2**1
Ndim = [2**2, 2**2,]
X = np.random.randn(M,*Ndim)
Y = np.random.randn(N,*Ndim)

In [83]:
_X = tf.placeholder(dtype=tf.float64, shape=(M, *Ndim))
_Y = tf.placeholder(dtype=tf.float64, shape=(N, *Ndim))

In [84]:
_M = transport_cost(_X, _Y)

In [85]:
M = sess.run(_M, feed_dict={_X: X, _Y: Y})

In [97]:
assert np.abs(M[0,0] - np.sum(np.abs(X[0,None]-Y[0,None]))) < 1e-16

### 2-3.  Can Cuturi function be differentiated?

In [137]:
if sess is not None:
    sess.close()
sess = tf.InteractiveSession()

In [148]:
N = 2**5
M = 2**5
Ndim = [2**2, 2**2,]
X = np.random.randn(M,*Ndim)
Y = np.random.randn(N,*Ndim)

In [150]:
eps_thermal = 1.0
tol = 1e-4

In [151]:
_X = tf.placeholder(dtype=tf.float64, shape=(M, *Ndim))
_Y = tf.placeholder(dtype=tf.float64, shape=(N, *Ndim))

In [152]:
_theta = tf.Variable(initial_value=0., dtype=tf.float64)

In [153]:
_f = cuturi_func(_X, _theta * _Y, eps_thermal, tol)

In [163]:
trainer = tf.train.AdadeltaOptimizer(1e-2).minimize(_f)

In [164]:
sess.run(tf.global_variables_initializer())

In [169]:
for k1 in range(10):
    f = sess.run([_f, trainer], feed_dict={_X: X, _Y: Y})
    print(f)

[6.14857607394796, None]
[6.148575939099349, None]
[6.148575803313703, None]
[6.148575670282152, None]
[6.148575573765574, None]
[6.148575476580117, None]
[6.1485753787230415, None]
[6.148575280191494, None]
[6.148575194289385, None]
[6.14857512330403, None]
