From 67b33e713faa2b5e9f609203da8bc48ea745314c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 20 Mar 2025 18:06:39 +0100 Subject: [PATCH 01/16] rrme to user guide and add quicksrat script --- docs/source/index.rst | 5 ++-- .../source/{quickstart.rst => user_guide.rst} | 4 +-- examples/plot_quickstart_guide.py | 25 +++++++++++++++++++ 3 files changed, 30 insertions(+), 4 deletions(-) rename docs/source/{quickstart.rst => user_guide.rst} (99%) create mode 100644 examples/plot_quickstart_guide.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 0f04738d9..250cd52b9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,9 +17,10 @@ Contents :maxdepth: 1 self - quickstart - all + auto_examples/plot_quickstart_guide auto_examples/index + user_guide + all releases contributors contributing diff --git a/docs/source/quickstart.rst b/docs/source/user_guide.rst similarity index 99% rename from docs/source/quickstart.rst rename to docs/source/user_guide.rst index 1f1c69398..25e9d4213 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/user_guide.rst @@ -1,6 +1,6 @@ -Quick start guide -================= +User guide +========== In the following we provide some pointers about which functions and classes to use for different problems related to optimal transport (OT) and machine diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py new file mode 100644 index 000000000..f661ac908 --- /dev/null +++ b/examples/plot_quickstart_guide.py @@ -0,0 +1,25 @@ +# coding: utf-8 +""" +============================================= +Quickstart Guide +============================================= + + +This is a quickstart guide to the Python Optimal Transport (POT) toolbox. + +""" + +# Author: Remi Flamary +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 1 + +# %% +# Simple example +# -------------- +# + +import numpy as np # always need it +import pylab as pl # for the plots + +import ot From 3c7ca56b61e6edf132055f417dddac52c55e7e53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 21 Mar 2025 09:35:14 +0100 Subject: [PATCH 02/16] remoe import of deprecated module --- RELEASES.md | 1 + ot/lp/__init__.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 3fb6c1f14..62240fa77 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -15,6 +15,7 @@ - Added `ot.gaussian.bures_wasserstein_distance` (PR #680) - `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680) - Backend implementation of `ot.dist` for (PR #701) +- Updated documentation Quickstart guide and User guide with new API (PR #726) #### Closed issues - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index e3cfce0fd..932b261df 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -8,7 +8,6 @@ # # License: MIT License -from . import cvx from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize from ._network_simplex import emd, emd2 from ._barycenter_solvers import ( From 3749263fdd7a01dac3ba6fd22a554cb49f5d9713 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 21 Mar 2025 12:24:52 +0100 Subject: [PATCH 03/16] premier jet quckstart guide --- examples/plot_quickstart_guide.py | 211 +++++++++++++++++++++++++++++- 1 file changed, 206 insertions(+), 5 deletions(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index f661ac908..23c87c492 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -5,7 +5,11 @@ ============================================= -This is a quickstart guide to the Python Optimal Transport (POT) toolbox. +This is a quickstart guide to the Python Optimal Transport (POT) toolbox. We use +here the new API of POT which is more flexible and allows to solve a wider range +of problems with just a few functions. The old API is still available (the new +one is a convenient wrapper around the old one) and we provide pointers to the +old API when needed. """ @@ -14,12 +18,209 @@ # License: MIT License # sphinx_gallery_thumbnail_number = 1 +# Import necessary libraries + +import numpy as np +import pylab as pl + +import ot + + # %% -# Simple example +# Example data # -------------- # +# Data generation +# ~~~~~~~~~~~~~~~ -import numpy as np # always need it -import pylab as pl # for the plots +# Problem size +n1 = 25 +n2 = 50 -import ot +# Generate random data +np.random.seed(0) +a = ot.utils.unif(n1) # weights of points in the source domain +b = ot.utils.unif(n2) # weights of points in the target domain + +x1 = np.random.randn(n1, 2) +x1 /= ( + np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2 +) # project on the unit circle and scale +x2 = np.random.randn(n2, 2) +x2 /= ( + np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 +) # project on the unit circle and scale + +# %% +# Plot data +# ~~~~~~~~~ + +style = {"markeredgecolor": "k"} + +pl.figure(1, (4, 4)) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.legend(loc=0) +pl.title("Source and target distributions") +pl.show() + +# %% +# Solving exact Optimal Transport +# ------------------------------- +# Solve the Optimal Transport problem between the samples +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The :func:`ot.solve_sample` function can be used to solve the Optimal Transport problem +# between two sets of samples. The function takes as its two first arguments the +# positions of the source and target samples, and returns an :class:`ot.utils.OTResult` object. + +# Solve the OT problem +sol = ot.solve_sample(x1, x2, a, b) + +# get the OT plan +P = sol.plan + +# get the OT loss +loss = sol.value + +# get the dual potentials +alpha, beta = sol.potentials + +print(f"OT loss = {loss:1.3f}") + +# %% +# We provide +# the weights of the samples in the source and target domains :code:`a` and +# :code:`b`. If not provided, the weights are assumed to be uniform. +# +# The :class:`ot.utils.OTResult` object contains the following attributes: +# +# - :code:`value`: the value of the OT problem +# - :code:`plan`: the OT matrix +# - :code:`potentials`: Dual potentials of the OT problem +# - :code:`log`: log dictionary of the solver +# +# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where +# :code:`P[i,j]` is the amount of mass +# transported from :code:`x1[i]` to :code:`x2[j]`. +# +# The OT loss is the sum of the element-wise product of the OT matrix and the +# cost matrix taken by default as the Squared Euclidean distance. +# + +# %% +# Plot the OT plan and dual potentials +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +from ot.plot import plot2D_samples_mat + +pl.figure(1, (8, 4)) + +pl.subplot(1, 2, 1) +plot2D_samples_mat(x1, x2, P) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("OT plan P loss={:.3f}".format(loss)) + +pl.subplot(1, 2, 2) +pl.scatter(x1[:, 0], x1[:, 1], c=alpha, cmap="viridis", edgecolors="k") +pl.scatter(x2[:, 0], x2[:, 1], c=beta, cmap="plasma", edgecolors="k") +pl.title("Dual potentials") +pl.show() + + +pl.figure(2, (3, 1.7)) +pl.imshow(P, cmap="Greys") +pl.title("OT plan") +pl.show() + +# %% +# Solve the Optimal Transport problem with a custom cost matrix +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The cost matrix can be customized by passing it to the more general +# :func:`ot.solve` function. The cost matrix should be a matrix of size +# :code:`(n1, n2)` where :code:`C[i,j]` is the cost of transporting mass from +# :code:`x1[i]` to :code:`x2[j]`. +# +# In this example, we use the Citybloc distance as the cost matrix. + +# Compute the cost matrix +C = ot.dist(x1, x2, metric="cityblock") + +# Solve the OT problem with the custom cost matrix +P_city = ot.solve(C).plan + +# Compute the OT loss (equivalent to ot.solve(C).value) +loss_city = np.sum(P_city * C) + +# %% +# Note that we show here how to sole the OT problem with a custom cost matrix +# with the more general :func:`ot.solve` function. +# But the same can be done with the :func:`ot.solve_sample` function by passing +# :code:`metric='cityblock'` as argument. +# +# .. note:: +# The examples above use the new API of POT. The old API is still available +# and and OT plan and loss can be computed with the :func:`ot.emd` and +# the :func:`ot.emd2` functions as below: +# +# .. code-block:: python +# +# P = ot.emd(a, b, C) +# loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b +# +# Plot the OT plan and dual potentials for other loss +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +pl.figure(1, (3, 3)) +plot2D_samples_mat(x1, x2, P) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city)) + +pl.figure(2, (3, 1.7)) +pl.imshow(P_city, cmap="Greys") +pl.title("OT plan (Citybloc)") +pl.show() + +# %% +# Sinkhorn and Regularized OT +# --------------------------- +# +# Solve Entropic Regularized OT with Sinkhorn algorithm +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# Solve the Sinkhorn problem (just add reg parameter value) +sol = ot.solve_sample(x1, x2, a, b, reg=1e-1) + +# get the OT plan and loss +P_sink = sol.plan +loss_sink = sol.value # objective value of the Sinkhorn problem (incl. entropy) +loss_sink_linear = sol.value_linear # np.sum(P_sink * C) linear part of loss + +# %% +# The Sinkhorn algorithm solves the Entropic Regularized OT problem. The +# regularization strength can be controlled with the :code:`reg` parameter. +# The Sinkhorn algorithm can be faster than the exact OT solver for large +# regularization strength but the solution is only an approximation of the +# exact OT problem and the OT plan is not sparse. +# +# Plot the OT plan and dual potentials for Sinkhorn +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +pl.figure(1, (3, 3)) +plot2D_samples_mat(x1, x2, P_sink) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink)) +pl.show() + +pl.figure(2, (3, 1.7)) +pl.imshow(P_sink, cmap="Greys") +pl.title("Sinkhorn OT plan") +pl.show() From aa03aaffa82c0427752b8e72b87a374d76a79e41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 21 Mar 2025 14:39:51 +0100 Subject: [PATCH 04/16] working on the guide --- docs/source/conf.py | 2 +- examples/plot_OT_2D_samples.py | 2 +- examples/plot_quickstart_guide.py | 168 +++++++++++++++++++----------- 3 files changed, 107 insertions(+), 65 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index e1b9a85b7..1850ff040 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -347,7 +347,7 @@ def __getattr__(cls, name): } sphinx_gallery_conf = { - "examples_dirs": ["../../examples", "../../examples/da"], + "examples_dirs": ["../../examples"], "gallery_dirs": "auto_examples", "filename_pattern": "plot_", # (?!barycenter_fgw) "nested_sections": False, diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index e51ce1285..c37e78ca2 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -65,7 +65,7 @@ # %% EMD -G0 = ot.emd(a, b, M) +G0 = ot.solve(M, a, b).plan pl.figure(3) pl.imshow(G0, interpolation="nearest") diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 23c87c492..67291aca3 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -5,18 +5,22 @@ ============================================= -This is a quickstart guide to the Python Optimal Transport (POT) toolbox. We use -here the new API of POT which is more flexible and allows to solve a wider range -of problems with just a few functions. The old API is still available (the new -one is a convenient wrapper around the old one) and we provide pointers to the -old API when needed. +Quickstart guide to the POT toolbox. + +For better readability, only the use of POT is provided and the plotting code +with matplotlib is hidden (but is available in the source file of the example). + +.. note:: + We use here the new API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The old API is still available (the new + one is a convenient wrapper around the old one) and we provide pointers to the + old API when needed. """ # Author: Remi Flamary # # License: MIT License -# sphinx_gallery_thumbnail_number = 1 +# sphinx_gallery_thumbnail_number = 4 # Import necessary libraries @@ -43,18 +47,12 @@ b = ot.utils.unif(n2) # weights of points in the target domain x1 = np.random.randn(n1, 2) -x1 /= ( - np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2 -) # project on the unit circle and scale -x2 = np.random.randn(n2, 2) -x2 /= ( - np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 -) # project on the unit circle and scale +x1 /= np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2 -# %% -# Plot data -# ~~~~~~~~~ +x2 = np.random.randn(n2, 2) +x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 +# sphinx_gallery_start_ignore style = {"markeredgecolor": "k"} pl.figure(1, (4, 4)) @@ -63,8 +61,13 @@ pl.legend(loc=0) pl.title("Source and target distributions") pl.show() +# sphinx_gallery_end_ignore # %% +# We illustrate above the simple example of two 2D distributions with 25 and 50 +# samples respectively located on circles. The weights of the samples are +# uniform. +# # Solving exact Optimal Transport # ------------------------------- # Solve the Optimal Transport problem between the samples @@ -88,31 +91,7 @@ print(f"OT loss = {loss:1.3f}") -# %% -# We provide -# the weights of the samples in the source and target domains :code:`a` and -# :code:`b`. If not provided, the weights are assumed to be uniform. -# -# The :class:`ot.utils.OTResult` object contains the following attributes: -# -# - :code:`value`: the value of the OT problem -# - :code:`plan`: the OT matrix -# - :code:`potentials`: Dual potentials of the OT problem -# - :code:`log`: log dictionary of the solver -# -# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where -# :code:`P[i,j]` is the amount of mass -# transported from :code:`x1[i]` to :code:`x2[j]`. -# -# The OT loss is the sum of the element-wise product of the OT matrix and the -# cost matrix taken by default as the Squared Euclidean distance. -# - -# %% -# Plot the OT plan and dual potentials -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# - +# sphinx_gallery_start_ignore from ot.plot import plot2D_samples_mat pl.figure(1, (8, 4)) @@ -134,6 +113,32 @@ pl.imshow(P, cmap="Greys") pl.title("OT plan") pl.show() +# sphinx_gallery_end_ignore + +# %% +# The figure above shows the Optimal Transport plan between the source and target +# samples. The color intensity represents the amount of mass transported +# between the samples. The dual potentials of the OT problem are also shown. +# +# The weights of the samples in the source and target domains :code:`a` and +# :code:`b` are given to the function. If not provided, the weights are assumed +# to be uniform See :func:`ot.solve_sample` for more details. +# +# The :class:`ot.utils.OTResult` object contains the following attributes: +# +# - :code:`value`: the value of the OT problem +# - :code:`plan`: the OT matrix +# - :code:`potentials`: Dual potentials of the OT problem +# - :code:`log`: log dictionary of the solver +# +# The OT matrix :math:`P` is a matrix of size :code:`(n1, n2)` where +# :code:`P[i,j]` is the amount of mass +# transported from :code:`x1[i]` to :code:`x2[j]`. +# +# The OT loss is the sum of the element-wise product of the OT matrix and the +# cost matrix taken by default as the Squared Euclidean distance. +# + # %% # Solve the Optimal Transport problem with a custom cost matrix @@ -155,8 +160,21 @@ # Compute the OT loss (equivalent to ot.solve(C).value) loss_city = np.sum(P_city * C) +# sphinx_gallery_start_ignore +pl.figure(1, (3, 3)) +plot2D_samples_mat(x1, x2, P) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city)) + +pl.figure(2, (3, 1.7)) +pl.imshow(P_city, cmap="Greys") +pl.title("OT plan (Citybloc)") +pl.show() +# sphinx_gallery_end_ignore + # %% -# Note that we show here how to sole the OT problem with a custom cost matrix +# Note that we show here how to solve the OT problem with a custom cost matrix # with the more general :func:`ot.solve` function. # But the same can be done with the :func:`ot.solve_sample` function by passing # :code:`metric='cityblock'` as argument. @@ -171,20 +189,9 @@ # P = ot.emd(a, b, C) # loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b # -# Plot the OT plan and dual potentials for other loss -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# .. minigallery:: ot.emd2 ot.emd ot.solve ot.solve_sample # -pl.figure(1, (3, 3)) -plot2D_samples_mat(x1, x2, P) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city)) - -pl.figure(2, (3, 1.7)) -pl.imshow(P_city, cmap="Greys") -pl.title("OT plan (Citybloc)") -pl.show() # %% # Sinkhorn and Regularized OT @@ -202,25 +209,60 @@ loss_sink = sol.value # objective value of the Sinkhorn problem (incl. entropy) loss_sink_linear = sol.value_linear # np.sum(P_sink * C) linear part of loss +# sphinx_gallery_start_ignore +pl.figure(1, (3, 3)) +plot2D_samples_mat(x1, x2, P_sink) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink)) +pl.show() + +pl.figure(2, (3, 1.7)) +pl.imshow(P_sink, cmap="Greys") +pl.title("Sinkhorn OT plan") +pl.show() +# sphinx_gallery_end_ignore # %% # The Sinkhorn algorithm solves the Entropic Regularized OT problem. The # regularization strength can be controlled with the :code:`reg` parameter. # The Sinkhorn algorithm can be faster than the exact OT solver for large # regularization strength but the solution is only an approximation of the # exact OT problem and the OT plan is not sparse. -# -# Plot the OT plan and dual potentials for Sinkhorn -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +# %% +# Solve the Regularized OT problem with other regularizations +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -pl.figure(1, (3, 3)) +# Use quadratic regularization +P_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type="L2").plan + +loss_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type="L2").value + +# sphinx_gallery_start_ignore +pl.figure(1, (9, 3)) + +pl.subplot(1, 3, 1) +plot2D_samples_mat(x1, x2, P) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("OT plan loss={:.3f}".format(loss)) + +pl.subplot(1, 3, 2) plot2D_samples_mat(x1, x2, P_sink) pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink)) -pl.show() +pl.title("Sinkhorn plan loss={:.3f}".format(loss_sink)) -pl.figure(2, (3, 1.7)) -pl.imshow(P_sink, cmap="Greys") -pl.title("Sinkhorn OT plan") +pl.subplot(1, 3, 3) +plot2D_samples_mat(x1, x2, P_quad) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Quadratic plan loss={:.3f}".format(loss_quad)) pl.show() +# sphinx_gallery_end_ignore +# %% +# We plot above the OT plans obtained with different regularizations. The +# quadratic regularization is another common choice for regularized OT and +# preserves the sparsity of the OT plan. +# From dd885bce6c8909801d1c2c38760dc0559eeb2971 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 21 Mar 2025 17:12:08 +0100 Subject: [PATCH 05/16] add stuff --- examples/plot_quickstart_guide.py | 274 ++++++++++++++++++++++++++++-- 1 file changed, 263 insertions(+), 11 deletions(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 67291aca3..f038ea5c6 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -31,11 +31,12 @@ # %% -# Example data +# Data generation # -------------- # -# Data generation -# ~~~~~~~~~~~~~~~ +# We first generate two sets of samples in 2D that 25 and 50 +# samples respectively located on circles. The weights of the samples are +# uniform. # Problem size n1 = 25 @@ -64,9 +65,6 @@ # sphinx_gallery_end_ignore # %% -# We illustrate above the simple example of two 2D distributions with 25 and 50 -# samples respectively located on circles. The weights of the samples are -# uniform. # # Solving exact Optimal Transport # ------------------------------- @@ -189,15 +187,13 @@ # P = ot.emd(a, b, C) # loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b # -# .. minigallery:: ot.emd2 ot.emd ot.solve ot.solve_sample -# # %% # Sinkhorn and Regularized OT # --------------------------- # -# Solve Entropic Regularized OT with Sinkhorn algorithm +# Entropic OT with Sinkhorn algorithm # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # @@ -230,8 +226,8 @@ # exact OT problem and the OT plan is not sparse. # %% -# Solve the Regularized OT problem with other regularizations -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Quadratic Regularized OT +# ~~~~~~~~~~~~~~~~~~~~~~~~~ # # Use quadratic regularization @@ -266,3 +262,259 @@ # quadratic regularization is another common choice for regularized OT and # preserves the sparsity of the OT plan. # +# Solve the Regularized OT problem with user-defined regularization +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + + +# Define a custom regularization function +def f(G): + return 0.5 * np.sum(G**2) + + +def df(G): + return G + + +P_reg = ot.solve_sample(x1, x2, a, b, reg=1e2, reg_type=(f, df)).plan + + +# sphinx_gallery_start_ignore +pl.figure(1, (3, 3)) +plot2D_samples_mat(x1, x2, P_reg) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Custom reg plan") +pl.show() +# sphinx_gallery_end_ignore +# %% +# +# .. note:: +# The examples above use the new API of POT. The old API is still available +# and and the entropic OT plan and loss can be computed with the +# :func:`ot.sinkhorn` # and :func:`ot.sinkhorn2` functions as below: +# +# .. code-block:: python +# +# Gs = ot.sinkhorn(a, b, C, reg=1e-1) +# loss_sink = ot.sinkhorn2(a, b, C, reg=1e-1) +# +# For quadratic regularization, the :func:`ot.smooth.smooth_ot_dual` function +# can be used to compute the solution of the regularized OT problem. For +# user-defined regularization, the :func:`ot.optim.cg` function can be used +# to solve the regularized OT problem with Conditional Gradient algorithm. +# +# Unbalanced and Partial Optimal Transport +# ---------------------------- +# +# Solve the Unbalanced OT problem +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Unbalanced OT relaxes the marginal constraints and allows for the source and +# target total weights to be different. The :func:`ot.solve_sample` function can be +# used to solve the unbalanced OT problem by setting the marginal penalization +# :code:`unbalanced` parameter to a positive value. +# + +# Solve the unbalanced OT problem with KL penalization +P_unb_kl = ot.solve_sample(x1, x2, a, b, unbalanced=5e-2).plan + +# Unbalanced with KL penalization ad KL regularization +P_unb_kl_reg = ot.solve_sample( + x1, x2, a, b, unbalanced=5e-2, reg=1e-1 +).plan # also regularized + +# Unbalanced with L2 penalization +P_unb_l2 = ot.solve_sample(x1, x2, a, b, unbalanced=7e1, unbalanced_type="L2").plan + +# sphinx_gallery_start_ignore +pl.figure(1, (9, 3)) + +pl.subplot(1, 3, 1) +plot2D_samples_mat(x1, x2, P_unb_kl) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Unbalanced KL plan") + +pl.subplot(1, 3, 2) +plot2D_samples_mat(x1, x2, P_unb_kl_reg) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Unbalanced KL + reg plan") + +pl.subplot(1, 3, 3) +plot2D_samples_mat(x1, x2, P_unb_l2) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Unbalanced L2 plan") + +pl.show() +# sphinx_gallery_end_ignore +# %% +# .. note:: +# Solving the unbalanced OT problem with the old API can be done with the +# :func:`ot.unbalanced.sinkhorn_unbalanced` function as below: +# +# .. code-block:: python +# +# G_unb_kl = ot.unbalanced.sinkhorn_unbalanced(a, b, C, eps=reg, alpha=unbalanced) +# +# Partial Optimal Transport +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# Solve the Unbalanced OT problem with TV penalization (equivalent) +P_part_pen = ot.solve_sample(x1, x2, a, b, unbalanced=3, unbalanced_type="TV").plan + +# Solve the Partial OT problem with mass constraints (only old API) +P_part_const = ot.partial.partial_wasserstein(a, b, C, m=0.5) # 50% mass transported + +# sphinx_gallery_start_ignore +pl.figure(1, (6, 3)) + +pl.subplot(1, 2, 1) +plot2D_samples_mat(x1, x2, P_part_pen) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Partial (Unb. TV) plan") + +pl.subplot(1, 2, 2) +plot2D_samples_mat(x1, x2, P_part_const) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Partial 50% mass plan") +pl.show() + +# sphinx_gallery_end_ignore +# %% +# +# Gromov-Wasserstein and Fused Gromov-Wasserstein +# ----------------------------------------------- +# +# Solve the Gromov-Wasserstein problem +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The Gromov-Wasserstein distance is a similarity measure between metric +# measure spaces. So it does not require the samples to be in the same space. +# + +# Define the metric cost matrices in each spaces + +C1 = ot.dist(x1, x1, metric="sqeuclidean") +C2 = ot.dist(x2, x2, metric="sqeuclidean") + +C1 /= C1.max() +C2 /= C2.max() + +# Solve the Gromov-Wasserstein problem +sol_gw = ot.solve_gromov(C1, C2, a=a, b=b) +P_gw = sol_gw.plan +loss_gw = sol_gw.value +loss_gw_linear = sol_gw.value_linear # linear part of loss +loss_gw_quad = sol_gw.value_quad # quadratic part of loss + +# Solve the Entropic Gromov-Wasserstein problem +P_egw = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2).plan + +# sphinx_gallery_start_ignore +pl.figure(1, (6, 3)) + +pl.subplot(1, 2, 1) +plot2D_samples_mat(x1, x2, P_gw) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("GW plan") + +pl.subplot(1, 2, 2) +plot2D_samples_mat(x1, x2, P_egw) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Entropic GW plan") +pl.show() +# sphinx_gallery_end_ignore +# %% +# .. note:: +# The Gromov-Wasserstein problem can be solved with the old API using the +# :func:`ot.gromov.gromov_wasserstein` function and the Entropic +# Gromov-Wasserstein problem can be solved with the +# :func:`ot.gromov.entropic_gromov_wasserstein` function. +# +# .. code-block:: python +# +# P_gw = ot.gromov.gromov_wasserstein(C1, C2, a, b) +# P_egw = ot.gromov.entropic_gromov_wasserstein(C1, C2, a, b, epsilon=reg) +# +# loss_gw = ot.gromov.gromov_wasserstein2(C1, C2, a, b) +# loss_egw = ot.gromov.entropic_gromov_wasserstein2(C1, C2, a, b, epsilon=reg) +# +# Fused Gromov-Wasserstein +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# Cost matrix +M = C / np.max(C) + +# Solve FGW problem with alpha=0.1 +P_fgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1).plan # C is cost across spaces + +# SOlve entropic FGW problem with alpha=0.1 +P_efgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1, reg=1e-3).plan + +# sphinx_gallery_start_ignore +pl.figure(1, (6, 3)) + +pl.subplot(1, 2, 1) +plot2D_samples_mat(x1, x2, P_fgw) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("FGW plan") + +pl.subplot(1, 2, 2) +plot2D_samples_mat(x1, x2, P_efgw) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Entropic FGW plan") +pl.show() + +# sphinx_gallery_end_ignore +# %% +# .. note:: +# The Fused Gromov-Wasserstein problem can be solved with the old API using +# the :func:`ot.gromov.fused_gromov_wasserstein` function and the Entropic +# Fused Gromov-Wasserstein problem can be solved with the +# :func:`ot.gromov.entropic_fused_gromov_wasserstein` function. +# +# .. code-block:: python +# +# P_fgw = ot.gromov.fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1) +# P_efgw = ot.gromov.entropic_fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1, epsilon=reg) +# +# loss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1) +# loss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg) +# +# Unbalanced Gromov-Wasserstein +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# # Solve the Unbalanced Gromov-Wasserstein problem +# P_gw_unb = ot.solve_gromov(C1, C2, a=a, b=b, unbalanced=1e-2).plan + +# # Solve the Unbalanced Entropic Gromov-Wasserstein problem +# P_egw_unb = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2, reg_type='KL', unbalanced=1e-2).plan + +# # sphinx_gallery_start_ignore +# pl.figure(1, (6, 3)) + +# pl.subplot(1, 2, 1) +# plot2D_samples_mat(x1, x2, P_gw_unb) +# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +# pl.title("Unbalanced GW plan") + +# pl.subplot(1, 2, 2) +# plot2D_samples_mat(x1, x2, P_egw_unb) +# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +# pl.title("Unbalanced Entropic GW plan") +# pl.show() +# # sphinx_gallery_end_ignore From 723778eb432e8f0ad23b7f324c8564ba5c42dc0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 24 Mar 2025 13:31:58 +0100 Subject: [PATCH 06/16] comment unbalanced Gromov --- examples/plot_quickstart_guide.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index f038ea5c6..dbc5477f3 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -492,25 +492,26 @@ def df(G): # loss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1) # loss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg) # -# Unbalanced Gromov-Wasserstein -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# +# # Unbalanced Gromov-Wasserstein +# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# # +# # # Solve the Unbalanced Gromov-Wasserstein problem # P_gw_unb = ot.solve_gromov(C1, C2, a=a, b=b, unbalanced=1e-2).plan - +# # # Solve the Unbalanced Entropic Gromov-Wasserstein problem # P_egw_unb = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2, reg_type='KL', unbalanced=1e-2).plan - +# # # sphinx_gallery_start_ignore # pl.figure(1, (6, 3)) - +# # pl.subplot(1, 2, 1) # plot2D_samples_mat(x1, x2, P_gw_unb) # pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) # pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) # pl.title("Unbalanced GW plan") - +# # pl.subplot(1, 2, 2) # plot2D_samples_mat(x1, x2, P_egw_unb) # pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) From 04ffaa4d6bc8e9430d1c424655531f92c844be26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 24 Mar 2025 13:33:16 +0100 Subject: [PATCH 07/16] rename section --- examples/plot_quickstart_guide.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index dbc5477f3..34152601f 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -304,7 +304,7 @@ def df(G): # user-defined regularization, the :func:`ot.optim.cg` function can be used # to solve the regularized OT problem with Conditional Gradient algorithm. # -# Unbalanced and Partial Optimal Transport +# Unbalanced and Partial OT # ---------------------------- # # Solve the Unbalanced OT problem @@ -388,8 +388,8 @@ def df(G): # sphinx_gallery_end_ignore # %% # -# Gromov-Wasserstein and Fused Gromov-Wasserstein -# ----------------------------------------------- +# Gromov-Wasserstein (GW) and Fused GW +# ------------------------------------- # # Solve the Gromov-Wasserstein problem # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From db9942f6e59ab584728338dfdf92b509149f8759 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 24 Mar 2025 17:26:18 +0100 Subject: [PATCH 08/16] first shot done --- examples/plot_quickstart_guide.py | 125 +++++++++++++++++++++++++++++- ot/bregman/_empirical.py | 4 + 2 files changed, 128 insertions(+), 1 deletion(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 34152601f..2cf52ea32 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -154,6 +154,7 @@ # Solve the OT problem with the custom cost matrix P_city = ot.solve(C).plan +# the parameters a and b are not provided so uniform weights are assumed # Compute the OT loss (equivalent to ot.solve(C).value) loss_city = np.sum(P_city * C) @@ -177,6 +178,10 @@ # But the same can be done with the :func:`ot.solve_sample` function by passing # :code:`metric='cityblock'` as argument. # +# The cost matrix can be computed with the :func:`ot.dist` function which +# computes the pairwise distance between two sets of samples or can be provided +# directly as a matrix by the user when no samples are available. +# # .. note:: # The examples above use the new API of POT. The old API is still available # and and OT plan and loss can be computed with the :func:`ot.emd` and @@ -388,7 +393,7 @@ def df(G): # sphinx_gallery_end_ignore # %% # -# Gromov-Wasserstein (GW) and Fused GW +# Gromov-Wasserstein and Fused GW # ------------------------------------- # # Solve the Gromov-Wasserstein problem @@ -519,3 +524,121 @@ def df(G): # pl.title("Unbalanced Entropic GW plan") # pl.show() # # sphinx_gallery_end_ignore +# %% +# +# Large scale OT +# -------------- +# +# We discuss here strategies to solve large scale OT problems using approximations +# of the exact OT problem. +# +# Large scale Sinkhorn +# ~~~~~~~~~~~~~~~~~~~~ +# +# When having samples with a large number of points, the Sinkhorn algorithm can +# be implemented in a Lazy version which is more memory efficient and avoids +# the computation of the :math:`n \times m` cost matrix. +# +# POT provides two implementation of the lazy Sinkhorn algorithm that return their +# results in a lazy form of type :class:`ot.utils.LazyTensor`. This object can be +# used to compute the loss or the OT plan in a lazy way or to recover its values +# in a dense form. +# + +# Solve the Sinkhorn problem in a lazy way +sol = ot.solve_sample(x1, x2, a, b, reg=1e-1, lazy=True) + +# Solve the sinkhoorn in a lazy way with geomloss +sol_geo = ot.solve_sample(x1, x2, a, b, reg=1e-1, method="geomloss", lazy=True) + +# get the OT lazy plan and loss +P_sink_lazy = sol.lazy_plan + +# recover values for Lazy plan +P12 = P_sink_lazy[1, 2] +P1dots = P_sink_lazy[1, :] +P_sink_lazy_dense = P_sink_lazy[ + : +] # convert to dense matrix !!warning this can be memory consuming + +# sphinx_gallery_start_ignore +pl.figure(1, (3, 3)) +plot2D_samples_mat(x1, x2, P_sink_lazy_dense) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Lazy Sinkhorn OT plan") +pl.show() + +pl.figure(2, (3, 1.7)) +pl.imshow(P_sink_lazy_dense, cmap="Greys") +pl.title("Lazy Sinkhorn OT plan") +pl.show() + +# sphinx_gallery_end_ignore +# +# %% +# +# the first example shows how to solve the Sinkhorn problem in a lazy way with +# the default POT implementation. The second example shows how to solve the +# Sinkhorn problem in a lazy way with the PyKeops/Geomloss implementation that provides +# a very efficient way to solve large scale problems on low dimensionality +# samples. +# +# Factored and Low rank OT +# ------------------------ +# +# The Sinkhorn algorithm can be implemented in a low rank version that +# approximates the OT plan with a low rank matrix. This can be useful to +# accelerate the computation of the OT plan for large scale problems. +# A similar non-regularized version of low rank factorization is also available. +# + +# Solve the Factored OT problem (use lazy=True for large scale) +P_fact = ot.solve_sample(x1, x2, a, b, method="factored", rank=8).plan + +P_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method="lowrank", rank=8).plan + +# sphinx_gallery_start_ignore +pl.figure(1, (6, 3)) + +pl.subplot(1, 2, 1) +plot2D_samples_mat(x1, x2, P_fact) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Factored OT plan") + +pl.subplot(1, 2, 2) +plot2D_samples_mat(x1, x2, P_lowrank) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.title("Low rank OT plan") +pl.show() + +pl.figure(2, (6, 1.7)) + +pl.subplot(1, 2, 1) +pl.imshow(P_fact, cmap="Greys") +pl.title("Factored OT plan") + +pl.subplot(1, 2, 2) +pl.imshow(P_lowrank, cmap="Greys") +pl.title("Low rank OT plan") +pl.show() + +# sphinx_gallery_end_ignore + +# %% +# +# Gaussian OT with Bures-Wasserstein +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The Gaussian Wasserstein or Bures-Wasserstein distance is the Wasserstein distance +# between Gaussian distributions. It can be used as an approximation of the +# Wasserstein distance between empirical distributions by estimating the +# covariance matrices of the samples. +# + +# Compute the Bures-Wasserstein distance +bw_value = ot.solve_sample(x1, x2, a, b, method="gaussian").value + +print(f"Bures-Wasserstein distance = {bw_value:1.3f}") diff --git a/ot/bregman/_empirical.py b/ot/bregman/_empirical.py index 055a07ef3..e010aa0c7 100644 --- a/ot/bregman/_empirical.py +++ b/ot/bregman/_empirical.py @@ -53,6 +53,10 @@ def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric="sqeuclidean", reg=1e-1, nx=N shape = (X_a.shape[0], X_b.shape[0]) def func(i, j, X_a, X_b, f, g, metric, reg): + if isinstance(i, int): + i = slice(i, i + 1) + if isinstance(j, int): + j = slice(j, j + 1) C = dist(X_a[i], X_b[j], metric=metric) return nx.exp(f[i, None] + g[None, j] - C / reg) From 1d74e71a6ca5f1babeb27daab08bcea621eebb88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Mar 2025 12:43:35 +0100 Subject: [PATCH 09/16] better verison quickstart guide --- examples/plot_quickstart_guide.py | 217 ++++++++++++++++++++++-------- 1 file changed, 159 insertions(+), 58 deletions(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 2cf52ea32..14a9916c0 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -31,8 +31,8 @@ # %% -# Data generation -# -------------- +# 2D data example +# --------------- # # We first generate two sets of samples in 2D that 25 and 50 # samples respectively located on circles. The weights of the samples are @@ -53,15 +53,34 @@ x2 = np.random.randn(n2, 2) x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4 +# Compute the cost matrix +C = ot.dist(x1, x2) # Squared Euclidean cost matrix by default + # sphinx_gallery_start_ignore style = {"markeredgecolor": "k"} + +def plot_plan(P=None, title="", axis=True): + plot2D_samples_mat(x1, x2, P) + pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) + pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) + if not axis: + pl.axis("off") + pl.title(title) + + pl.figure(1, (4, 4)) pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) pl.legend(loc=0) pl.title("Source and target distributions") pl.show() + +pl.figure(2, (3.5, 1.7)) +pl.imshow(C) +pl.colorbar() +pl.title("Cost matrix C") + # sphinx_gallery_end_ignore # %% @@ -139,8 +158,8 @@ # %% -# Solve the Optimal Transport problem with a custom cost matrix -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Optimal Transport problem with a custom cost matrix +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The cost matrix can be customized by passing it to the more general # :func:`ot.solve` function. The cost matrix should be a matrix of size @@ -150,14 +169,17 @@ # In this example, we use the Citybloc distance as the cost matrix. # Compute the cost matrix -C = ot.dist(x1, x2, metric="cityblock") +C_city = ot.dist(x1, x2, metric="cityblock") # Solve the OT problem with the custom cost matrix -P_city = ot.solve(C).plan +sol = ot.solve(C_city) # the parameters a and b are not provided so uniform weights are assumed +P_city = sol.plan +# on empirical data the same can be done with ot.solve_sample : +# sol = ot.solve_sample(x1, x2, metric='cityblock') # Compute the OT loss (equivalent to ot.solve(C).value) -loss_city = np.sum(P_city * C) +loss_city = sol.value # same as np.sum(P_city * C) # sphinx_gallery_start_ignore pl.figure(1, (3, 3)) @@ -192,9 +214,7 @@ # P = ot.emd(a, b, C) # loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b # - - -# %% +# # Sinkhorn and Regularized OT # --------------------------- # @@ -229,8 +249,7 @@ # The Sinkhorn algorithm can be faster than the exact OT solver for large # regularization strength but the solution is only an approximation of the # exact OT problem and the OT plan is not sparse. - -# %% +# # Quadratic Regularized OT # ~~~~~~~~~~~~~~~~~~~~~~~~~ # @@ -281,8 +300,7 @@ def df(G): return G -P_reg = ot.solve_sample(x1, x2, a, b, reg=1e2, reg_type=(f, df)).plan - +P_reg = ot.solve_sample(x1, x2, a, b, reg=3, reg_type=(f, df)).plan # sphinx_gallery_start_ignore pl.figure(1, (3, 3)) @@ -312,7 +330,7 @@ def df(G): # Unbalanced and Partial OT # ---------------------------- # -# Solve the Unbalanced OT problem +# Unbalanced Optimal Transport # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # Unbalanced OT relaxes the marginal constraints and allows for the source and @@ -393,10 +411,10 @@ def df(G): # sphinx_gallery_end_ignore # %% # -# Gromov-Wasserstein and Fused GW +# Gromov-Wasserstein and Fused Gromov-Wasserstein # ------------------------------------- # -# Solve the Gromov-Wasserstein problem +# Gromov-Wasserstein and Entropic GW # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The Gromov-Wasserstein distance is a similarity measure between metric @@ -414,8 +432,7 @@ def df(G): # Solve the Gromov-Wasserstein problem sol_gw = ot.solve_gromov(C1, C2, a=a, b=b) P_gw = sol_gw.plan -loss_gw = sol_gw.value -loss_gw_linear = sol_gw.value_linear # linear part of loss +loss_gw = sol_gw.value # quadratic + reg if reg>0 loss_gw_quad = sol_gw.value_quad # quadratic part of loss # Solve the Entropic Gromov-Wasserstein problem @@ -460,9 +477,13 @@ def df(G): M = C / np.max(C) # Solve FGW problem with alpha=0.1 -P_fgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1).plan # C is cost across spaces +sol = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1) +P_fgw = sol.plan +loss_fgw = sol.value +loss_fgw_linear = sol.value_linear # linear part of loss (wrt M) +loss_fgw_quad = sol.value_quad # quadratic part of loss (wrt C1 and C2) -# SOlve entropic FGW problem with alpha=0.1 +# Solve entropic FGW problem with alpha=0.1 P_efgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1, reg=1e-3).plan # sphinx_gallery_start_ignore @@ -497,35 +518,6 @@ def df(G): # loss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1) # loss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg) # - -# # Unbalanced Gromov-Wasserstein -# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# # -# -# # Solve the Unbalanced Gromov-Wasserstein problem -# P_gw_unb = ot.solve_gromov(C1, C2, a=a, b=b, unbalanced=1e-2).plan -# -# # Solve the Unbalanced Entropic Gromov-Wasserstein problem -# P_egw_unb = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2, reg_type='KL', unbalanced=1e-2).plan -# -# # sphinx_gallery_start_ignore -# pl.figure(1, (6, 3)) -# -# pl.subplot(1, 2, 1) -# plot2D_samples_mat(x1, x2, P_gw_unb) -# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -# pl.title("Unbalanced GW plan") -# -# pl.subplot(1, 2, 2) -# plot2D_samples_mat(x1, x2, P_egw_unb) -# pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -# pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -# pl.title("Unbalanced Entropic GW plan") -# pl.show() -# # sphinx_gallery_end_ignore -# %% -# # Large scale OT # -------------- # @@ -557,9 +549,8 @@ def df(G): # recover values for Lazy plan P12 = P_sink_lazy[1, 2] P1dots = P_sink_lazy[1, :] -P_sink_lazy_dense = P_sink_lazy[ - : -] # convert to dense matrix !!warning this can be memory consuming +# convert to dense matrix !!warning this can be memory consuming +P_sink_lazy_dense = P_sink_lazy[:] # sphinx_gallery_start_ignore pl.figure(1, (3, 3)) @@ -575,8 +566,13 @@ def df(G): pl.show() # sphinx_gallery_end_ignore -# # %% +# .. note:: +# The lazy Sinkhorn algorithm can be found in the old API with the +# :func:`ot.bregman.empirical_sinkhorn` function with parameter +# :code:`lazy=True`. Similarly the geoloss implementation is available +# with the :func:`ot.bregman.empirical_sinkhorn2_geomloss`. +# # # the first example shows how to solve the Sinkhorn problem in a lazy way with # the default POT implementation. The second example shows how to solve the @@ -585,7 +581,7 @@ def df(G): # samples. # # Factored and Low rank OT -# ------------------------ +# ~~~~~~~~~~~~~~~~~~~~~~~~ # # The Sinkhorn algorithm can be implemented in a low rank version that # approximates the OT plan with a low rank matrix. This can be useful to @@ -594,9 +590,9 @@ def df(G): # # Solve the Factored OT problem (use lazy=True for large scale) -P_fact = ot.solve_sample(x1, x2, a, b, method="factored", rank=8).plan +P_fact = ot.solve_sample(x1, x2, a, b, method="factored", rank=15).plan -P_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method="lowrank", rank=8).plan +P_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method="lowrank", rank=10).plan # sphinx_gallery_start_ignore pl.figure(1, (6, 3)) @@ -626,8 +622,11 @@ def df(G): pl.show() # sphinx_gallery_end_ignore - # %% +# .. note:: +# The factored OT problem can be solved with the old API using the +# :func:`ot.factored.factored_optimal_transport` function and the low rank +# OT problem can be solved with the :func:`ot.lowrank.lowrank_sinkhorn` function. # # Gaussian OT with Bures-Wasserstein # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -641,4 +640,106 @@ def df(G): # Compute the Bures-Wasserstein distance bw_value = ot.solve_sample(x1, x2, a, b, method="gaussian").value +print(f"Exact OT loss = {loss:1.3f}") print(f"Bures-Wasserstein distance = {bw_value:1.3f}") + +# %% +# .. note:: +# The Gaussian Wasserstein problem can be solved with the old API using the +# :func:`ot.gaussian.empirical_bures_wasserstein_distance` function. +# +# All OT plans +# ------------ +# +# The figure below shows all the OT plans computed in this example. +# The color intensity represents the amount of mass transported +# between the samples. +# + +# sphinx_gallery_start_ignore +pl.figure(1, (9, 13)) + + +pl.subplot(4, 3, 1) +plot2D_samples_mat(x1, x2, P) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("OT plan") + +pl.subplot(4, 3, 2) +plot2D_samples_mat(x1, x2, P_sink) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Sinkhorn plan") + +pl.subplot(4, 3, 3) +plot2D_samples_mat(x1, x2, P_quad) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Quadratic reg. plan") + +pl.subplot(4, 3, 4) +plot2D_samples_mat(x1, x2, P_unb_kl) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Unbalanced KL plan") + +pl.subplot(4, 3, 5) +plot2D_samples_mat(x1, x2, P_unb_kl_reg) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Unbalanced KL + reg plan") + +pl.subplot(4, 3, 6) +plot2D_samples_mat(x1, x2, P_unb_l2) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Unbalanced L2 plan") + +pl.subplot(4, 3, 7) +plot2D_samples_mat(x1, x2, P_part_const) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Partial 50% mass plan") + +pl.subplot(4, 3, 8) +plot2D_samples_mat(x1, x2, P_fact) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Factored OT plan") + +pl.subplot(4, 3, 9) +plot2D_samples_mat(x1, x2, P_lowrank) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Low rank OT plan") + +pl.subplot(4, 3, 10) +plot2D_samples_mat(x1, x2, P_gw) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("GW plan") + +pl.subplot(4, 3, 11) +plot2D_samples_mat(x1, x2, P_egw) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Entropic GW plan") + +pl.subplot(4, 3, 12) +plot2D_samples_mat(x1, x2, P_fgw) +pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) +pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +pl.axis("off") +pl.title("Fused GW plan") From 8b1f8beeb5361f6f14282003e13a96e6258ac6f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Mar 2025 12:44:19 +0100 Subject: [PATCH 10/16] better verison quickstart guide --- examples/plot_quickstart_guide.py | 178 ++++++------------------------ 1 file changed, 36 insertions(+), 142 deletions(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 14a9916c0..e31ca1807 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -61,7 +61,8 @@ def plot_plan(P=None, title="", axis=True): - plot2D_samples_mat(x1, x2, P) + if P is not None: + plot2D_samples_mat(x1, x2, P) pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) if not axis: @@ -70,10 +71,8 @@ def plot_plan(P=None, title="", axis=True): pl.figure(1, (4, 4)) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) +plot_plan(title="Source and target distributions") pl.legend(loc=0) -pl.title("Source and target distributions") pl.show() pl.figure(2, (3.5, 1.7)) @@ -114,10 +113,7 @@ def plot_plan(P=None, title="", axis=True): pl.figure(1, (8, 4)) pl.subplot(1, 2, 1) -plot2D_samples_mat(x1, x2, P) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("OT plan P loss={:.3f}".format(loss)) +plot_plan(P, "OT plan P loss={:.3f}".format(loss)) pl.subplot(1, 2, 2) pl.scatter(x1[:, 0], x1[:, 1], c=alpha, cmap="viridis", edgecolors="k") @@ -183,10 +179,7 @@ def plot_plan(P=None, title="", axis=True): # sphinx_gallery_start_ignore pl.figure(1, (3, 3)) -plot2D_samples_mat(x1, x2, P) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("OT plan (Citybloc) loss={:.3f}".format(loss_city)) +plot_plan(P_city, "OT plan (Citybloc) loss={:.3f}".format(loss_city)) pl.figure(2, (3, 1.7)) pl.imshow(P_city, cmap="Greys") @@ -232,10 +225,7 @@ def plot_plan(P=None, title="", axis=True): # sphinx_gallery_start_ignore pl.figure(1, (3, 3)) -plot2D_samples_mat(x1, x2, P_sink) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Sinkhorn OT plan loss={:.3f}".format(loss_sink)) +plot_plan(P_sink, "Sinkhorn OT plan loss={:.3f}".format(loss_sink)) pl.show() pl.figure(2, (3, 1.7)) @@ -263,22 +253,13 @@ def plot_plan(P=None, title="", axis=True): pl.figure(1, (9, 3)) pl.subplot(1, 3, 1) -plot2D_samples_mat(x1, x2, P) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("OT plan loss={:.3f}".format(loss)) +plot_plan(P, "OT plan loss={:.3f}".format(loss)) pl.subplot(1, 3, 2) -plot2D_samples_mat(x1, x2, P_sink) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Sinkhorn plan loss={:.3f}".format(loss_sink)) +plot_plan(P_sink, "Sinkhorn plan loss={:.3f}".format(loss_sink)) pl.subplot(1, 3, 3) -plot2D_samples_mat(x1, x2, P_quad) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Quadratic plan loss={:.3f}".format(loss_quad)) +plot_plan(P_quad, "Quadratic reg plan loss={:.3f}".format(loss_quad)) pl.show() # sphinx_gallery_end_ignore # %% @@ -304,10 +285,7 @@ def df(G): # sphinx_gallery_start_ignore pl.figure(1, (3, 3)) -plot2D_samples_mat(x1, x2, P_reg) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Custom reg plan") +plot_plan(P_reg, "User-defined reg plan") pl.show() # sphinx_gallery_end_ignore # %% @@ -354,23 +332,13 @@ def df(G): pl.figure(1, (9, 3)) pl.subplot(1, 3, 1) -plot2D_samples_mat(x1, x2, P_unb_kl) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Unbalanced KL plan") +plot_plan(P_unb_kl, "Unbalanced KL plan") pl.subplot(1, 3, 2) -plot2D_samples_mat(x1, x2, P_unb_kl_reg) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Unbalanced KL + reg plan") +plot_plan(P_unb_kl_reg, "Unbalanced KL + reg plan") pl.subplot(1, 3, 3) -plot2D_samples_mat(x1, x2, P_unb_l2) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Unbalanced L2 plan") - +plot_plan(P_unb_l2, "Unbalanced L2 plan") pl.show() # sphinx_gallery_end_ignore # %% @@ -396,16 +364,10 @@ def df(G): pl.figure(1, (6, 3)) pl.subplot(1, 2, 1) -plot2D_samples_mat(x1, x2, P_part_pen) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Partial (Unb. TV) plan") +plot_plan(P_part_pen, "Partial TV plan") pl.subplot(1, 2, 2) -plot2D_samples_mat(x1, x2, P_part_const) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Partial 50% mass plan") +plot_plan(P_part_const, "Partial 50% mass plan") pl.show() # sphinx_gallery_end_ignore @@ -442,16 +404,10 @@ def df(G): pl.figure(1, (6, 3)) pl.subplot(1, 2, 1) -plot2D_samples_mat(x1, x2, P_gw) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("GW plan") +plot_plan(P_gw, "GW plan") pl.subplot(1, 2, 2) -plot2D_samples_mat(x1, x2, P_egw) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Entropic GW plan") +plot_plan(P_egw, "Entropic GW plan") pl.show() # sphinx_gallery_end_ignore # %% @@ -490,16 +446,10 @@ def df(G): pl.figure(1, (6, 3)) pl.subplot(1, 2, 1) -plot2D_samples_mat(x1, x2, P_fgw) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("FGW plan") +plot_plan(P_fgw, "FGW plan") pl.subplot(1, 2, 2) -plot2D_samples_mat(x1, x2, P_efgw) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Entropic FGW plan") +plot_plan(P_efgw, "Entropic FGW plan") pl.show() # sphinx_gallery_end_ignore @@ -554,10 +504,7 @@ def df(G): # sphinx_gallery_start_ignore pl.figure(1, (3, 3)) -plot2D_samples_mat(x1, x2, P_sink_lazy_dense) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Lazy Sinkhorn OT plan") +plot_plan(P_sink_lazy_dense, "Lazy Sinkhorn OT plan") pl.show() pl.figure(2, (3, 1.7)) @@ -598,16 +545,10 @@ def df(G): pl.figure(1, (6, 3)) pl.subplot(1, 2, 1) -plot2D_samples_mat(x1, x2, P_fact) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Factored OT plan") +plot_plan(P_fact, "Factored OT plan") pl.subplot(1, 2, 2) -plot2D_samples_mat(x1, x2, P_lowrank) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.title("Low rank OT plan") +plot_plan(P_lowrank, "Low rank OT plan") pl.show() pl.figure(2, (6, 1.7)) @@ -659,87 +600,40 @@ def df(G): # sphinx_gallery_start_ignore pl.figure(1, (9, 13)) - pl.subplot(4, 3, 1) -plot2D_samples_mat(x1, x2, P) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("OT plan") +plot_plan(P, "OT plan", axis=False) pl.subplot(4, 3, 2) -plot2D_samples_mat(x1, x2, P_sink) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Sinkhorn plan") +plot_plan(P_sink, "Sinkhorn plan", axis=False) pl.subplot(4, 3, 3) -plot2D_samples_mat(x1, x2, P_quad) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Quadratic reg. plan") +plot_plan(P_quad, "Quadratic reg. plan", axis=False) pl.subplot(4, 3, 4) -plot2D_samples_mat(x1, x2, P_unb_kl) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Unbalanced KL plan") +plot_plan(P_unb_kl, "Unbalanced KL plan", axis=False) pl.subplot(4, 3, 5) -plot2D_samples_mat(x1, x2, P_unb_kl_reg) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Unbalanced KL + reg plan") +plot_plan(P_unb_kl_reg, "Unbalanced KL + reg plan", axis=False) pl.subplot(4, 3, 6) -plot2D_samples_mat(x1, x2, P_unb_l2) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Unbalanced L2 plan") +plot_plan(P_unb_l2, "Unbalanced L2 plan", axis=False) pl.subplot(4, 3, 7) -plot2D_samples_mat(x1, x2, P_part_const) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Partial 50% mass plan") +plot_plan(P_part_const, "Partial 50% mass plan", axis=False) pl.subplot(4, 3, 8) -plot2D_samples_mat(x1, x2, P_fact) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Factored OT plan") +plot_plan(P_fact, "Factored OT plan", axis=False) pl.subplot(4, 3, 9) -plot2D_samples_mat(x1, x2, P_lowrank) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Low rank OT plan") +plot_plan(P_lowrank, "Low rank OT plan", axis=False) pl.subplot(4, 3, 10) -plot2D_samples_mat(x1, x2, P_gw) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("GW plan") +plot_plan(P_gw, "GW plan", axis=False) pl.subplot(4, 3, 11) -plot2D_samples_mat(x1, x2, P_egw) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Entropic GW plan") +plot_plan(P_egw, "Entropic GW plan", axis=False) pl.subplot(4, 3, 12) -plot2D_samples_mat(x1, x2, P_fgw) -pl.plot(x1[:, 0], x1[:, 1], "ob", label="Source samples", **style) -pl.plot(x2[:, 0], x2[:, 1], "or", label="Target samples", **style) -pl.axis("off") -pl.title("Fused GW plan") +plot_plan(P_fgw, "Fused GW plan", axis=False) + +pl.show() From 43a6f6326b9bb36c7e0124d2c9a8bae75a6a1aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Mar 2025 13:03:07 +0100 Subject: [PATCH 11/16] fix doc --- docs/source/user_guide.rst | 66 +++++++++++++++---------------- examples/plot_quickstart_guide.py | 7 ++-- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/docs/source/user_guide.rst b/docs/source/user_guide.rst index 25e9d4213..bceea3b6f 100644 --- a/docs/source/user_guide.rst +++ b/docs/source/user_guide.rst @@ -136,12 +136,12 @@ instance the memory cost for an OT problem is always :math:`\mathcal{O}(n^2)` in memory because the cost matrix has to be computed. The exact solver in of time complexity :math:`\mathcal{O}(n^3\log(n))` and the Sinkhorn solver has been proven to be nearly :math:`\mathcal{O}(n^2)` which is still too complex for very -large scale solvers. +large scale solvers. For all the generic solvers we need to compute the cost +matrix and the OT matrix of memory size :math:`\mathcal{O}(n^2)` which can be +prohibitive for very large scale problems. - -If you need to solve OT with large number of samples, we recommend to use -entropic regularization and memory efficient implementation of Sinkhorn as -proposed in `GeomLoss `_. This +If you need to solve OT with large number of samples, we provide "lazy" memory efficient implementation of Sinkhorn in pure +python and using `GeomLoss `_. This implementation is compatible with Pytorch and can handle large number of samples. Another approach to estimate the Wasserstein distance for very large number of sample is to use the trick from `Wasserstein GAN @@ -193,6 +193,11 @@ that will return the optimal transport matrix :math:`\gamma^*`: # a and b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix + + # unified API + T = ot.solve(M, a, b).plan # exact linear program + + # old API T = ot.emd(a, b, M) # exact linear program The method implemented for solving the OT problem is the network simplex. It is @@ -200,8 +205,7 @@ implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the solver is quite efficient and uses sparsity of the solution. - -.. minigallery:: ot.emd +.. minigallery:: ot.emd, ot.solve :add-heading: Examples of use for :any:`ot.emd` :heading-level: " @@ -226,7 +230,12 @@ It can computed from an already estimated OT matrix with # a and b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix - W = ot.emd2(a, b, M) # Wasserstein distance / EMD value + + # Wasserstein distance / EMD value with unified API + W = ot.solve(M, a, b, return_matrix=False).value + + # with old API + W = ot.emd2(a, b, M) Note that the well known `Wasserstein distance `_ between distributions a and @@ -246,7 +255,7 @@ the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2` when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean distance. -.. minigallery:: ot.emd2 +.. minigallery:: ot.emd2, ot.solve :add-heading: Examples of use for :any:`ot.emd2` :heading-level: " @@ -274,6 +283,10 @@ distributions. In the case when the finite sample dataset is supposed Gaussian, we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the Monge mapping. +All those special cases are accessible with the unified API of POT through the +function :any:`ot.solve_sample` with the parameter :code:`method` that allows to +choose the method used to solve the problem (with :code:`method='1D'` or :code:`method='gaussian'`). + Regularized Optimal Transport ----------------------------- @@ -330,13 +343,15 @@ The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and linear term. Note that the regularization parameter :math:`\lambda` in the equation above is given to those functions with the parameter :code:`reg`. - >>> import ot - >>> a = [.5, .5] - >>> b = [.5, .5] - >>> M = [[0., 1.], [1., 0.]] - >>> ot.sinkhorn(a, b, M, 1) - array([[ 0.36552929, 0.13447071], - [ 0.13447071, 0.36552929]]) +.. code:: python + + # unified API + P = ot.solve(M, a, b, reg=1).plan # OT Sinkhorn matrix + loss = ot.solve(M, a, b, reg=1).value # OT Sinkhorn value + + # old API + P = ot.sinkhorn(a, b, M, reg=1) # OT Sinkhorn matrix + loss = ot.sinkhorn2(a, b, M, reg=1) # OT Sinkhorn value More details about the algorithms used are given in the following note. @@ -406,13 +421,10 @@ implementations are not optimized for speed but provide a robust implementation of algorithms in [18]_ [19]_. -.. minigallery:: ot.sinkhorn - :add-heading: Examples of use for :any:`ot.sinkhorn` +.. minigallery:: ot.sinkhorn ot.sinkhorn2 + :add-heading: Examples of use for Sinkhorn algorithm :heading-level: " -.. minigallery:: ot.sinkhorn2 - :add-heading: Examples of use for :any:`ot.sinkhorn2` - :heading-level: " Other regularizations @@ -969,18 +981,6 @@ For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW= It's important to note that the `numpy` backend cannot be disabled. -List of compatible modules -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This list will get longer for new releases and will hopefully disappear when POT -become fully implemented with the backend. - -- :any:`ot.bregman` -- :any:`ot.gromov` (some functions use CPU only solvers with copy overhead) -- :any:`ot.optim` (some functions use CPU only solvers with copy overhead) -- :any:`ot.sliced` -- :any:`ot.utils` (partial) - FAQ --- diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index e31ca1807..2f3a4aa59 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -11,7 +11,7 @@ with matplotlib is hidden (but is available in the source file of the example). .. note:: - We use here the new API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The old API is still available (the new + We use here the unified API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The old API is still available (the unified API one is a convenient wrapper around the old one) and we provide pointers to the old API when needed. @@ -198,7 +198,7 @@ def plot_plan(P=None, title="", axis=True): # directly as a matrix by the user when no samples are available. # # .. note:: -# The examples above use the new API of POT. The old API is still available +# The examples above use the unified API of POT. The old API is still available # and and OT plan and loss can be computed with the :func:`ot.emd` and # the :func:`ot.emd2` functions as below: # @@ -291,7 +291,7 @@ def df(G): # %% # # .. note:: -# The examples above use the new API of POT. The old API is still available +# The examples above use the unified API of POT. The old API is still available # and and the entropic OT plan and loss can be computed with the # :func:`ot.sinkhorn` # and :func:`ot.sinkhorn2` functions as below: # @@ -637,3 +637,4 @@ def df(G): plot_plan(P_fgw, "Fused GW plan", axis=False) pl.show() +# sphinx_gallery_end_ignore From c1e00dd5928f18c744900456268da226840ac4a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Mar 2025 14:07:25 +0100 Subject: [PATCH 12/16] cleanup exmaple --- docs/source/user_guide.rst | 6 +++--- examples/plot_quickstart_guide.py | 16 +++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/source/user_guide.rst b/docs/source/user_guide.rst index bceea3b6f..69d48528b 100644 --- a/docs/source/user_guide.rst +++ b/docs/source/user_guide.rst @@ -197,7 +197,7 @@ that will return the optimal transport matrix :math:`\gamma^*`: # unified API T = ot.solve(M, a, b).plan # exact linear program - # old API + # classical API T = ot.emd(a, b, M) # exact linear program The method implemented for solving the OT problem is the network simplex. It is @@ -234,7 +234,7 @@ It can computed from an already estimated OT matrix with # Wasserstein distance / EMD value with unified API W = ot.solve(M, a, b, return_matrix=False).value - # with old API + # with classical API W = ot.emd2(a, b, M) Note that the well known `Wasserstein distance @@ -349,7 +349,7 @@ equation above is given to those functions with the parameter :code:`reg`. P = ot.solve(M, a, b, reg=1).plan # OT Sinkhorn matrix loss = ot.solve(M, a, b, reg=1).value # OT Sinkhorn value - # old API + # classical API P = ot.sinkhorn(a, b, M, reg=1) # OT Sinkhorn matrix loss = ot.sinkhorn2(a, b, M, reg=1) # OT Sinkhorn value diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 2f3a4aa59..08ecb2077 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -11,9 +11,9 @@ with matplotlib is hidden (but is available in the source file of the example). .. note:: - We use here the unified API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The old API is still available (the unified API - one is a convenient wrapper around the old one) and we provide pointers to the - old API when needed. + We use here the unified API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The classical API is still available (the unified API + one is a convenient wrapper around the classical one) and we provide pointers to the + classical API when needed. """ @@ -589,14 +589,13 @@ def df(G): # The Gaussian Wasserstein problem can be solved with the old API using the # :func:`ot.gaussian.empirical_bures_wasserstein_distance` function. # -# All OT plans -# ------------ +# Comparing all OT plans +# ---------------------- # # The figure below shows all the OT plans computed in this example. # The color intensity represents the amount of mass transported # between the samples. # - # sphinx_gallery_start_ignore pl.figure(1, (9, 13)) @@ -635,6 +634,9 @@ def df(G): pl.subplot(4, 3, 12) plot_plan(P_fgw, "Fused GW plan", axis=False) - pl.show() + # sphinx_gallery_end_ignore +# %% +# +# For more details on the unified and classical API of POT From 4f66eac4b58e1be7a3b2f1bdd90d5f580bc81db0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Mar 2025 14:23:33 +0100 Subject: [PATCH 13/16] fix? --- examples/plot_quickstart_guide.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 08ecb2077..39e5d369c 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -596,6 +596,9 @@ def df(G): # The color intensity represents the amount of mass transported # between the samples. # + +# plot all plans + # sphinx_gallery_start_ignore pl.figure(1, (9, 13)) From f7f09a9a3da40eea0f512ac82fbb5a8dadf8ec44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Mar 2025 14:39:18 +0100 Subject: [PATCH 14/16] add test --- test/test_bregman.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6c0c0e8f2..17b400306 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1253,6 +1253,11 @@ def test_lazy_empirical_sinkhorn(nx): X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=False ) + # test lazy plan + np.testing.assert_allclose( + G_sqe[1, 1], G_lazy[1, 1], atol=1e-03 + ) # metric sqeuclidian + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05 From 3aa6184457f5e2ae5cfb06f5b67ad01b7724ab93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Mar 2025 15:10:16 +0100 Subject: [PATCH 15/16] remove sentence --- examples/plot_quickstart_guide.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 39e5d369c..9b24e6a89 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -642,4 +642,3 @@ def df(G): # sphinx_gallery_end_ignore # %% # -# For more details on the unified and classical API of POT From e68f54e5569fad3b983dfb73c46c5f3a6be86581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Mar 2025 15:12:27 +0100 Subject: [PATCH 16/16] call it the unified vs classic API --- examples/plot_quickstart_guide.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/plot_quickstart_guide.py b/examples/plot_quickstart_guide.py index 9b24e6a89..fb86ac5bd 100644 --- a/examples/plot_quickstart_guide.py +++ b/examples/plot_quickstart_guide.py @@ -198,7 +198,7 @@ def plot_plan(P=None, title="", axis=True): # directly as a matrix by the user when no samples are available. # # .. note:: -# The examples above use the unified API of POT. The old API is still available +# The examples above use the unified API of POT. The classic API is still available # and and OT plan and loss can be computed with the :func:`ot.emd` and # the :func:`ot.emd2` functions as below: # @@ -291,7 +291,7 @@ def df(G): # %% # # .. note:: -# The examples above use the unified API of POT. The old API is still available +# The examples above use the unified API of POT. The classic API is still available # and and the entropic OT plan and loss can be computed with the # :func:`ot.sinkhorn` # and :func:`ot.sinkhorn2` functions as below: # @@ -343,7 +343,7 @@ def df(G): # sphinx_gallery_end_ignore # %% # .. note:: -# Solving the unbalanced OT problem with the old API can be done with the +# Solving the unbalanced OT problem with the classic API can be done with the # :func:`ot.unbalanced.sinkhorn_unbalanced` function as below: # # .. code-block:: python @@ -357,7 +357,7 @@ def df(G): # Solve the Unbalanced OT problem with TV penalization (equivalent) P_part_pen = ot.solve_sample(x1, x2, a, b, unbalanced=3, unbalanced_type="TV").plan -# Solve the Partial OT problem with mass constraints (only old API) +# Solve the Partial OT problem with mass constraints (only classic API) P_part_const = ot.partial.partial_wasserstein(a, b, C, m=0.5) # 50% mass transported # sphinx_gallery_start_ignore @@ -412,7 +412,7 @@ def df(G): # sphinx_gallery_end_ignore # %% # .. note:: -# The Gromov-Wasserstein problem can be solved with the old API using the +# The Gromov-Wasserstein problem can be solved with the classic API using the # :func:`ot.gromov.gromov_wasserstein` function and the Entropic # Gromov-Wasserstein problem can be solved with the # :func:`ot.gromov.entropic_gromov_wasserstein` function. @@ -455,7 +455,7 @@ def df(G): # sphinx_gallery_end_ignore # %% # .. note:: -# The Fused Gromov-Wasserstein problem can be solved with the old API using +# The Fused Gromov-Wasserstein problem can be solved with the classic API using # the :func:`ot.gromov.fused_gromov_wasserstein` function and the Entropic # Fused Gromov-Wasserstein problem can be solved with the # :func:`ot.gromov.entropic_fused_gromov_wasserstein` function. @@ -515,7 +515,7 @@ def df(G): # sphinx_gallery_end_ignore # %% # .. note:: -# The lazy Sinkhorn algorithm can be found in the old API with the +# The lazy Sinkhorn algorithm can be found in the classic API with the # :func:`ot.bregman.empirical_sinkhorn` function with parameter # :code:`lazy=True`. Similarly the geoloss implementation is available # with the :func:`ot.bregman.empirical_sinkhorn2_geomloss`. @@ -565,7 +565,7 @@ def df(G): # sphinx_gallery_end_ignore # %% # .. note:: -# The factored OT problem can be solved with the old API using the +# The factored OT problem can be solved with the classic API using the # :func:`ot.factored.factored_optimal_transport` function and the low rank # OT problem can be solved with the :func:`ot.lowrank.lowrank_sinkhorn` function. # @@ -586,7 +586,7 @@ def df(G): # %% # .. note:: -# The Gaussian Wasserstein problem can be solved with the old API using the +# The Gaussian Wasserstein problem can be solved with the classic API using the # :func:`ot.gaussian.empirical_bures_wasserstein_distance` function. # # Comparing all OT plans