Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
f91daac
WIP GMM
eloitanguy Jun 27, 2024
38822a1
wip test
eloitanguy Jun 27, 2024
19481f4
basic tests
eloitanguy Jun 27, 2024
29c314a
wip gmm flow example
eloitanguy Jun 27, 2024
fdd526d
Merge branch 'master' into master
rflamary Jul 10, 2024
161faa9
batchable proj_SDP
eloitanguy Jul 16, 2024
49531f5
linting
eloitanguy Jul 16, 2024
5ea4d11
fetched latest POT changes
eloitanguy Jul 16, 2024
26bb200
linting
eloitanguy Jul 16, 2024
190b729
gaussian + gmm pdfs, faster GMMOT plans and barycentric mappings
eloitanguy Jul 16, 2024
fbe2a81
Merge branch 'master' into master
cedricvincentcuaz Jul 16, 2024
915c9a9
Trand + docs and test
eloitanguy Jul 17, 2024
804d46f
Merge remote-tracking branch 'upstream/master'
eloitanguy Jul 17, 2024
f27ca06
Merge branch 'master' of github.com:eloitanguy/POT
eloitanguy Jul 17, 2024
2ac5c96
updated references, docs, RELEASES.md, CONTRIBUTORS.md and the README…
eloitanguy Jul 17, 2024
528fb6c
fixed mean computations in ot.gaussian
eloitanguy Jul 17, 2024
8ac2c11
all dosctrings as raw strings
eloitanguy Jul 17, 2024
42ca0d9
linting
eloitanguy Jul 17, 2024
7a6fbb2
fixed PR number
eloitanguy Jul 18, 2024
57217c5
nx.det WIP
eloitanguy Jul 18, 2024
3c6f4c9
gmm flow example
judelo Jul 18, 2024
5cf6bad
plot_GMMOT_plan unfinished
judelo Jul 18, 2024
62be757
ot.plot.plot1D_mat update + finished GMM example
eloitanguy Jul 19, 2024
3949973
Merge branch 'master' into master
cedricvincentcuaz Jul 19, 2024
807211a
finished GMM plan example + nx.det test + ot.plot.plot1D_mat update a…
eloitanguy Jul 22, 2024
b2cce0f
fetch upstream
eloitanguy Jul 22, 2024
db530c3
updated RELEASES.md
eloitanguy Jul 22, 2024
1a0f06d
linting
eloitanguy Jul 22, 2024
8ca9604
minor PR suggestions + jax det fix
eloitanguy Jul 22, 2024
060bbc8
batched nx.sqrtm
eloitanguy Jul 23, 2024
72bd04d
any-broadcastable sqrtm
eloitanguy Jul 23, 2024
e2681b2
einsum and optimisations in ot.gmm
eloitanguy Jul 23, 2024
edbf561
doc tweak
eloitanguy Jul 23, 2024
68d4324
linting fix
eloitanguy Jul 23, 2024
6a77050
gmm ot plan density test
eloitanguy Jul 23, 2024
d6b6474
plot1D_mat tweaks for similarity with old version by default + gmm_de…
eloitanguy Jul 30, 2024
3846143
gmm_density mask without slicing
eloitanguy Jul 30, 2024
98a7fba
fixed erroneuous simplex projection in GMM_flow example
eloitanguy Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ The contributors to this library are:
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein
Barycenters, GMMOT)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
Expand All @@ -51,6 +52,7 @@ The contributors to this library are:
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples)
* [Julie Delon](https://judelo.github.io/) (GMM OT)

## Acknowledgments

Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ POT provides the following generic OT solvers (links to examples):
* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68].
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
* Smooth Strongly Convex Nearest Brenier Potentials [58], with an extension to bounding potentials using [59].
* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59].
* Gaussian Mixture Model OT [69]

POT provides the following Machine Learning related solvers:

Expand Down Expand Up @@ -368,3 +369,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time Gromov-Wasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022.

