Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
7ab9037
Gromov-Wasserstein distance
Aug 28, 2017
0a68bf4
gromov:flake8 and other
Aug 28, 2017
3007f1d
Minor corrections suggested by @agramfort + new barycenter example + …
Aug 31, 2017
bc68cc3
minor corrections
Aug 31, 2017
986f46d
Merge branch 'master' into gromov
ncourty Aug 31, 2017
e89f09d
remove linewidth error message
Slasnista Jul 28, 2017
f469205
first proposal for OT wrappers
Slasnista Jul 28, 2017
fa36e77
small modifs according to NG proposals
Slasnista Jul 28, 2017
aa19b6a
integrate AG comments
Slasnista Jul 28, 2017
5ab5035
own BaseEstimator class written + rflamary comments addressed
Slasnista Jul 31, 2017
c7eaaf4
update SinkhornTransport class + added test for class
Slasnista Aug 1, 2017
d5c6cc1
added EMDTransport Class from NG's code + added dedicated test
Slasnista Aug 1, 2017
cd4fa72
added test for fit_transform + correction of fit_transform bug (missi…
Slasnista Aug 4, 2017
0659abe
added new class SinkhornLpl1Transport() + dedicated test
Slasnista Aug 4, 2017
2005a09
added new class SinkhornL1l2Transport() + dedicated test
Slasnista Aug 4, 2017
4e562a1
semi supervised mode supported
Slasnista Aug 4, 2017
62b40a9
correction of semi supervised mode
Slasnista Aug 4, 2017
266abb6
reformat doc strings + remove useless log / verbose parameters for emd
Slasnista Aug 4, 2017
b8672f6
out of samples by Ferradans supported for transform and inverse_trans…
Slasnista Aug 4, 2017
117cd33
added new class MappingTransport to support linear and kernel mapping…
Slasnista Aug 4, 2017
d20a067
make doc strings compliant with numpy / modif according to AG review
Slasnista Aug 23, 2017
8d19d36
out of samples transform and inverse transform by batch
Slasnista Aug 23, 2017
c8ae584
test functions for MappingTransport Class
Slasnista Aug 23, 2017
fc58f39
added deprecation warning on old classes
Slasnista Aug 23, 2017
6167f34
solving log issues to avoid errors and adding further tests
Slasnista Aug 25, 2017
181fcd3
refactoring examples according to new DA classes
Slasnista Aug 25, 2017
e1a3984
small corrections for examples
Slasnista Aug 25, 2017
4f802cf
set properly path of data
Slasnista Aug 25, 2017
e1606c1
move no da objects into utils.py
Slasnista Aug 28, 2017
f79f483
handling input arguments in fit, transform... methods + remove old ex…
Slasnista Aug 28, 2017
84e56a0
check input parameters with helper functions
Slasnista Aug 28, 2017
5964001
update readme
Slasnista Aug 28, 2017
24362ec
Gromov-Wasserstein distance
Aug 28, 2017
f8744a3
gromov:flake8 and other
Aug 28, 2017
3730779
addressed AG comments + adding random seed
Slasnista Aug 29, 2017
5a9795f
pass on examples | introduced RandomState
Slasnista Aug 29, 2017
6ae3ad7
Changes to LP solver:
toto6 Aug 29, 2017
b562927
Fix param order
toto6 Aug 29, 2017
0f7cd92
Type print
toto6 Aug 29, 2017
ceeb063
Changes:
toto6 Aug 30, 2017
8875f65
Rename for emd and emd2
toto6 Aug 30, 2017
5076131
Fix name error
toto6 Aug 30, 2017
6d60230
Move normalize function in utils.py
toto6 Aug 30, 2017
93dee55
Move norm out of fit to init for deprecated OTDA
toto6 Aug 30, 2017
8c52517
Minor corrections suggested by @agramfort + new barycenter example + …
Aug 31, 2017
4ec5b33
minor corrections
Aug 31, 2017
ab6ed1d
docstrings and naming
Sep 1, 2017
64a5d3c
docstrings and naming
Sep 1, 2017
46fc12a
solving conflicts :/
Sep 1, 2017
f12322c
add barycenters to Readme.md
Sep 1, 2017
53e1115
docstrings + naming
Sep 1, 2017
8ea74ad
docstrings + naming
Sep 1, 2017
36bf599
Corrections on Gromov
Sep 12, 2017
24784ed
Corrections on Gromov
Sep 12, 2017
84c2723
Corrections on Gromov
Sep 12, 2017
55db350
Corrections on Gromov
Sep 12, 2017
5a2ebfa
Corrections on Gromov
ncourty Sep 13, 2017
7e5df4c
Merge branch 'gromov' of https://github.com/rflamary/POT into gromov
ncourty Sep 13, 2017
c86cc4f
Merge branch 'master' into gromov
ncourty Sep 13, 2017
c7eef9d
Merge branch 'master' into gromov
ncourty Sep 13, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ It provides the following solvers:
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
* Joint OT matrix and mapping estimation [8].
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).

