From f98bb6c1eb497df53a83df119b4d30030d796223 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 13:40:55 +0100 Subject: [PATCH 01/13] new test gpu --- ot/backend.py | 15 +++++++++++++++ test/test_ot.py | 25 +++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/ot/backend.py b/ot/backend.py index d3df44c1e..df3690340 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -102,6 +102,7 @@ class Backend(): __name__ = None __type__ = None + __type_list__ = None rng_ = None @@ -663,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + __type_list__ = [np.array(1,dtype=np.float32), + np.array(1,dtype=np.float64)] rng_ = np.random.RandomState() @@ -888,12 +891,16 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + __type_list__ = None rng_ = None def __init__(self): self.rng_ = jax.random.PRNGKey(42) + self.__type_list__= [jnp.array(1,dtype=np.float32), + jnp.array(1,dtype=np.float64)] + def to_numpy(self, a): return np.array(a) @@ -1130,6 +1137,7 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + __type_list__ = None rng_ = None @@ -1138,6 +1146,13 @@ def __init__(self): self.rng_ = torch.Generator() self.rng_.seed() + self.__type_list__ = [torch.tensor(1,dtype=torch.float32), + torch.tensor(1,dtype=torch.float64)] + + if torch.cuda.is_available(): + self.__type_list_.append(torch.tensor(1,dtype=torch.float32, device='cuda')) + self.__type_list_.append(torch.tensor(1,dtype=torch.float64, device='cuda')) + from torch.autograd import Function # define a function that takes inputs val and grads diff --git a/test/test_ot.py b/test/test_ot.py index 5bfde1df2..9ac8ff9a6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -77,6 +77,31 @@ def test_emd2_backends(nx): np.allclose(val, nx.to_numpy(valb)) +def test_emd_emd2_types_devices(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + for tp in nx.__type_list__: + + ab = nx.from_numpy(a, type_as=tp) + Mb = nx.from_numpy(M, type_as=tp) + + Gb = ot.emd(ab, ab, Mb) + + w = ot.emd2(ab, ab, Mb) + + assert Gb.dtype == Mb.dtype + if not str(nx)=='numpy': + assert w.dtype == Mb.dtype + + def test_emd2_gradients(): n_samples = 100 n_features = 2 From e620a5fe72efee267128dbc602a2c8c8694cb172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 13:43:09 +0100 Subject: [PATCH 02/13] pep 8 of couse --- ot/backend.py | 16 ++++++++-------- test/test_ot.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index df3690340..3cbd0f926 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -664,8 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray - __type_list__ = [np.array(1,dtype=np.float32), - np.array(1,dtype=np.float64)] + __type_list__ = [np.array(1, dtype=np.float32), + np.array(1, dtype=np.float64)] rng_ = np.random.RandomState() @@ -898,8 +898,8 @@ class JaxBackend(Backend): def __init__(self): self.rng_ = jax.random.PRNGKey(42) - self.__type_list__= [jnp.array(1,dtype=np.float32), - jnp.array(1,dtype=np.float64)] + self.__type_list__ = [jnp.array(1, dtype=np.float32), + jnp.array(1, dtype=np.float64)] def to_numpy(self, a): return np.array(a) @@ -1146,12 +1146,12 @@ def __init__(self): self.rng_ = torch.Generator() self.rng_.seed() - self.__type_list__ = [torch.tensor(1,dtype=torch.float32), - torch.tensor(1,dtype=torch.float64)] + self.__type_list__ = [torch.tensor(1, dtype=torch.float32), + torch.tensor(1, dtype=torch.float64)] if torch.cuda.is_available(): - self.__type_list_.append(torch.tensor(1,dtype=torch.float32, device='cuda')) - self.__type_list_.append(torch.tensor(1,dtype=torch.float64, device='cuda')) + self.__type_list_.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list_.append(torch.tensor(1, dtype=torch.float64, device='cuda')) from torch.autograd import Function diff --git a/test/test_ot.py b/test/test_ot.py index 9ac8ff9a6..677f41ffd 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -98,7 +98,7 @@ def test_emd_emd2_types_devices(nx): w = ot.emd2(ab, ab, Mb) assert Gb.dtype == Mb.dtype - if not str(nx)=='numpy': + if not str(nx) == 'numpy': assert w.dtype == Mb.dtype From 637ef6574506fbfc4d7cc11b5eb011d4927a3919 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 13:51:01 +0100 Subject: [PATCH 03/13] debug torch --- ot/backend.py | 4 ++-- test/test_ot.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 3cbd0f926..144f4fe9d 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1150,8 +1150,8 @@ def __init__(self): torch.tensor(1, dtype=torch.float64)] if torch.cuda.is_available(): - self.__type_list_.append(torch.tensor(1, dtype=torch.float32, device='cuda')) - self.__type_list_.append(torch.tensor(1, dtype=torch.float64, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda')) + self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda')) from torch.autograd import Function diff --git a/test/test_ot.py b/test/test_ot.py index 677f41ffd..7ec1bb304 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -90,6 +90,8 @@ def test_emd_emd2_types_devices(nx): for tp in nx.__type_list__: + print(tp.dtype) + ab = nx.from_numpy(a, type_as=tp) Mb = nx.from_numpy(M, type_as=tp) From 337a44a18eb89c9db194bef2c17ff269797e584f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:03:59 +0100 Subject: [PATCH 04/13] jax with gpu --- ot/backend.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 144f4fe9d..6e055fbf7 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -678,7 +678,7 @@ def from_numpy(self, a, type_as=None): elif isinstance(a, float): return a else: - return a.astype(type_as.dtype) + return jax.put_device(a.astype(type_as.dtype),type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): # No gradients for numpy @@ -898,8 +898,11 @@ class JaxBackend(Backend): def __init__(self): self.rng_ = jax.random.PRNGKey(42) - self.__type_list__ = [jnp.array(1, dtype=np.float32), - jnp.array(1, dtype=np.float64)] + for d in jax.devices(): + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32),d), + jax.device_put(jnp.array(1, dtype=np.float64))] + + def to_numpy(self, a): return np.array(a) From ceed6395f1a544432972b3e1b1e9bacaf661d9dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:06:19 +0100 Subject: [PATCH 05/13] device put --- ot/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 6e055fbf7..e7a10a00c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -678,7 +678,7 @@ def from_numpy(self, a, type_as=None): elif isinstance(a, float): return a else: - return jax.put_device(a.astype(type_as.dtype),type_as.device_buffer.device()) + return jax.device-put(a.astype(type_as.dtype),type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): # No gradients for numpy From 81d572cee6a997052183bae1b42ce810dbb25bac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:06:29 +0100 Subject: [PATCH 06/13] device put --- ot/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index e7a10a00c..2bf95be41 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -678,7 +678,7 @@ def from_numpy(self, a, type_as=None): elif isinstance(a, float): return a else: - return jax.device-put(a.astype(type_as.dtype),type_as.device_buffer.device()) + return jax.device_put(a.astype(type_as.dtype),type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): # No gradients for numpy From b4f9a860abe2839569c69e0c45fc7f75a7829281 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:09:52 +0100 Subject: [PATCH 07/13] it works --- ot/backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 2bf95be41..79f23e2f9 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -678,7 +678,7 @@ def from_numpy(self, a, type_as=None): elif isinstance(a, float): return a else: - return jax.device_put(a.astype(type_as.dtype),type_as.device_buffer.device()) + return a.astype(type_as.dtype) def set_gradients(self, val, inputs, grads): # No gradients for numpy @@ -900,7 +900,7 @@ def __init__(self): for d in jax.devices(): self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32),d), - jax.device_put(jnp.array(1, dtype=np.float64))] + jax.device_put(jnp.array(1, dtype=np.float64),d)] @@ -911,7 +911,7 @@ def from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return jnp.array(a).astype(type_as.dtype) + return jax.device_put(jnp.array(a).astype(type_as.dtype),type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree From e0619272ea017a3f9d7a6c7a0d9f7896e2839c4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:46:36 +0100 Subject: [PATCH 08/13] emd1d and emd2_1d working --- ot/backend.py | 10 +++++----- ot/lp/solver_1d.py | 10 +++++----- test/test_1d_solver.py | 26 ++++++++++++++++++++++++++ test/test_ot.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 79f23e2f9..55e10d373 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -899,10 +899,8 @@ def __init__(self): self.rng_ = jax.random.PRNGKey(42) for d in jax.devices(): - self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32),d), - jax.device_put(jnp.array(1, dtype=np.float64),d)] - - + self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d), + jax.device_put(jnp.array(1, dtype=np.float64), d)] def to_numpy(self, a): return np.array(a) @@ -911,7 +909,7 @@ def from_numpy(self, a, type_as=None): if type_as is None: return jnp.array(a) else: - return jax.device_put(jnp.array(a).astype(type_as.dtype),type_as.device_buffer.device()) + return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device()) def set_gradients(self, val, inputs, grads): from jax.flatten_util import ravel_pytree @@ -1178,6 +1176,8 @@ def to_numpy(self, a): return a.cpu().detach().numpy() def from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return torch.from_numpy(a) else: diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 42554aa3f..422549056 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, perm_b = nx.argsort(x_b_1d) G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), + nx.to_numpy(a[perm_a]).astype(np.float64), + nx.to_numpy(b[perm_b]).astype(np.float64), + nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), + nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), metric=metric, p=p ) @@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: - log = {'cost': cost} + log = {'cost': nx.from_numpy(cost, type_as=x_a)} return G, log return G diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 2c470c298..55e5f5999 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -83,3 +83,29 @@ def test_wasserstein_1d(nx): Xb = nx.from_numpy(X) res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) + + if not str(nx) == 'numpy': + assert res.dtype == xb.dtype diff --git a/test/test_ot.py b/test/test_ot.py index 7ec1bb304..c646ac67c 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -192,6 +192,34 @@ def test_emd_1d_emd2_1d(): ot.emd_1d(u, v, [], []) +def test_emd1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + + assert emd.dtype == xb.dtype + if not str(nx) == 'numpy': + assert emd2.dtype == xb.dtype + + def test_emd_empty(): # test emd and emd2 for simple identity n = 100 From 8cd7e9fe9c7986dc322ba2b4f2671ec5bc9951cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:48:58 +0100 Subject: [PATCH 09/13] emd_1d and emd2_1d done --- test/test_1d_solver.py | 67 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 55e5f5999..77b123401 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -109,3 +109,70 @@ def test_wasserstein_1d_type_devices(nx): if not str(nx) == 'numpy': assert res.dtype == xb.dtype + + +def test_emd_1d_emd2_1d(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd([], [], M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) + np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) + + # check G is similar + np.testing.assert_allclose(G, G_1d, atol=1e-15) + + # check AssertionError is raised if called on non 1d arrays + u = np.random.randn(n, 2) + v = np.random.randn(m, 2) + with pytest.raises(AssertionError): + ot.emd_1d(u, v, [], []) + + +def test_emd1d_type_devices(nx): + + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + + print(tp.dtype) + + xb = nx.from_numpy(x, type_as=tp) + rho_ub = nx.from_numpy(rho_u, type_as=tp) + rho_vb = nx.from_numpy(rho_v, type_as=tp) + + emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) + + emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) + + assert emd.dtype == xb.dtype + if not str(nx) == 'numpy': + assert emd2.dtype == xb.dtype From cb2aa8ee674a42275657c79c1d9c4cd7a8af7190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:55:48 +0100 Subject: [PATCH 10/13] cleanup --- test/test_ot.py | 66 ------------------------------------------------- 1 file changed, 66 deletions(-) diff --git a/test/test_ot.py b/test/test_ot.py index c646ac67c..b8ab06cd5 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -153,72 +153,6 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) -def test_emd_1d_emd2_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1)) - np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) - - # check G is similar - np.testing.assert_allclose(G, G_1d, atol=1e-15) - - # check AssertionError is raised if called on non 1d arrays - u = np.random.randn(n, 2) - v = np.random.randn(m, 2) - with pytest.raises(AssertionError): - ot.emd_1d(u, v, [], []) - - -def test_emd1d_type_devices(nx): - - rng = np.random.RandomState(0) - - n = 10 - x = np.linspace(0, 5, n) - rho_u = np.abs(rng.randn(n)) - rho_u /= rho_u.sum() - rho_v = np.abs(rng.randn(n)) - rho_v /= rho_v.sum() - - for tp in nx.__type_list__: - - print(tp.dtype) - - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) - - emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) - - emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) - - assert emd.dtype == xb.dtype - if not str(nx) == 'numpy': - assert emd2.dtype == xb.dtype - def test_emd_empty(): # test emd and emd2 for simple identity From b4f1422a380d518e16297648a13c21ea5a971226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:56:06 +0100 Subject: [PATCH 11/13] of course --- test/test_ot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_ot.py b/test/test_ot.py index b8ab06cd5..c0f7e2714 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -153,7 +153,6 @@ def test_emd_emd2(): np.testing.assert_allclose(w, 0) - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 From 41f75859ba20bfbd94012d04ee86151e6785befb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 14:58:16 +0100 Subject: [PATCH 12/13] should work on gpu now --- ot/lp/solver_1d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 422549056..8b4d0c3be 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, # ensure that same mass np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), + nx.to_numpy(nx.sum(a, axis=0)), + nx.to_numpy(nx.sum(b, axis=0)), err_msg='a and b vector must have the same sum' ) b = b * nx.sum(a) / nx.sum(b) From ead6527c9fbb8ad7c1005b40a150df3368cf32c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 15:01:00 +0100 Subject: [PATCH 13/13] tests done+ pep8 --- test/test_ot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_ot.py b/test/test_ot.py index c0f7e2714..dc3930af8 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,7 +12,6 @@ import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch -from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch():