Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
3306f40
GPU changes:
toto6 Sep 13, 2017
93376d5
Fix indent
toto6 Sep 13, 2017
3bf4d09
Remove debug
toto6 Sep 13, 2017
bbe9a92
Remove debug
toto6 Sep 14, 2017
05c1bd6
Small speedup LpL1
toto6 Sep 14, 2017
111d3ff
Merge branch 'master' into master
toto6 Sep 14, 2017
abf6fe0
Add utils functions for CPU/GPU merge
toto6 Sep 19, 2017
d7d25b3
Use function pairwiseEuclidean for dist computation
toto6 Sep 19, 2017
e99328b
Merge https://github.com/rflamary/POT
toto6 Sep 19, 2017
7874570
Initial implementation GPU for sinkhorn
toto6 Sep 19, 2017
6c7012f
Little speedup sinkhorn:
toto6 Sep 19, 2017
eae9d2a
Little speedup sinkhorn:
toto6 Sep 19, 2017
a30f3aa
Move test file
toto6 Sep 19, 2017
23869f8
Fix mistake
toto6 Sep 20, 2017
dcfae03
Replace xp by np
toto6 Sep 20, 2017
9d9d996
Compute sinkhorn using float type of inputs
toto6 Sep 28, 2017
d6157db
Add decorator from Rémi
toto6 Oct 9, 2017
131b711
gpu changes:
toto6 Dec 3, 2017
39275aa
fix some warnings
toto6 Dec 3, 2017
7ad1571
fix mistake
toto6 Dec 3, 2017
0b40b5f
Update gpu test file
toto6 Dec 3, 2017
cb739f6
add linear mapping function
rflamary Mar 20, 2018
8fc9fce
add class LinearTransport
rflamary Mar 20, 2018
c104623
passing tests
rflamary Mar 20, 2018
4fc9ccc
better example+test
rflamary Mar 20, 2018
88a81c3
makefile update
rflamary Mar 20, 2018
287c659
update example
rflamary Mar 20, 2018
6fdf5de
add linear mapping test + autopep8
rflamary Mar 21, 2018
5efdf00
add test linear mapping class
rflamary Mar 21, 2018
fc9923d
add tests for ot.uils
rflamary Mar 21, 2018
927395b
add externals for function signature
rflamary Mar 21, 2018
55aaf78
add test gromov + debug sklearn Basestimator
rflamary Mar 21, 2018
64ef33d
aupdate gromov + autopep8 externals
rflamary Mar 21, 2018
7095e03
gtomov barycenter tests
rflamary Mar 21, 2018
63fd11e
add entropic gromov test for 90+% corerage
rflamary Mar 21, 2018
1262563
update readme + doc
rflamary Mar 21, 2018
0ce1a5e
update doc
rflamary Mar 21, 2018
83c706c
pep cleanup
rflamary Mar 21, 2018
69c7d1c
pep8 unused variable
rflamary Mar 21, 2018
7681db5
update reame
rflamary Mar 21, 2018
aa12256
add automatically documentation to gu_fun
rflamary May 2, 2018
aa24c1d
pep8 util function
rflamary May 2, 2018
6f964e4
better decorate
rflamary May 2, 2018
0bec096
awesome merge
rflamary May 2, 2018
90636c4
working decorator for gpu
rflamary May 2, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@


PYTHON=python
PYTHON=python3

help :
@echo "The following make targets are available:"
Expand Down Expand Up @@ -41,7 +41,7 @@ pep8 :
flake8 examples/ ot/ test/

test : FORCE pep8
python -m py.test -v test/ --cov=ot --cov-report html:cov_html
$(PYTHON) -m pytest -v test/ --cov=ot --cov-report html:cov_html

pytest : FORCE
python -m py.test -v test/ --cov=ot
Expand All @@ -56,6 +56,11 @@ rdoc :

notebook :
ipython notebook --matplotlib=inline --notebook-dir=notebooks/

autopep8 :
autopep8 -ir test ot examples

aautopep8 :
autopep8 -air test ot examples

FORCE :
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:

* OT solver for the linear program/ Earth Movers Distance [1].
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat).
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cupy).
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
* Joint OT matrix and mapping estimation [8].
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
* Gromov-Wasserstein distances and barycenters [12]
* Gromov-Wasserstein distances and barycenters ([13] and regularized [12])

Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.

Expand Down Expand Up @@ -67,10 +67,10 @@ Some sub-modules require additional dependences which are discussed below
```
pip install pymanopt autograd
```
* **ot.gpu** (GPU accelerated OT) depends on cudamat that have to be installed with:
* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed with:
```
git clone https://github.com/cudamat/cudamat.git
cd cudamat
git clone https://github.com/cupy/cupy
cd cupy
python setup.py install --user # for user install (no root)
```

