Skip to content

[MRG] Sliced wasserstein #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9f51c14
example for log treatment in bregman.py
AdrienCorenflos May 8, 2020
a07330c
Improve doc
AdrienCorenflos Jul 14, 2020
d3292a8
Merge remote-tracking branch 'origin/master'
AdrienCorenflos Jul 14, 2020
dfa2c9d
Revert "example for log treatment in bregman.py"
AdrienCorenflos Jul 14, 2020
36377cc
Add comments by Flamary
AdrienCorenflos Jul 20, 2020
110f382
Delete repetitive description
AdrienCorenflos Jul 20, 2020
cbf6bf5
Added raw string to avoid pbs with backslashes
AdrienCorenflos Jul 20, 2020
22e7f6b
Implements sliced wasserstein
AdrienCorenflos Jul 20, 2020
7beac55
Merge branch 'master' into sliced_wasserstein
rflamary Jul 20, 2020
ba04ed6
Changed formatting of string for py3.5 support
AdrienCorenflos Jul 20, 2020
391df18
Merge remote-tracking branch 'origin/sliced_wasserstein' into sliced_…
AdrienCorenflos Jul 20, 2020
ca8364c
Docstest, expected 0.0 and not 0.
AdrienCorenflos Jul 20, 2020
2d893f2
Adressed comments by @rflamary
AdrienCorenflos Aug 4, 2020
7d9b920
No 3d plot here
AdrienCorenflos Aug 4, 2020
b68e2c2
add sliced to the docs
AdrienCorenflos Aug 4, 2020
a1309da
Merge branch 'master' into sliced_wasserstein
rflamary Aug 25, 2020
5c5c589
Merge branch 'master' into sliced_wasserstein
rflamary Aug 31, 2020
abeba45
Merge remote-tracking branch 'upstream/master' into sliced_wasserstein
AdrienCorenflos Aug 31, 2020
9a8edb5
Incorporate comments by @rflamary
AdrienCorenflos Aug 31, 2020
64fc3e1
Merge remote-tracking branch 'origin/sliced_wasserstein' into sliced_…
AdrienCorenflos Aug 31, 2020
5590a79
add link to pdf
rflamary Sep 4, 2020
1a718b2
Merge branch 'master' into sliced_wasserstein
rflamary Oct 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples):
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32].

POT provides the following Machine Learning related solvers:

Expand Down Expand Up @@ -180,6 +181,7 @@ The contributors to this library are
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)

This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):

Expand Down Expand Up @@ -263,3 +265,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.