* Gromov-Wasserstein distances and barycenters [12]

Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.

Expand Down Expand Up @@ -184,3 +184,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816.

[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063.

[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.
Binary file added data/cross.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/square.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/star.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/triangle.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
90 changes: 90 additions & 0 deletions examples/plot_gromov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
"""
==========================
Gromov-Wasserstein example
==========================
This example is designed to show how to use the Gromov-Wassertsein distance
computation in POT.
"""

# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

import scipy as sp
import numpy as np
import matplotlib.pylab as pl

import ot


"""
Sample two Gaussian distributions (2D and 3D)
=============================================
The Gromov-Wasserstein distance allows to compute distances with samples that
do not belong to the same metric space. For demonstration purpose, we sample
two Gaussian distributions in 2- and 3-dimensional spaces.
"""

n_samples = 30 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4, 4])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])


xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
P = sp.linalg.sqrtm(cov_t)
xt = np.random.randn(n_samples, 3).dot(P) + mu_t


"""
Plotting the distributions
==========================
"""
fig = pl.figure()
ax1 = fig.add_subplot(121)
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
pl.show()


"""
Compute distance kernels, normalize them and then display
=========================================================
"""

C1 = sp.spatial.distance.cdist(xs, xs)
C2 = sp.spatial.distance.cdist(xt, xt)

C1 /= C1.max()
C2 /= C2.max()

pl.figure()
pl.subplot(121)
pl.imshow(C1)
pl.subplot(122)
pl.imshow(C2)
pl.show()

"""
Compute Gromov-Wasserstein plans and distance
=============================================
"""

p = ot.unif(n_samples)
q = ot.unif(n_samples)

gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)

print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))

pl.figure()
pl.imshow(gw, cmap='jet')
pl.colorbar()
pl.show()
248 changes: 248 additions & 0 deletions examples/plot_gromov_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# -*- coding: utf-8 -*-
"""
=====================================
Gromov-Wasserstein Barycenter example
=====================================
This example is designed to show how to use the Gromov-Wasserstein distance
computation in POT.
"""

# Author: Erwan Vautier <erwan.vautier@gmail.com>
# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License


import numpy as np
import scipy as sp

import scipy.ndimage as spi
import matplotlib.pylab as pl
from sklearn import manifold
from sklearn.decomposition import PCA

import ot

"""

Smacof MDS
==========
This function allows to find an embedding of points given a dissimilarity matrix
that will be given by the output of the algorithm
"""


def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
"""
Returns an interpolated point cloud following the dissimilarity matrix C
using SMACOF multidimensional scaling (MDS) in specific dimensionned
target space

Parameters
----------
C : ndarray, shape (ns, ns)
dissimilarity matrix
dim : int
dimension of the targeted space
max_iter : int
Maximum number of iterations of the SMACOF algorithm for a single run
eps : float
relative tolerance w.r.t stress to declare converge

Returns
-------
npos : ndarray, shape (R, dim)
Embedded coordinates of the interpolated point cloud (defined with
one isometry)
"""

