In [3]:
import numpy as np
import scipy.optimize
import matplotlib.pyplot as plt
import functools
import riot # riot / ot homemade library

# Sinkhorn vs log stabilized Sinkhorn

size definitions for the problem : number of points in the origin distribution and target distribution

In [4]:
n, m = 400, 500

## Consistency tests

we check that both algorithms : sinkhorn and sinkhor_stab lead to the same results for different randomly generated transport problems

In [6]:
seed_iterator = range(100)
for seed in  seed_iterator:
#     to simplifie further random generation
    rand_gen = np.random.RandomState(seed)
    
#     generate departure and arrival distributions
    mu = rand_gen.rand(n,1)
    nu = rand_gen.rand(m,1)
    mu /= mu.sum()
    nu /= nu.sum()
    
#     generate random coordinates for the departure and arrival points
#     the transport cost will be the distance between points
    x1 = rand_gen.rand(n,1,2)
    x2 = rand_gen.rand(1,m,2)
    C = ((x1-x2)**2).sum(2)**0.5
    
#     run both algorithms
    Pi, u, v, regul_cost , _ = riot.sinkhorn(C, mu, nu)
    Pi_s, f, g, regul_cost_s , _ = riot.sinkhorn_stab(C, mu, nu)
    
#     checking the similarity of the results
    dPi = np.linalg.norm(Pi - Pi_s)
    du = np.linalg.norm(np.exp(f / 1) - u)
    dv = np.linalg.norm(np.exp(g / 1) - v)
    dC = np.linalg.norm(regul_cost-regul_cost_s) 
    assert dPi + du + dv + dC < 1e-8, f'failed for seed {seed}'
print(f' no problem encountered for the {len(seed_iterator)} seeds')

 no problem encountered for the 100 seeds


## time test

In [7]:
def sinkhorn_run():
    rand_gen = np.random.RandomState()
    mu = rand_gen.rand(n,1)
    nu = rand_gen.rand(m,1)
    mu /= mu.sum()
    nu /= nu.sum()
    x1 = rand_gen.rand(n,1,2)
    x2 = rand_gen.rand(1,m,2)
    C = ((x1-x2)**2).sum(2)**0.5
    Pi, u, v, regul_cost , _ = riot.sinkhorn(C, mu, nu)
    
def sinkhorn_stab_run():
    rand_gen = np.random.RandomState()
    mu = rand_gen.rand(n,1)
    nu = rand_gen.rand(m,1)
    mu /= mu.sum()
    nu /= nu.sum()
    x1 = rand_gen.rand(n,1,2)
    x2 = rand_gen.rand(1,m,2)
    C = ((x1-x2)**2).sum(2)**0.5
    Pi, f, g, regul_cost , _ = riot.sinkhorn_stab(C, mu, nu)

In [8]:
print('for the normal sinkhorn, the time taken for a run is :')
res_sink = %timeit -o sinkhorn_run
print('\nfor the stabilized sinkhorn, the time taken for a run is :')
res_stab = %timeit -o sinkhorn_stab_run
print()

r = res_sink.average / res_stab.average
if r > 1 :
    print(f'we see an amelioration of {int(100 * r - 100)}% in computation time')
else :
    print(f'we see a loss of {int(100 * r-100)}% in computation time')

for the normal sinkhorn, the time taken for a run is :
28.4 ns ± 5.21 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

for the stabilized sinkhorn, the time taken for a run is :
24.5 ns ± 1.89 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

we see an amelioration of 16% in computation time
