In [1]:
import tensorflow as tf
import numpy as np
import sys
import ot as pot

from tf_sinkhorn import sinkhorn_knopp_tf, ground_distance_tf

In [2]:
def sinkhorn_knopp_tf(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, **kwargs):
    K = tf.exp(-M/reg)
    #u = tf.ones(a.shape)/tf.reduce_sum(a,axis=-1,keepdims=True)
    u = tf.ones(a.shape)/tf.cast(tf.shape(a)[1],tf.float32)
    v = tf.ones(b.shape)/tf.cast(tf.shape(b)[1],tf.float32)
    uprev = tf.ones(a.shape)/tf.reduce_sum(a,axis=-1,keepdims=True)
    vprev = tf.ones(b.shape)/tf.reduce_sum(b,axis=-1,keepdims=True)
    Kp = tf.expand_dims(1/a,-1)*K
    
    err = tf.Variable(1.)
    cpt = tf.Variable(0)
    

    
    mycond = lambda err, cpt, Kp, u, v, uprev, vprev : tf.logical_and(tf.less(cpt, numItermax),tf.greater(err,stopThr))

    def loopfn(err, cpt, Kp, u, v, uprev, vprev):
        uprev = u
        vprev = v
        
        KtransposeU = tf.squeeze(tf.matmul(K,tf.expand_dims(u,-1),transpose_a=True))
        v = b / KtransposeU
        u = 1/tf.squeeze(tf.matmul(Kp, tf.expand_dims(v,-1)))
        
        
        error_cond = tf.reduce_any(tf.is_nan(u))
        error_cond = tf.logical_or(error_cond, tf.reduce_any(tf.is_nan(v)))
        error_cond = tf.logical_or(error_cond, tf.reduce_any(tf.is_inf(u)))
        error_cond = tf.logical_or(error_cond, tf.reduce_any(tf.is_inf(v)))
        error_cond = tf.logical_or(error_cond, tf.reduce_any(tf.equal(KtransposeU,0)))
        
        def error_function_true():
            if verbose:
                print("Reached machine precision")
            return tf.Variable(numItermax), uprev, vprev
        
        def error_function_false():
            return tf.Variable(numItermax), u, v
        
        cpt, u, v = tf.cond(error_cond,error_function_true,error_function_false)
        
        
        return err, cpt, Kp, u, v, uprev, vprev
    
    this = tf.while_loop(mycond, loopfn,[err, cpt, Kp, u, v, uprev, vprev])

    u = this[3]
    v = this[4]
    
    return tf.expand_dims(u,-1)*K*tf.expand_dims(v,-2)

In [123]:
def ground_distance_tf(pointsa,pointsb,return_gradients=False,clip=True,epsilon=1e-2):
    
    a_dim = pointsa.shape[-2]
    b_dim = pointsb.shape[-2]
    
    amat = tf.tile(tf.expand_dims(pointsa,2),[1,1,b_dim,1])
    bmat = tf.tile(tf.expand_dims(pointsb,1),[1,a_dim,1,1])
    
    diffmat = bmat - amat
    dist = tf.norm(diffmat,axis=3)

    zerogradients = tf.fill(amat.shape,0.)
    if clip:
        clipentries = tf.tile(tf.greater(epsilon, tf.expand_dims(dist,-1)),[1,1,1,2])
        diffmat = tf.where(clipentries,zerogradients,diffmat)
        gradients = tf.where(clipentries, zerogradients, diffmat/tf.expand_dims(dist,-1))
    elif return_gradients:
        gradients = diffmat / tf.expand_dims(dist,-1)

    if return_gradients:
        return tf.norm(diffmat,axis=3), gradients
    else:
        return tf.norm(bmat - amat,axis=3)

In [126]:
aval = np.ones((10,4,2))*0.8
#aval[:,:,0] = 0
bval = np.ones((10,3,2))

with tf.Session() as sess:
    a = tf.placeholder(tf.float32, shape=(10,4,2))
    b = tf.placeholder(tf.float32, shape=(10,3,2))
    result = sess.run(ground_distance_tf(a,b,return_gradients=True),feed_dict={a:aval,b:bval})
print(result[0][0])
print(result[1][0])


[[0.2828427 0.2828427 0.2828427]
 [0.2828427 0.2828427 0.2828427]
 [0.2828427 0.2828427 0.2828427]
 [0.2828427 0.2828427 0.2828427]]