[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing.

[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970.
9 changes: 8 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@

#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
- Gaussian Mixture Model OT `ot.gmm` (PR #649)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter
updates `update_barycenter_structure` and `update_barycenter_feature` (PR
#659)
- Improved `ot.plot.plot1D_mat` (PR #649)
- Added `nx.det` (PR #649)
- `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649)

#### Closed issues
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)

## 0.9.4
*June 2024*
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ API and modules
dr
factored
gaussian
gmm
gnn
gromov
lowrank
Expand Down
88 changes: 88 additions & 0 deletions examples/others/plot_GMMOT_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# %%
# -*- coding: utf-8 -*-
r"""
====================================================
GMM Plan 1D
====================================================

Illustration of the GMM plan for
the Mixture Wasserstein between two GMM in 1D,
as well as the two maps T_mean and T_rand.
T_mean is the barycentric projection of the GMM coupling,
and T_rand takes a random gaussian image between two components,
according to the coupling and the GMMs.
See [69] for details.
.. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.

"""

# Author: Eloi Tanguy <eloi.tanguy@u-paris>
# Remi Flamary <remi.flamary@polytehnique.edu>
# Julie Delon <julie.delon@math.cnrs.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

import numpy as np
from ot.plot import plot1D_mat, rescale_for_imshow_plot
from ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map
import matplotlib.pyplot as plt

##############################################################################
# Generate GMMOT plan plot it
# ---------------------------
ks = 2
kt = 3
d = 1
eps = 0.1
m_s = np.array([[1], [2]])
m_t = np.array([[3], [4.2], [5]])
C_s = np.array([[[.05]], [[.06]]])
C_t = np.array([[[.03]], [[.07]], [[.04]]])
w_s = np.array([.4, .6])
w_t = np.array([.4, .2, .4])

n = 500
a_x, b_x = 0, 3
x = np.linspace(a_x, b_x, n)
a_y, b_y = 2, 6
y = np.linspace(a_y, b_y, n)
plan_density = gmm_ot_plan_density(x[:, None], y[:, None],
m_s, m_t, C_s, C_t, w_s, w_t,
plan=None, atol=2e-2)

a = gmm_pdf(x[:, None], m_s, C_s, w_s)
b = gmm_pdf(y[:, None], m_t, C_t, w_t)
plt.figure(figsize=(8, 8))
plot1D_mat(a, b, plan_density, title='GMM OT plan', plot_style='xy',
a_label='Source distribution', b_label='Target distribution')


##############################################################################
# Generate GMMOT maps and plot them over plan
# -------------------------------------------
plt.figure(figsize=(8, 8))
ax_s, ax_t, ax_M = plot1D_mat(a, b, plan_density, plot_style='xy',
title='GMM OT plan with T_mean and T_rand maps',
a_label='Source distribution',
b_label='Target distribution')
T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t,
w_s, w_t, method='bary')[:, 0]
x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n,
a_y=a_y, b_y=b_y)

ax_M.plot(x_rescaled, T_mean_rescaled, label='T_mean', alpha=.5,
linewidth=5, color='aqua')

T_rand = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t,
w_s, w_t, method='rand', seed=0)[:, 0]
x_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n,
a_y=a_y, b_y=b_y)

ax_M.scatter(x_rescaled, T_rand_rescaled, label='T_rand', alpha=.5,
s=20, color='orange')

ax_M.legend(loc='upper left', fontsize=13)

# %%
191 changes: 191 additions & 0 deletions examples/others/plot_GMM_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
r"""
====================================================
GMM Flow
====================================================

Illustration of the flow of a Gaussian Mixture with
respect to its GMM-OT distance with respect to a
fixed GMM.

"""

# Author: Eloi Tanguy <eloi.tanguy@u-paris>
# Remi Flamary <remi.flamary@polytehnique.edu>
# Julie Delon <julie.delon@math.cnrs.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 4

import numpy as np
import matplotlib.pylab as pl
from matplotlib import colormaps as cm
import ot
import ot.plot
from ot.utils import proj_SDP, proj_simplex
from ot.gmm import gmm_ot_loss
import torch
from torch.optim import Adam
from matplotlib.patches import Ellipse


##############################################################################
# Generate data and plot it
# -------------------------
torch.manual_seed(3)
ks = 3
kt = 2
d = 2
eps = 0.1
m_s = torch.randn(ks, d)
m_s.requires_grad_()
m_t = torch.randn(kt, d)
C_s = torch.randn(ks, d, d)
C_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1))
C_s += eps * torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1)
C_s.requires_grad_()
C_t = torch.randn(kt, d, d)
C_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1))
C_t += eps * torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1)
w_s = torch.randn(ks)
w_s = proj_simplex(w_s)
w_s.requires_grad_()
w_t = torch.tensor(ot.unif(kt))


def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=.5):

def eigsorted(cov):
vals, vecs = np.linalg.eigh(cov)
order = vals.argsort()[::-1]
return vals[order], vecs[:, order]

vals, vecs = eigsorted(C)
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
w, h = 2 * nstd * np.sqrt(vals)
ell = Ellipse(xy=(mu[0], mu[1]),
width=w, height=h, alpha=alpha,
angle=theta, facecolor=color, edgecolor=color, label=label, fill=True)
pl.gca().add_artist(ell)


def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1):
for k in range(ms.shape[0]):
draw_cov(ms[k], Cs[k], color, None, nstd,
alpha * ws[k])


axis = [-3, 3, -3, 3]
pl.figure(1, (20, 10))
pl.clf()

pl.subplot(1, 2, 1)
pl.scatter(m_s[:, 0].detach(), m_s[:, 1].detach(), color='C0')
draw_gmm(m_s.detach(), C_s.detach(),
torch.softmax(w_s, 0).detach().numpy(),
color='C0')
pl.axis(axis)
pl.title('Source GMM')

pl.subplot(1, 2, 2)
pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1')
draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1')
pl.axis(axis)
pl.title('Target GMM')

##############################################################################
# Gradient descent loop
# ------------------------

n_gd_its = 100
lr = 3e-2
opt = Adam([{'params': m_s, 'lr': 2 * lr},
{'params': C_s, 'lr': lr},
{'params': w_s, 'lr': lr}])
m_list = [m_s.data.numpy().copy()]
C_list = [C_s.data.numpy().copy()]
w_list = [torch.softmax(w_s, 0).data.numpy().copy()]
loss_list = []

for _ in range(n_gd_its):
opt.zero_grad()
loss = gmm_ot_loss(m_s, m_t, C_s, C_t,
torch.softmax(w_s, 0), w_t)
loss.backward()
opt.step()
with torch.no_grad():
C_s.data = proj_SDP(C_s.data, vmin=1e-6)
m_list.append(m_s.data.numpy().copy())
C_list.append(C_s.data.numpy().copy())
w_list.append(torch.softmax(w_s, 0).data.numpy().copy())
loss_list.append(loss.item())

pl.figure(2)
pl.clf()
pl.plot(loss_list)
pl.title('Loss')
pl.xlabel('its')
pl.ylabel('loss')


##############################################################################
# Last step visualisation
# ------------------------

axis = [-3, 3, -3, 3]
pl.figure(3, (10, 10))
pl.clf()
pl.title('GMM flow, last step')
pl.scatter(m_list[0][:, 0], m_list[0][:, 1], color='C0', label='Source')
draw_gmm(m_list[0], C_list[0], w_list[0], color='C0')
pl.axis(axis)

pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1', label='Target')
draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1')
pl.axis(axis)

k = -1
pl.scatter(m_list[k][:, 0], m_list[k][:, 1], color='C2', alpha=1, label='Last step')
draw_gmm(m_list[k], C_list[k], w_list[0], color='C2', alpha=1)

pl.axis(axis)
pl.legend(fontsize=15)


##############################################################################
# Steps visualisation
# ------------------------
def index_to_color(i):
return int(i**0.5)


n_steps_visu = 100
pl.figure(3, (10, 10))
pl.clf()
pl.title('GMM flow, all steps')

its_to_show = [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)]
cmp = cm['plasma'].resampled(index_to_color(n_steps_visu))

pl.scatter(m_list[0][:, 0], m_list[0][:, 1],
color=cmp(index_to_color(0)), label='Source')
draw_gmm(m_list[0], C_list[0], w_list[0],
color=cmp(index_to_color(0)))

pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(),
color=cmp(index_to_color(n_steps_visu - 1)), label='Target')
draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(),
color=cmp(index_to_color(n_steps_visu - 1)))


for k in its_to_show:
pl.scatter(m_list[k][:, 0], m_list[k][:, 1],
color=cmp(index_to_color(k)), alpha=0.8)
draw_gmm(m_list[k], C_list[k], w_list[0],
color=cmp(index_to_color(k)), alpha=0.04)

pl.axis(axis)
pl.legend(fontsize=15)

# %%
1 change: 1 addition & 0 deletions examples/plot_OT_1D.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# %%
# -*- coding: utf-8 -*-
"""
======================================
Expand Down
1 change: 1 addition & 0 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import solvers
from . import gaussian
from . import lowrank
from . import gmm


# OT functions
Expand Down
Loading
Loading