[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.

[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ API and modules
stochastic
unbalanced
partial
sliced

.. autosummary::
:toctree: ../modules/generated/
Expand Down
4 changes: 4 additions & 0 deletions examples/sliced-wasserstein/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@


Sliced Wasserstein Distance
---------------------------
84 changes: 84 additions & 0 deletions examples/sliced-wasserstein/plot_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""
==============================
2D Sliced Wasserstein Distance
==============================

This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31].

[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45

"""

# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
#
# License: MIT License

import matplotlib.pylab as pl
import numpy as np

import ot

##############################################################################
# Generate data
# -------------

# %% parameters and data generation

n = 500 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4])
cov_t = np.array([[1, -.8], [-.8, 1]])

xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)

a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples

##############################################################################
# Plot data
# ---------

# %% plot samples

pl.figure(1)
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')

###################################################################################
# Compute Sliced Wasserstein distance for different seeds and number of projections
# -----------

n_seed = 50
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
res = np.empty((n_seed, 25))

# %% Compute statistics
for seed in range(n_seed):
for i, n_projections in enumerate(n_projections_arr):
res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed)

res_mean = np.mean(res, axis=0)
res_std = np.std(res, axis=0)

###################################################################################
# Plot Sliced Wasserstein Distance
# -----------

pl.figure(2)
pl.plot(n_projections_arr, res_mean, label="SWD")
pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)

pl.legend()
pl.xscale('log')

pl.xlabel("Number of projections")
pl.ylabel("Distance")
pl.title('Sliced Wasserstein Distance with 95% confidence inverval')

pl.show()
3 changes: 2 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -50,4 +51,4 @@
'emd_1d', 'emd2_1d', 'wasserstein_1d',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2']
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance']
144 changes: 144 additions & 0 deletions ot/sliced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Sliced Wasserstein Distance.

"""

# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
#
# License: MIT License


import numpy as np


def get_random_projections(n_projections, d, seed=None):
r"""
Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})`

Parameters
----------
n_projections : int
number of samples requested
d : int
dimension of the space
seed: int or RandomState, optional
Seed used for numpy random number generator

Returns
-------
out: ndarray, shape (n_projections, d)
The uniform unit vectors on the sphere

Examples
--------
>>> n_projections = 100
>>> d = 5
>>> projs = get_random_projections(n_projections, d)
>>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE
True

"""

if not isinstance(seed, np.random.RandomState):
random_state = np.random.RandomState(seed)
else:
random_state = seed

projections = random_state.normal(0., 1., [n_projections, d])
norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True)
projections = projections / norm
return projections


def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
r"""
Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance

.. math::
\mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}}

where :

- :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle`


Parameters
----------
X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
a : ndarray, shape (n_samples_a,), optional
samples weights in the source domain
b : ndarray, shape (n_samples_b,), optional
samples weights in the target domain
n_projections : int, optional
Number of projections used for the Monte-Carlo approximation
seed: int or RandomState or None, optional
Seed used for numpy random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.

Returns
-------
cost: float
Sliced Wasserstein Cost
log : dict, optional
log dictionary return only if log==True in parameters

Examples
--------

>>> n_samples_a = 20
>>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0

References
----------

.. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
"""
from .lp import emd2_1d

X_s = np.asanyarray(X_s)
X_t = np.asanyarray(X_t)

n = X_s.shape[0]
m = X_t.shape[0]

if X_s.shape[1] != X_t.shape[1]:
raise ValueError(
"X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
X_t.shape[1]))

if a is None:
a = np.full(n, 1 / n)
if b is None:
b = np.full(m, 1 / m)

d = X_s.shape[1]

projections = get_random_projections(n_projections, d, seed)

X_s_projections = np.dot(projections, X_s.T)
X_t_projections = np.dot(projections, X_t.T)

if log:
projected_emd = np.empty(n_projections)
else:
projected_emd = None

res = 0.

for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)):
emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False)
if projected_emd is not None:
projected_emd[i] = emd
res += emd

res = (res / n_projections) ** 0.5
if log:
return res, {"projections": projections, "projected_emds": projected_emd}
return res
85 changes: 85 additions & 0 deletions test/test_sliced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Tests for module sliced"""

# Author: Adrien Corenflos <adrien.corenflos@aalto.fi>
#
# License: MIT License

import numpy as np
import pytest

import ot
from ot.sliced import get_random_projections


def test_get_random_projections():
rng = np.random.RandomState(0)
projections = get_random_projections(1000, 50, rng)
np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.)


def test_sliced_same_dist():
n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
u = ot.utils.unif(n)

res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
np.testing.assert_almost_equal(res, 0.)


def test_sliced_bad_shapes():
n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
y = rng.randn(n, 4)
u = ot.utils.unif(n)

with pytest.raises(ValueError):
_ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)


def test_sliced_log():
n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 4)
y = rng.randn(n, 4)
u = ot.utils.unif(n)

res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
assert len(log) == 2
projections = log["projections"]
projected_emds = log["projected_emds"]

assert len(projections) == len(projected_emds) == 10
for emd in projected_emds:
assert emd > 0


def test_sliced_different_dists():
n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
u = ot.utils.unif(n)
y = rng.randn(n, 2)

res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
assert res > 0.


def test_1d_sliced_equals_emd():
n = 100
m = 120
rng = np.random.RandomState(0)

x = rng.randn(n, 1)
a = rng.uniform(0, 1, n)
a /= a.sum()
y = rng.randn(m, 1)
u = ot.utils.unif(m)
res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
np.testing.assert_almost_equal(res ** 2, expected)