[[[0.70710677 0.70710677]
  [0.70710677 0.70710677]
  [0.70710677 0.70710677]]

 [[0.70710677 0.70710677]
  [0.70710677 0.70710677]
  [0.70710677 0.70710677]]

 [[0.70710677 0.70710677]
  [0.70710677 0.70710677]
  [0.70710677 0.70710677]]

 [[0.70710677 0.70710677]
  [0.70710677 0.70710677]
  [0.70710677 0.70710677]]]


In [131]:
aweights = tf.placeholder(tf.float32, shape=(10,4))
bweights = tf.placeholder(tf.float32, shape=(10,3))

dist, grads = ground_distance_tf(a,b,return_gradients=True)
match = sinkhorn_knopp_tf(aweights,bweights,dist,0.1)
tf.matmul

In [139]:
dist

<tf.Tensor 'norm_119/Squeeze:0' shape=(10, 4, 3) dtype=float32>

In [135]:
tf.transpose(grads,[0,3,1,2])

<tf.Tensor 'transpose_1:0' shape=(10, 2, 4, 3) dtype=float32>

In [133]:
dist

<tf.Tensor 'norm_119/Squeeze:0' shape=(10, 4, 3) dtype=float32>

In [58]:
a.shape

TensorShape([Dimension(10), Dimension(4)])

In [2]:
([None] + [100,1])[:-1]

[None, 100]

In [74]:
tf.constant(1,shape=(a.shape[0]))

TypeError: 'Dimension' object is not iterable

In [9]:
aval = np.ones((10,4,2))/4
bval = np.ones((10,3,2))/3

In [12]:
aval = np.ones((10,4))/4
bval = np.ones((10,3))/3
Mval = np.random.rand(10,4,3)
with tf.Session() as sess:
    
    a = tf.placeholder(tf.float32, shape=(10,4))
    b = tf.placeholder(tf.float32, shape=(10,3))
    M = tf.placeholder(tf.float32, shape=(10,4,3))
    reg = tf.placeholder(tf.float32, shape=())
    mysink = sinkhorn_knopp_tf(a,b,M,reg,verbose=True)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    this = sess.run(mysink,feed_dict={a:aval,b:bval,M:Mval,reg:0.1})
    print(this)
    print(np.sum(this,-1))
    print(np.sum(this,-2))

Reached machine precision
[[[7.63723180e-02 1.26231074e-01 4.73966151e-02]
  [2.97494363e-02 7.21025467e-02 1.48148015e-01]
  [9.12804157e-02 1.51159942e-01 7.55964732e-03]
  [1.84667572e-01 5.87658361e-02 6.56660274e-03]]

 [[6.57330547e-03 3.74736525e-02 2.05953062e-01]
  [4.06882973e-05 2.49647543e-01 3.11787007e-04]
  [2.23503739e-04 2.17766151e-01 3.20103467e-02]
  [1.34169832e-01 1.15114786e-02 1.04318686e-01]]

 [[6.24328069e-02 6.01601265e-02 1.27407059e-01]
  [4.94208466e-03 2.41968453e-01 3.08947102e-03]
  [2.40276188e-01 9.66886710e-03 5.49369943e-05]
  [1.98334143e-01 5.08492328e-02 8.16620421e-04]]

 [[8.30035657e-02 1.63842797e-01 3.15364660e-03]
  [5.62989386e-04 1.00513995e-01 1.48923010e-01]
  [7.02500492e-02 4.50111367e-02 1.34738803e-01]
  [1.34918287e-01 1.27353165e-02 1.02346413e-01]]

 [[3.72903980e-02 2.07487777e-01 5.22180554e-03]
  [3.73499431e-02 1.65102929e-01 4.75471132e-02]
  [1.11016199e-01 3.85640934e-02 1.00419715e-01]
  [1.84401567e-03 5.24154678e-02 1.

In [183]:
plan = ot.sinkhorn(aval[0],bval[0],Mval[0],0.1,numItermax=10)
print(plan)
print(np.sum(plan, axis=1))
print(np.sum(plan, axis=0))

[[1.16736461e-03 2.13563522e-01 3.52691129e-02]
 [2.76347778e-03 1.70987247e-02 2.30137798e-01]
 [6.26421689e-02 1.09084370e-01 7.82734609e-02]
 [2.48035199e-01 1.91196078e-03 5.28406159e-05]]
[0.25 0.25 0.25 0.25]
[0.31460821 0.34165858 0.34373321]


In [None]:
this

In [93]:
this.shape

(10, 4, 3)

In [36]:
np.random.rand(3,3)

array([[0.73399129, 0.8018214 , 0.10227172],
       [0.43301417, 0.62922426, 0.25735927],
       [0.05989077, 0.96986245, 0.37620816]])