Expand Down Expand Up @@ -206,3 +206,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.

[13] Mémoli, Facundo. [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 (2011): 417-487.

[14] Knott, M. and Smith, C. S. [On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43, 1984.

[15] Peyré, G., & Cuturi, M. (2017). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) , 2018.
25 changes: 16 additions & 9 deletions docs/source/readme.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
POT: Python Optimal Transport
=============================

|PyPI version| |Build Status| |Documentation Status| |Anaconda Cloud|
|License| |Anaconda downloads|
|PyPI version| |Anaconda Cloud| |Build Status| |Documentation Status|
|Anaconda downloads| |License|

This open source Python library provide several solvers for optimization
problems related to Optimal Transport for signal, image processing and
Expand All @@ -13,7 +13,7 @@ It provides the following solvers:
- OT solver for the linear program/ Earth Movers Distance [1].
- Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2]
and stabilized version [9][10] with optional GPU implementation
(required cudamat).
(required cupy).
- Bregman projections for Wasserstein barycenter [3] and unmixing [4].
- Optimal transport for domain adaptation with group lasso
regularization [5]
Expand Down Expand Up @@ -90,13 +90,13 @@ below

pip install pymanopt autograd

- **ot.gpu** (GPU accelerated OT) depends on cudamat that have to be
- **ot.gpu** (GPU accelerated OT) depends on cupy that have to be
installed with:

::

git clone https://github.com/cudamat/cudamat.git
cd cudamat
git clone https://github.com/cupy/cupy
cd cupy
python setup.py install --user # for user install (no root)

obviously you need CUDA installed and a compatible GPU.
Expand Down Expand Up @@ -311,15 +311,22 @@ approach to object
matching <https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf>`__.
Foundations of computational mathematics 11.4 (2011): 417-487.

[14] Knott, M. and Smith, C. S. `On the optimal mapping of
distributions <https://link.springer.com/article/10.1007/BF00934745>`__,
Journal of Optimization Theory and Applications Vol 43, 1984.

[15] Peyré, G., & Cuturi, M. (2017). `Computational Optimal
Transport <https://arxiv.org/pdf/1803.00567.pdf>`__ , 2018.

.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
:target: https://anaconda.org/conda-forge/pot
.. |Build Status| image:: https://travis-ci.org/rflamary/POT.svg?branch=master
:target: https://travis-ci.org/rflamary/POT
.. |Documentation Status| image:: https://readthedocs.org/projects/pot/badge/?version=latest
:target: http://pot.readthedocs.io/en/latest/?badge=latest
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
.. |Anaconda downloads| image:: https://anaconda.org/conda-forge/pot/badges/downloads.svg
:target: https://anaconda.org/conda-forge/pot
.. |License| image:: https://anaconda.org/conda-forge/pot/badges/license.svg
:target: https://github.com/rflamary/POT/blob/master/LICENSE
.. |Anaconda downloads| image:: https://anaconda.org/conda-forge/pot/badges/downloads.svg
:target: https://anaconda.org/conda-forge/pot
138 changes: 138 additions & 0 deletions examples/plot_otda_linear_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 20 14:31:15 2018

