Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@

import multiprocessing
import sys

import numpy as np
from scipy.sparse import coo_matrix

from .import cvx

from . import cvx
from .cvx import barycenter
# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
from ..utils import parmap
from .cvx import barycenter
from ..utils import dist
from ..utils import parmap

__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
Expand Down Expand Up @@ -458,7 +458,8 @@ def f(b):
return res


def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)

Expand Down Expand Up @@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None

T_sum = np.zeros((k, d))

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
weights.tolist()):
M_i = dist(X, measure_locations_i)
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
Expand Down Expand Up @@ -651,12 +652,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
if b.ndim == 0 or len(b) == 0:
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]

x_a_1d = x_a.reshape((-1, ))
x_b_1d = x_b.reshape((-1, ))
x_a_1d = x_a.reshape((-1,))
x_b_1d = x_b.reshape((-1,))
perm_a = np.argsort(x_a_1d)
perm_b = np.argsort(x_b_1d)

G_sorted, indices, cost = emd_1d_sorted(a, b,
G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
x_a_1d[perm_a], x_b_1d[perm_b],
metric=metric, p=p)
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
Expand Down
48 changes: 40 additions & 8 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import warnings

import numpy as np
import pytest
from scipy.stats import wasserstein_distance

import ot
from ot.datasets import make_1D_gauss as gauss
import pytest


def test_emd_dimension_mismatch():
Expand Down Expand Up @@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(wass, wass1d_emd2)

# check loss is similar to scipy's implementation for Euclidean metric
wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
np.testing.assert_allclose(wass_sp, wass1d_euc)

# check constraints
np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))

# check G is similar
np.testing.assert_allclose(G, G_1d)
Expand All @@ -92,6 +92,42 @@ def test_emd_1d_emd2_1d():
ot.emd_1d(u, v, [], [])


def test_emd_1d_emd2_1d_with_weights():
# test emd1d gives similar results as emd
n = 20
m = 30
rng = np.random.RandomState(0)
u = rng.randn(n, 1)
v = rng.randn(m, 1)

w_u = rng.uniform(0., 1., n)
w_u = w_u / w_u.sum()

w_v = rng.uniform(0., 1., m)
w_v = w_v / w_v.sum()

M = ot.dist(u, v, metric='sqeuclidean')

G, log = ot.emd(w_u, w_v, M, log=True)
wass = log["cost"]
G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
wass1d = log["cost"]
wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)

# check loss is similar
np.testing.assert_allclose(wass, wass1d)
np.testing.assert_allclose(wass, wass1d_emd2)

# check loss is similar to scipy's implementation for Euclidean metric
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
np.testing.assert_allclose(wass_sp, wass1d_euc)

# check constraints
np.testing.assert_allclose(w_u, G.sum(1))
np.testing.assert_allclose(w_v, G.sum(0))


def test_wass_1d():
# test emd1d gives similar results as emd
n = 20
Expand Down Expand Up @@ -135,7 +171,6 @@ def test_emd_empty():


def test_emd_sparse():

n = 100
rng = np.random.RandomState(0)

Expand Down Expand Up @@ -211,7 +246,6 @@ def test_emd2_multi():


def test_lp_barycenter():

a1 = np.array([1.0, 0, 0])[:, None]
a2 = np.array([0, 0, 1.0])[:, None]

Expand All @@ -228,7 +262,6 @@ def test_lp_barycenter():


def test_free_support_barycenter():

measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
measures_weights = [np.array([1.]), np.array([1.])]

Expand All @@ -244,7 +277,6 @@ def test_free_support_barycenter():

@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
def test_lp_barycenter_cvxopt():

a1 = np.array([1.0, 0, 0])[:, None]
a2 = np.array([0, 0, 1.0])[:, None]

Expand Down