Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Backend():

__name__ = None
__type__ = None
__type_list__ = None

rng_ = None

Expand Down Expand Up @@ -663,6 +664,8 @@ class NumpyBackend(Backend):

__name__ = 'numpy'
__type__ = np.ndarray
__type_list__ = [np.array(1, dtype=np.float32),
np.array(1, dtype=np.float64)]

rng_ = np.random.RandomState()

Expand Down Expand Up @@ -888,20 +891,25 @@ class JaxBackend(Backend):

__name__ = 'jax'
__type__ = jax_type
__type_list__ = None

rng_ = None

def __init__(self):
self.rng_ = jax.random.PRNGKey(42)

for d in jax.devices():
self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d),
jax.device_put(jnp.array(1, dtype=np.float64), d)]

def to_numpy(self, a):
return np.array(a)

def from_numpy(self, a, type_as=None):
if type_as is None:
return jnp.array(a)
else:
return jnp.array(a).astype(type_as.dtype)
return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device())

def set_gradients(self, val, inputs, grads):
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -1130,6 +1138,7 @@ class TorchBackend(Backend):

__name__ = 'torch'
__type__ = torch_type
__type_list__ = None

rng_ = None

Expand All @@ -1138,6 +1147,13 @@ def __init__(self):
self.rng_ = torch.Generator()
self.rng_.seed()

self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
torch.tensor(1, dtype=torch.float64)]

if torch.cuda.is_available():
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))

from torch.autograd import Function

# define a function that takes inputs val and grads
Expand All @@ -1160,6 +1176,8 @@ def to_numpy(self, a):
return a.cpu().detach().numpy()

def from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
else:
Expand Down
14 changes: 7 additions & 7 deletions ot/lp/solver_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,

# ensure that same mass
np.testing.assert_almost_equal(
nx.sum(a, axis=0),
nx.sum(b, axis=0),
nx.to_numpy(nx.sum(a, axis=0)),
nx.to_numpy(nx.sum(b, axis=0)),
err_msg='a and b vector must have the same sum'
)
b = b * nx.sum(a) / nx.sum(b)
Expand All @@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
perm_b = nx.argsort(x_b_1d)

G_sorted, indices, cost = emd_1d_sorted(
nx.to_numpy(a[perm_a]),
nx.to_numpy(b[perm_b]),
nx.to_numpy(x_a_1d[perm_a]),
nx.to_numpy(x_b_1d[perm_b]),
nx.to_numpy(a[perm_a]).astype(np.float64),
nx.to_numpy(b[perm_b]).astype(np.float64),
nx.to_numpy(x_a_1d[perm_a]).astype(np.float64),
nx.to_numpy(x_b_1d[perm_b]).astype(np.float64),
metric=metric, p=p
)

Expand All @@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
elif str(nx) == "jax":
warnings.warn("JAX does not support sparse matrices, converting to dense")
if log:
log = {'cost': cost}
log = {'cost': nx.from_numpy(cost, type_as=x_a)}
return G, log
return G

Expand Down
93 changes: 93 additions & 0 deletions test/test_1d_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,96 @@ def test_wasserstein_1d(nx):
Xb = nx.from_numpy(X)
res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)


@pytest.mark.parametrize('nx', backend_list)
def test_wasserstein_1d_type_devices(nx):

rng = np.random.RandomState(0)

n = 10
x = np.linspace(0, 5, n)
rho_u = np.abs(rng.randn(n))
rho_u /= rho_u.sum()
rho_v = np.abs(rng.randn(n))
rho_v /= rho_v.sum()

for tp in nx.__type_list__:

print(tp.dtype)

xb = nx.from_numpy(x, type_as=tp)
rho_ub = nx.from_numpy(rho_u, type_as=tp)
rho_vb = nx.from_numpy(rho_v, type_as=tp)

res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)

if not str(nx) == 'numpy':
assert res.dtype == xb.dtype