rng = np.random.RandomState(seed=3)

mds = manifold.MDS(
dim,
max_iter=max_iter,
eps=1e-9,
dissimilarity='precomputed',
n_init=1)
pos = mds.fit(C).embedding_

nmds = manifold.MDS(
2,
max_iter=max_iter,
eps=1e-9,
dissimilarity="precomputed",
random_state=rng,
n_init=1)
npos = nmds.fit_transform(C, init=pos)

return npos


"""
Data preparation
================
The four distributions are constructed from 4 simple images
"""


def im2mat(I):
"""Converts and image to matrix (one pixel per line)"""
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))


square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256

shapes = [square, cross, triangle, star]

S = 4
xs = [[] for i in range(S)]


for nb in range(4):
for i in range(8):
for j in range(8):
if shapes[nb][i, j] < 0.95:
xs[nb].append([j, 8 - i])

xs = np.array([np.array(xs[0]), np.array(xs[1]),
np.array(xs[2]), np.array(xs[3])])


"""
Barycenter computation
======================
The four distributions are constructed from 4 simple images
"""
ns = [len(xs[s]) for s in range(S)]
n_samples = 30

"""Compute all distances matrices for the four shapes"""
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
Cs = [cs / cs.max() for cs in Cs]

ps = [ot.unif(ns[s]) for s in range(S)]
p = ot.unif(n_samples)


lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]

Ct01 = [0 for i in range(2)]
for i in range(2):
Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
[ps[0], ps[1]
], p, lambdast[i], 'square_loss', 5e-4,
max_iter=100, stopThr=1e-3)

Ct02 = [0 for i in range(2)]
for i in range(2):
Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
[ps[0], ps[2]
], p, lambdast[i], 'square_loss', 5e-4,
max_iter=100, stopThr=1e-3)

Ct13 = [0 for i in range(2)]
for i in range(2):
Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
[ps[1], ps[3]
], p, lambdast[i], 'square_loss', 5e-4,
max_iter=100, stopThr=1e-3)

Ct23 = [0 for i in range(2)]
for i in range(2):
Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
[ps[2], ps[3]
], p, lambdast[i], 'square_loss', 5e-4,
max_iter=100, stopThr=1e-3)

"""
Visualization
=============
"""

"""The PCA helps in getting consistency between the rotations"""

clf = PCA(n_components=2)
npos = [0, 0, 0, 0]
npos = [smacof_mds(Cs[s], 2) for s in range(S)]

npost01 = [0, 0]
npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]

npost02 = [0, 0]
npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]

npost13 = [0, 0]
npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]

npost23 = [0, 0]
npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]


fig = pl.figure(figsize=(10, 10))

ax1 = pl.subplot2grid((4, 4), (0, 0))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')

ax2 = pl.subplot2grid((4, 4), (0, 1))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')

ax3 = pl.subplot2grid((4, 4), (0, 2))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')

ax4 = pl.subplot2grid((4, 4), (0, 3))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')

ax5 = pl.subplot2grid((4, 4), (1, 0))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')

ax6 = pl.subplot2grid((4, 4), (1, 3))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')

ax7 = pl.subplot2grid((4, 4), (2, 0))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')

ax8 = pl.subplot2grid((4, 4), (2, 3))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')

ax9 = pl.subplot2grid((4, 4), (3, 0))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')

ax10 = pl.subplot2grid((4, 4), (3, 1))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')

ax11 = pl.subplot2grid((4, 4), (3, 2))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')

ax12 = pl.subplot2grid((4, 4), (3, 3))
pl.xlim((-1, 1))
pl.ylim((-1, 1))
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
6 changes: 5 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

Expand All @@ -17,11 +18,13 @@
from . import datasets
from . import plot
from . import da
from . import gromov

# OT functions
from .lp import emd, emd2
from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
from .gromov import gromov_wasserstein, gromov_wasserstein2

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -30,4 +33,5 @@

__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'gromov_wasserstein','gromov_wasserstein2']
Loading