-
Notifications
You must be signed in to change notification settings - Fork 528
Gromov-Wasserstein distance #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
README.md
Outdated
@@ -16,7 +16,7 @@ It provides the following solvers: | |||
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. | |||
* Joint OT matrix and mapping estimation [8]. | |||
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt). | |||
|
|||
* Gromov-Wasserstein distances [12] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and barycenters
README.md
Outdated
@@ -182,3 +182,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t | |||
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816. | |||
|
|||
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063. | |||
|
|||
[12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gabriel Peyré to be consistent
examples/plot_gromov.py
Outdated
""" | ||
==================== | ||
Gromov-Wasserstein example | ||
==================== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not enough ===
examples/plot_gromov.py
Outdated
import numpy as np | ||
|
||
import ot | ||
import matplotlib.pylab as pl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import pl before ot
examples/plot_gromov.py
Outdated
|
||
""" | ||
Sample two Gaussian distributions (2D and 3D) | ||
==================== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not enough ===
it won't render well in sphinx
examples/plot_gromov.py
Outdated
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces. | ||
""" | ||
|
||
n = 30 # nb samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n -> n_samples
you won't need to write # nb samples :)
ot/gromov.py
Outdated
Returns the value of L(a,b)=(1/2)*|a-b|^2 | ||
""" | ||
|
||
return (1 / 2) * (a - b)**2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 / 2 will be 0 python 2
ot/gromov.py
Outdated
return b | ||
|
||
tens = -np.dot(h1(C1), T).dot(h2(C2).T) | ||
tens = tens - tens.min() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tens -= tens.min()
ot/gromov.py
Outdated
|
||
Parameters | ||
---------- | ||
C1 : np.ndarray(ns,ns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C1 : ndarray, shape (ns, ns)
is the standard of numpydoc
ot/gromov.py
Outdated
cpt = 0 | ||
err = 1 | ||
|
||
while (err > stopThr and cpt < numItermax): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid while loops. Use for with break. It's much safer to avoid infinite loops
you can use for else syntax to capture the absence of a break
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for this one I will keep the consistency with the rest of the optimization method (especially those in Bregman module)
examples/plot_gromov_barycenter.py
Outdated
""" | ||
|
||
|
||
def smacof_mds(C, dim, maxIter=3000, eps=1e-9): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maxIter -> max_iter
examples/plot_gromov_barycenter.py
Outdated
|
||
Parameters | ||
---------- | ||
C : np.ndarray(ns,ns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C : ndarray, shape (ns , ns)
examples/plot_gromov_barycenter.py
Outdated
---------- | ||
C : np.ndarray(ns,ns) | ||
dissimilarity matrix | ||
dim : Integer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Integer -> int
examples/plot_gromov_barycenter.py
Outdated
dissimilarity matrix | ||
dim : Integer | ||
dimension of the targeted space | ||
maxIter : Maximum number of iterations of the SMACOF algorithm for a single run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_iter : int
Maximum number of iterations of the SMACOF algorithm for a single run
examples/plot_gromov_barycenter.py
Outdated
Ct01 = [0 for i in range(2)] | ||
for i in range(2): | ||
Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [ | ||
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numItermax -> max_iter?
examples/plot_gromov_barycenter.py
Outdated
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256 | ||
fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256 | ||
|
||
shapes = [carre, rond, triangle, fleche] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you meant : square, circle, triangle and arrow :)
@ncourty please go over the full diff about docstrings and naming. If you're ok with me bugging you :) I'll do one more pass when you did it. |
ot/gromov.py
Outdated
'It.', 'Err') + '\n' + '-' * 19) | ||
print('{:5d}|{:8e}|'.format(cpt, err)) | ||
|
||
cpt = cpt + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpt += 1
examples/plot_gromov_barycenter.py
Outdated
square = spi.imread('../data/carre.png').astype(np.float64) / 256 | ||
circle = spi.imread('../data/rond.png').astype(np.float64) / 256 | ||
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256 | ||
arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you rename maybe the png files? also I see arrow = coeur. Is this a bug?
test/test_gromov.py
Outdated
|
||
xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s) | ||
|
||
xt = xs[::-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have written:
xt = xs[::-1].copy()
and removed the array below
examples/plot_gromov_barycenter.py
Outdated
npos : ndarray, shape (R, dim) | ||
Embedded coordinates of the interpolated point cloud (defined with one isometry) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unnecessary empty lines here and one before Returns
examples/plot_gromov.py
Outdated
""" | ||
Sample two Gaussian distributions (2D and 3D) | ||
============================================= | ||
The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line too long
ot/gromov.py
Outdated
tens : ndarray, shape (ns, nt) | ||
\mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove empty lines
examples/plot_gromov_barycenter.py
Outdated
===================================== | ||
Gromov-Wasserstein Barycenter example | ||
===================================== | ||
This example is designed to show how to use the Gromov-Wassertsein distance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wassertsein -> Wasserstein
examples/plot_gromov_barycenter.py
Outdated
|
||
def smacof_mds(C, dim, max_iter=3000, eps=1e-9): | ||
""" | ||
Returns an interpolated point cloud following the dissimilarity matrix C using SMACOF |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line too long.
see pep257 https://www.python.org/dev/peps/pep-0257/
especially for multiline docstings
examples/plot_gromov_barycenter.py
Outdated
Embedded coordinates of the interpolated point cloud (defined with one isometry) | ||
""" | ||
|
||
rng = np.random.RandomState(seed=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should expose the random_state and use check_random_state like sklearn does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait it's an example maybe it's not necessary...
---------- | ||
p : ndarray, shape (N,) | ||
weights in the targeted barycenter | ||
lambdas : list of the S spaces' weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad format
ot/gromov.py
Outdated
sample weights in the S spaces | ||
p : ndarray, shape(N,) | ||
weights in the targeted barycenter | ||
lambdas : list of the S spaces' weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad format
ot/gromov.py
Outdated
lambdas = np.asarray(lambdas, dtype=np.float64) | ||
|
||
# Initialization of C : random SPD matrix | ||
xalea = np.random.randn(N, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expose random_state to make results deterministic if one wants
Thanks for the careful reading @agramfort . And congrats for your NIPS paper :) See you in LA ? |
Hello @ncourty , I think we should merge shortly since it has converged, could you please update from master ? |
@ncourty thx :) yes see you in LA ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, you have taken into account all comments and the contribution is very nice for the toolbox.
I think we can merge.
Hi everyone,
This is a new implementation of the Gromov-Wasserstein distance, mostly programmed by Erwan Vautier and myself. In the next commit, I will add a new example on how to compute barycenters and also tests for this new functionality.