def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
m = 30
rng = np.random.RandomState(0)
u = rng.randn(n, 1)
v = rng.randn(m, 1)

M = ot.dist(u, v, metric='sqeuclidean')

G, log = ot.emd([], [], M, log=True)
wass = log["cost"]
G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
wass1d = log["cost"]
wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)

# check loss is similar
np.testing.assert_allclose(wass, wass1d)
np.testing.assert_allclose(wass, wass1d_emd2)

# check loss is similar to scipy's implementation for Euclidean metric
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
np.testing.assert_allclose(wass_sp, wass1d_euc)

# check constraints
np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))

# check G is similar
np.testing.assert_allclose(G, G_1d, atol=1e-15)

# check AssertionError is raised if called on non 1d arrays
u = np.random.randn(n, 2)
v = np.random.randn(m, 2)
with pytest.raises(AssertionError):
ot.emd_1d(u, v, [], [])


def test_emd1d_type_devices(nx):

rng = np.random.RandomState(0)

n = 10
x = np.linspace(0, 5, n)
rho_u = np.abs(rng.randn(n))
rho_u /= rho_u.sum()
rho_v = np.abs(rng.randn(n))
rho_v /= rho_v.sum()

for tp in nx.__type_list__:

print(tp.dtype)

xb = nx.from_numpy(x, type_as=tp)
rho_ub = nx.from_numpy(rho_u, type_as=tp)
rho_vb = nx.from_numpy(rho_v, type_as=tp)

emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)

emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)

assert emd.dtype == xb.dtype
if not str(nx) == 'numpy':
assert emd2.dtype == xb.dtype
67 changes: 27 additions & 40 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import ot
from ot.datasets import make_1D_gauss as gauss
from ot.backend import torch
from scipy.stats import wasserstein_distance


def test_emd_dimension_and_mass_mismatch():
Expand Down Expand Up @@ -77,6 +76,33 @@ def test_emd2_backends(nx):
np.allclose(val, nx.to_numpy(valb))


def test_emd_emd2_types_devices(nx):
n_samples = 100
n_features = 2
rng = np.random.RandomState(0)

x = rng.randn(n_samples, n_features)
y = rng.randn(n_samples, n_features)
a = ot.utils.unif(n_samples)

M = ot.dist(x, y)

for tp in nx.__type_list__:

print(tp.dtype)

ab = nx.from_numpy(a, type_as=tp)
Mb = nx.from_numpy(M, type_as=tp)

Gb = ot.emd(ab, ab, Mb)

w = ot.emd2(ab, ab, Mb)

assert Gb.dtype == Mb.dtype
if not str(nx) == 'numpy':
assert w.dtype == Mb.dtype


def test_emd2_gradients():
n_samples = 100
n_features = 2
Expand Down Expand Up @@ -126,45 +152,6 @@ def test_emd_emd2():
np.testing.assert_allclose(w, 0)


def test_emd_1d_emd2_1d():
# test emd1d gives similar results as emd
n = 20
m = 30
rng = np.random.RandomState(0)
u = rng.randn(n, 1)
v = rng.randn(m, 1)

M = ot.dist(u, v, metric='sqeuclidean')

G, log = ot.emd([], [], M, log=True)
wass = log["cost"]
G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
wass1d = log["cost"]
wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)

# check loss is similar
np.testing.assert_allclose(wass, wass1d)
np.testing.assert_allclose(wass, wass1d_emd2)

# check loss is similar to scipy's implementation for Euclidean metric
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
np.testing.assert_allclose(wass_sp, wass1d_euc)

# check constraints
np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))

# check G is similar
np.testing.assert_allclose(G, G_1d, atol=1e-15)

# check AssertionError is raised if called on non 1d arrays
u = np.random.randn(n, 2)
v = np.random.randn(m, 2)
with pytest.raises(AssertionError):
ot.emd_1d(u, v, [], [])


def test_emd_empty():
# test emd and emd2 for simple identity
n = 100
Expand Down