<u>Theme</u>

Some kind of implementaion for Sinkhorn iteration

S.Ukai, Oct. 20, 2018

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pylab as plt
import itertools
import sys

%matplotlib inline

# 1.  Simple implementation of sinkhorn iteration

## 1-1 the most simple way:

In [None]:
def baby_sinkhorn_iteration(Mxy, p, q, eps, L):
    Kxy = np.exp(-Mxy/eps) 
    Nx, Ny = Mxy.shape
    v = np.ones((Ny, 1))
    for k1 in range(L):
        u = p/np.dot(Kxy, v)
        v = q/np.dot(Kxy.T, u)

    P = np.dot(np.diag(u.squeeze()), np.dot(Kxy, np.diag(v.squeeze())))
    return P, u, v

In [None]:
N  = 2**10
X = np.random.randn(N, 2**3)
Y = np.random.randn(N, 2**3)
Mxy = np.sum(np.abs(X[:, np.newaxis, :] - Y[np.newaxis, :, :]), axis=2)

In [None]:
p = np.ones((N, 1))/N
q = np.ones((N, 1))/N

In [None]:
L = 2**7
eps = 1

In [None]:
P, u, v = baby_sinkhorn_iteration(Mxy, p, q, eps, L)

In [None]:
phat = np.sum(P, axis=1, keepdims=True)
qhat = np.sum(P, axis=0, keepdims=True).T
err = np.sum(np.abs(phat - p)) + np.sum(np.abs(qhat - q))
print('err = ', err)

## 1-2. Another simple implementation:

In [None]:
def log_sinkhorn_iteration(Mxy, p, q, eps, L):
    
    q = q.T

    alpha = np.mean(Mxy, axis=1, keepdims=True)
    beta = np.mean(Mxy, axis=0, keepdims=True)
        
    for k1 in range(L):
        
        beta = eps + eps * np.log(q) + \
            - eps * np.log( np.sum(np.exp(-(Mxy-alpha)/eps) , axis=0, keepdims=True)  )
        alpha = eps + eps * np.log(p) + \
            - eps * np.log( np.sum(np.exp(-(Mxy-beta)/eps) , axis=1, keepdims=True)  )        
        P = np.exp(-(Mxy-alpha-beta)/eps -1)
        
        qhat = np.sum(P, axis=0, keepdims=True)
        err = np.sum(np.abs(qhat - q))
        print(err)
        
    return P, None, None

In [None]:
N  = 2**2
X = np.random.randn(N, 2**3)
Y = np.random.randn(N, 2**3)
Mxy = np.sum(np.abs(X[:, np.newaxis, :] - Y[np.newaxis, :, :]), axis=2)

In [None]:
p = np.ones((N, 1))/N
q = np.ones((N, 1))/N

In [None]:
L = 2**3
eps = 1.0

In [None]:
P, u, v = log_sinkhorn_iteration(Mxy, p, q, eps, L)

# 2.  Robust implementation of sinkhorn iteration

# 2-1. @NUMPY

In [None]:
def robust_sinkhorn_iteration(Mxy, p, q, eps, L, tol=1e-8):
    
    q = q.T

    alpha = np.mean(Mxy, axis=1, keepdims=True)
    beta = np.mean(Mxy, axis=0, keepdims=True)
    
    for k1 in range(L):
        
        print(np.mean(alpha), np.mean(beta))
        
        delta_row = np.min(Mxy - alpha, axis=0, keepdims=True)
        
        diff = np.min((Mxy-alpha-delta_row), axis=0)
        assert np.all(diff == 0), diff
        beta = eps + eps * np.log(q) + delta_row \
            - eps * np.log( np.sum(np.exp(-(Mxy-alpha-delta_row)/eps) , axis=0, keepdims=True)  )
        
        delta_col = np.min(Mxy - beta, axis=1, keepdims=True)
        
        diff = np.min((Mxy-beta-delta_col), axis=1)
        #import pdb; pdb.set_trace() 
        assert np.all(diff == 0), diff
        alpha = eps + eps * np.log(p) + delta_col \
            - eps * np.log( np.sum(np.exp(-(Mxy-beta-delta_col)/eps) , axis = 1, keepdims=True)  )
        
        P = np.exp(-(Mxy-alpha-beta)/eps -1)
        qhat = np.sum(P, axis=0, keepdims=True)
        err = np.sum(np.abs(qhat - q))
        
        if err < tol:
            break
    Nitr = k1 + 1
        
    return P, None, None, Nitr


In [None]:
N  = 2**7
X = np.random.randn(N, 2**3)
Y = np.random.randn(N, 2**3)
Mxy = np.sum(np.abs(X[:, np.newaxis, :] - Y[np.newaxis, :, :]), axis=2)

In [None]:
mut_dist = np.mean(np.sum(np.abs(X-Y) , axis=1))