@author: rflamary
"""

import numpy as np
import pylab as pl
import ot

##############################################################################
# Generate data
# -------------

n = 1000
d = 2
sigma = .1

# source samples
angles = np.random.rand(n, 1) * 2 * np.pi
xs = np.concatenate((np.sin(angles), np.cos(angles)),
axis=1) + sigma * np.random.randn(n, 2)
xs[:n // 2, 1] += 2


# target samples
anglet = np.random.rand(n, 1) * 2 * np.pi
xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
axis=1) + sigma * np.random.randn(n, 2)
xt[:n // 2, 1] += 2


A = np.array([[1.5, .7], [.7, 1.5]])
b = np.array([[4, 2]])
xt = xt.dot(A) + b

##############################################################################
# Plot data
# ---------

pl.figure(1, (5, 5))
pl.plot(xs[:, 0], xs[:, 1], '+')
pl.plot(xt[:, 0], xt[:, 1], 'o')


##############################################################################
# Estimate linear mapping and transport
# -------------------------------------

Ae, be = ot.da.OT_mapping_linear(xs, xt)

xst = xs.dot(Ae) + be


##############################################################################
# Plot transported samples
# ------------------------

pl.figure(1, (5, 5))
pl.clf()
pl.plot(xs[:, 0], xs[:, 1], '+')
pl.plot(xt[:, 0], xt[:, 1], 'o')
pl.plot(xst[:, 0], xst[:, 1], '+')

pl.show()

##############################################################################
# Load image data
# ---------------


def im2mat(I):
"""Converts and image to matrix (one pixel per line)"""
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))


def mat2im(X, shape):
"""Converts back a matrix to an image"""
return X.reshape(shape)


def minmax(I):
return np.clip(I, 0, 1)


# Loading images
I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256
I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256


X1 = im2mat(I1)
X2 = im2mat(I2)

##############################################################################
# Estimate mapping and adapt
# ----------------------------

mapping = ot.da.LinearTransport()

mapping.fit(Xs=X1, Xt=X2)


xst = mapping.transform(Xs=X1)
xts = mapping.inverse_transform(Xt=X2)

I1t = minmax(mat2im(xst, I1.shape))
I2t = minmax(mat2im(xts, I2.shape))

# %%


##############################################################################
# Plot transformed images
# -----------------------

pl.figure(2, figsize=(10, 7))

pl.subplot(2, 2, 1)
pl.imshow(I1)
pl.axis('off')
pl.title('Im. 1')

pl.subplot(2, 2, 2)
pl.imshow(I2)
pl.axis('off')
pl.title('Im. 2')

pl.subplot(2, 2, 3)
pl.imshow(I1t)
pl.axis('off')
pl.title('Mapping Im. 1')

pl.subplot(2, 2, 4)
pl.imshow(I2t)
pl.axis('off')
pl.title('Inverse mapping Im. 2')
68 changes: 68 additions & 0 deletions examples/test_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-

import time

import numpy as np
import cupy as cp
from scipy.spatial.distance import cdist
import ot


def benchDistance(a, b):
# First compare computation time for computing pairwise euclidean matrix
time1 = time.time()
M1 = cdist(a, b, metric="sqeuclidean")
time2 = time.time()
M2 = ot.utils.pairwiseEuclidean(a, b, gpu=False, squared=True)
time3 = time.time()
M3 = ot.utils.pairwiseEuclidean(a, b, gpu=True, squared=True)
time4 = time.time()

np.testing.assert_allclose(M1, M2, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(M2, cp.asnumpy(M3), rtol=1e-5, atol=1e-5)
print(" scipy's cdist, time: {:6.2f} sec ".format(time2 - time1))
print("pairwiseEuclidean CPU, time: {:6.2f} sec ".format(time3 - time2))
print("pairwiseEuclidean GPU, time: {:6.2f} sec ".format(time4 - time3))


def benchSinkhorn(a, labels_a, b):
# Then compare computation time for computing optimal sinkhorn coupling
ot1 = ot.da.SinkhornTransport(gpu=False)
ot2 = ot.da.SinkhornTransport(gpu=True)

time1 = time.time()
ot1.fit(Xs=a, Xt=b)
g1 = ot1.coupling_
time2 = time.time()
ot2.fit(Xs=a, Xt=b)
g2 = ot2.coupling_
time3 = time.time()

print("Sinkhorn CPU, time: {:6.2f} sec ".format(time2 - time1))
print("Sinkhorn GPU, time: {:6.2f} sec ".format(time3 - time2))
np.testing.assert_allclose(g1, cp.asnumpy(g2), rtol=1e-5, atol=1e-5)

otlpl1 = ot.da.SinkhornLpl1Transport(gpu=False)
otlpl2 = ot.da.SinkhornLpl1Transport(gpu=True)
time1 = time.time()
otlpl1.fit(Xs=a, ys=labels_a, Xt=b)
g1 = otlpl1.coupling_
time2 = time.time()
otlpl2.fit(Xs=a, ys=labels_a, Xt=b)
g2 = otlpl2.coupling_
time3 = time.time()

print("Sinkhorn LpL1 CPU, time: {:6.2f} sec ".format(time2 - time1))
print("Sinkhorn LpL1 GPU, time: {:6.2f} sec ".format(time3 - time2))
np.testing.assert_allclose(g1, cp.asnumpy(g2), rtol=1e-5, atol=1e-5)


for tp in [np.float32, np.float64]:
print("Using " + str(tp))
n = 5000
d = 100
a = np.random.rand(n, d).astype(tp)
labels_a = (np.random.rand(n, 1) * 2).astype(int).ravel()
b = np.random.rand(n, d).astype(tp)
benchDistance(a, b)
benchSinkhorn(a, labels_a, b)
Loading