Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#### New features

- Better list of related examples in quick start guide with `minigallery` (PR #334)
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
of the regularization parameter (PR #336)

#### Closed issues

Expand Down
44 changes: 42 additions & 2 deletions ot/dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

# Author: Remi Flamary <remi.flamary@unice.fr>
# Minhui Huang <mhhuang@ucdavis.edu>
# Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
#
# License: MIT License

Expand Down Expand Up @@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k):
return G


def logsumexp(M, axis):
r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)
"""
amax = np.amax(M, axis=axis, keepdims=True)
return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis)


def sinkhorn_log(w1, w2, M, reg, k):
r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)
"""
Mr = -M / reg
ui = np.zeros((M.shape[0],))
vi = np.zeros((M.shape[1],))
log_w1 = np.log(w1)
log_w2 = np.log(w2)
for i in range(k):
vi = log_w2 - logsumexp(Mr + ui[:, None], 0)
ui = log_w1 - logsumexp(Mr + vi[None, :], 1)
G = np.exp(ui[:, None] + Mr + vi[None, :])
return G


def split_classes(X, y):
r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
"""
Expand Down Expand Up @@ -110,7 +133,7 @@ def proj(X):
return Popt, proj


def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False):
def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False):
r"""
Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`

Expand All @@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
- :math:`W` is entropic regularized Wasserstein distances
- :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i

**Choosing a Sinkhorn solver**

By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sparse cost matrices, you should use the
:py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical
errors, but can be slow in practice.

Parameters
----------
X : ndarray, shape (n, d)
Expand All @@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
sinkhorn_method : str
method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log'
P0 : ndarray, shape (d, p)
Initial starting point for projection.
normalize : bool, optional
Expand All @@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa

if sinkhorn_method.lower() == 'sinkhorn':
sinkhorn_solver = sinkhorn
elif sinkhorn_method.lower() == 'sinkhorn_log':
sinkhorn_solver = sinkhorn_log
else:
raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method)

mx = np.mean(X)
X -= mx.reshape((1, -1))

Expand Down Expand Up @@ -193,7 +233,7 @@ def cost(P):
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k)
G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k)
if j == 0:
loss_w += np.sum(G * M)
else:
Expand Down
22 changes: 22 additions & 0 deletions test/test_dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ def test_wda():
np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))


@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
def test_wda_low_reg():

n_samples = 100 # nb samples in source and target datasets
np.random.seed(0)

# generate gaussian dataset
xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples)

n_features_noise = 8

xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise)))

p = 2

Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log')

projwda(xs)

np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))


@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
def test_wda_normalized():

Expand Down