In [None]:
p = np.ones((N, 1))/N
q = np.ones((N, 1))/N

In [None]:
L = 2**13
eps = mut_dist * 0.01

In [None]:
P, _, _ , Nitr = robust_sinkhorn_iteration(Mxy, p, q, eps, L, tol = 1e-4)

In [None]:
Nitr

In [None]:
phat = np.sum(P, axis=1, keepdims=True)
qhat = np.sum(P, axis=0, keepdims=True).T
err = np.sum(np.abs(phat - p)) + np.sum(np.abs(qhat - q))
print('err = ', err)

In [None]:
list_eps = np.logspace(-2, 0, 10) * mut_dist
Nitr = [0, ] * 10
for k1, eps in enumerate(list_eps):
    P, _, _ , Nitr[k1] = robust_sinkhorn_iteration(Mxy, p, q, eps, L, tol = 1e-4)

In [None]:
plt.plot(list_eps/mut_dist, Nitr, '--x')
plt.xscale('log')
plt.yscale('log')

# 2-2. @Tensorflow

In [None]:
sess = None

### define a network

In [None]:
def _sinkhorn_iteration(_Mxy, _p, _q, _eps, _L, _tol):
    
    _q = tf.transpose(_q)

    _alpha = tf.reduce_mean(_Mxy, axis=1, keepdims=True) * 0
    _beta = tf.reduce_mean(_Mxy, axis=0, keepdims=True) * 0
    
    def terminate_cond(_itr, _alpha, _beta):
        _P = tf.exp(-(_Mxy-_alpha-_beta)/_eps -1)
        _qhat = tf.reduce_sum(_P, axis=0, keepdims=True)
        _err = tf.reduce_sum(tf.abs(_qhat - _q))
        return tf.logical_and(
            tf.less(_itr, _L),
            tf.greater(_err, _tol))
        
    def update_sinkhorn(_itr, _alpha, _beta):
        
        _delta_row = tf.reduce_min(_Mxy - _alpha, axis=0, keepdims=True)
        _beta = _eps + _eps * tf.log(_q) + _delta_row \
              - _eps * tf.log( tf.reduce_sum(tf.exp(-(_Mxy-_alpha-_delta_row)/_eps) , axis=0, keepdims=True)  )
        
        _delta_col = tf.reduce_min(_Mxy - _beta, axis=1, keepdims=True)
        _alpha = _eps + _eps * tf.log(_p) + _delta_col \
            - _eps * tf.log( tf.reduce_sum(tf.exp(-(_Mxy-_beta-_delta_col)/_eps) , axis = 1, keepdims=True)  )
        
        return _itr+1, _alpha, _beta
    
    _Nitr, _alpha, _beta = tf.while_loop(
        terminate_cond,
        update_sinkhorn,
        [0, _alpha, _beta])

    _dist = tf.reduce_sum(_p * _alpha) + tf.reduce_sum(_q * _beta) - _eps
    _P = tf.exp(-(_Mxy-_alpha-_beta)/_eps -1)
    return _P, _Nitr, _dist

### session begin

In [None]:
tf.reset_default_graph()
if sess is not None:
    sess.close()
sess = tf.InteractiveSession()

In [None]:
_Mxy = tf.placeholder(shape=(None, None), dtype=tf.float64)
_p = tf.placeholder(shape=(None, 1), dtype=tf.float64)
_q = tf.placeholder(shape=(None, 1), dtype=tf.float64)
_eps = tf.placeholder(shape=(), dtype=tf.float64)
_L = tf.placeholder(shape=(), dtype=tf.int32)
_tol = tf.placeholder(shape=(), dtype=tf.float64)

In [None]:
_P, _Nitr, _dist = _sinkhorn_iteration(_Mxy, _p, _q, _eps, _L, _tol)

### genenrate a dataset for test

In [None]:
N  = 2**7
X = np.random.randn(N, 2**3)
Y = np.random.randn(N, 2**3)
Mxy = np.sum(np.abs(X[:, np.newaxis, :] - Y[np.newaxis, :, :]), axis=2)

In [None]:
mut_dist = np.mean(np.sum(np.abs(X-Y) , axis=1))

In [None]:
p = np.ones((N, 1))/N
q = np.ones((N, 1))/N

In [None]:
L = 2**13
eps = mut_dist * 0.005
tol = 1e-4

In [None]:
feed_dict = {_Mxy: Mxy, _p: p, _q: q, _eps: eps, _L: L, _tol: tol }

### evaluate

In [None]:
P, Nitr, dist = sess.run([_P, _Nitr, _dist], feed_dict)

In [None]:
phat = np.sum(P, axis=1, keepdims=True)
qhat = np.sum(P, axis=0, keepdims=True).T
err = np.sum(np.abs(phat - p)) + np.sum(np.abs(qhat - q))
print('err = ', err)

In [None]:
Nitr