From f91daac1d73a7a85dee60d21bc81282f5d5a71c7 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 27 Jun 2024 12:06:57 +0200 Subject: [PATCH 01/30] WIP GMM --- ot/gaussian.py | 2 +- ot/gmm.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 ot/gmm.py diff --git a/ot/gaussian.py b/ot/gaussian.py index e9d475b52..a0e7a2ae9 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -788,7 +788,7 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, Cov_s = nx.dot((xs * ws).T, xs) / nx.sum(ws) Cov_t = nx.dot((xt * wt).T, xt) / nx.sum(wt) - # compte and sort eigenvalues/eigenvectors decreasingly + # compute and sort eigenvalues/eigenvectors decreasingly d_s, U_s = nx.eigh(Cov_s) id_s = nx.flip(nx.argsort(d_s)) d_s, U_s = d_s[id_s], U_s[:, id_s] diff --git a/ot/gmm.py b/ot/gmm.py new file mode 100644 index 000000000..34bf937c8 --- /dev/null +++ b/ot/gmm.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +""" +Optimal transport for Gaussian Mixtures +""" + +# Author: Eloi Tanguy +# Remi Flamary +# Julie Delon +# +# License: MIT License + +import warnings + +from .backend import get_backend +from .utils import dots, is_all_finite, list_to_array +from .gaussian import bures_wasserstein_distance, bures_wasserstein_mapping, gaussian_pdf +from .lp import emd2, emd +from scipy.stats import multivariate_normal + + +def gaussian_pdf(x, m, C): + var = multivariate_normal(mean=m, cov=C) + return var.pdf(x) + + +def gmm_pdf(x, m, C, w): + nx = get_backend(x, m, C) + out = nx.zeros((x.shape[0])) + for k in range(m.shape[0]): + out = out + w[k] * gaussian_pdf(x, m[k], C[k]) + return out + + +def dist_bures(m_s, m_t, C_s, C_t): + r""" + """ + nx = get_backend(m_s, C_s, m_t, C_t) + k_s, k_t = m_s.shape[0], m_t.shape[0] + # TODO assert tailles + D = nx.zeros((k_s, k_t), type_as=m_s) + for i in range(k_s): + for j in range(k_t): + D[i, j] = bures_wasserstein_distance(m_s[i], m_t[j], C_s[i], C_t[j]) + return D + + +def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): + r""" + Gaussian Mixture OT loss + """ + get_backend(m_s, C_s, w_s, m_t, C_t, w_t) # backed test + # TODO assert taille w + D = dist_bures(m_s, m_t, C_s, C_t) + return emd2(w_s, w_t, D) + + +def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): + r""" + Gaussian Mixture OT loss + """ + get_backend(m_s, C_s, w_s, m_t, C_t, w_t) # backed test + # TODO assert taille w + D = dist_bures(m_s, m_t, C_s, C_t) + return emd(w_s, w_t, D) + + +def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, + method='bary'): + r""" + Applies the barycentric or stochastic map associated to the GMM OT from the + source GMM to the target GMM + """ + if plan is None: + plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) + + # TODO asserts + normalisation = np.expand_dims(m0.pdf(x), -1) # from (...) to (..., 1) + out = nx.zeros_like(x) + for k0 in range(m0.n_components): + for k1 in range(m1.n_components): + g = gaussian_pdf(x, m_s[k0], C_s[k0])[:, None] + A, b = bures_wasserstein_mapping(m_s[k0], m_t[k1], C_s[k0], C_t[k1]) + Tk0k1x = A @ x + b + out = out + w[k0, k1] * g * Tk0k1x + return out / normalisation \ No newline at end of file From 38822a1d2502bd2a000f6e6082827e761b77f5c1 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 27 Jun 2024 14:27:35 +0200 Subject: [PATCH 02/30] wip test --- ot/__init__.py | 1 + ot/gmm.py | 2 +- test/test_gmm.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 test/test_gmm.py diff --git a/ot/__init__.py b/ot/__init__.py index 5eb3977aa..3a5168c67 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import solvers from . import gaussian from . import lowrank +from . import gmm # OT functions diff --git a/ot/gmm.py b/ot/gmm.py index 34bf937c8..9c4b03691 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -13,7 +13,7 @@ from .backend import get_backend from .utils import dots, is_all_finite, list_to_array -from .gaussian import bures_wasserstein_distance, bures_wasserstein_mapping, gaussian_pdf +from .gaussian import bures_wasserstein_distance, bures_wasserstein_mapping from .lp import emd2, emd from scipy.stats import multivariate_normal diff --git a/test/test_gmm.py b/test/test_gmm.py new file mode 100644 index 000000000..3440ce8b1 --- /dev/null +++ b/test/test_gmm.py @@ -0,0 +1,54 @@ +"""Tests for module gaussian""" + +# Author: Eloi Tanguy +# Remi Flamary +# Julie Delon +# +# License: MIT License + +import numpy as np +import pytest +import ot +from ot.utils import proj_simplex +from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures + + +def get_gmms(): + rng = np.random.RandomState(seed=42) + ks = 3 + kt = 5 + d = 3 + m_s = rng.randn(ks, d) + m_t = rng.randn(kt, d) + C_s = rng.randn(ks, d, d) + C_s = np.matmul(C_s, np.transpose(C_s, (0, 2, 1))) + C_t = rng.randn(kt, d, d) + C_t = np.matmul(C_t, np.transpose(C_t, (0, 2, 1))) + w_s = proj_simplex(rng.rand(ks)) + w_t = proj_simplex(rng.rand(kt)) + return m_s, m_t, C_s, C_t, w_s, w_t + + +def test_gaussian_pdf(): + rng = np.random.RandomState(seed=42) + n = 7 + d = 3 + x = rng.randn(n, d) + m, _, C, _, _, _ = get_gmms() + p = gaussian_pdf(x, m[0], C[0]) + + +def test_gmm_pdf(): + rng = np.random.RandomState(seed=42) + n = 7 + d = 3 + x = rng.randn(n, d) + m_s, _, C_s, _, w_s, _ = get_gmms() + p = gmm_pdf(x, m_s, C_s, w_s) + +def test_dist_bures(): + m_s, m_t, C_s, C_t, _, _ = get_gmms() + D = dist_bures(m_s, m_t, C_s, C_t) + D0 = dist_bures(m_s, m_s, C_s, C_s) + print(D0) + assert np.allclose(np.diag(D0), 0) From 19481f483ebc51ad0c1ea961104b6ab53f057f6d Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 27 Jun 2024 15:07:48 +0200 Subject: [PATCH 03/30] basic tests --- ot/gmm.py | 33 +++++++++++++++++++++------------ test/test_gmm.py | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/ot/gmm.py b/ot/gmm.py index 9c4b03691..62bf4fd9e 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -19,8 +19,8 @@ def gaussian_pdf(x, m, C): - var = multivariate_normal(mean=m, cov=C) - return var.pdf(x) + gauss = multivariate_normal(mean=m, cov=C) + return gauss.pdf(x) def gmm_pdf(x, m, C, w): @@ -70,16 +70,25 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, Applies the barycentric or stochastic map associated to the GMM OT from the source GMM to the target GMM """ + if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) + nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t) + else: + nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t, plan) + + if method == 'bary': + # TODO asserts + normalisation = gmm_pdf(x, m_s, C_s, w_s)[:, None] + out = nx.zeros(x.shape) + for k0 in range(m_s.shape[0]): + for k1 in range(m_t.shape[0]): + g = gaussian_pdf(x, m_s[k0], C_s[k0])[:, None] + A, b = bures_wasserstein_mapping( + m_s[k0], m_t[k1], C_s[k0], C_t[k1]) + Tk0k1x = x @ A + b + out = out + plan[k0, k1] * g * Tk0k1x + return out / normalisation - # TODO asserts - normalisation = np.expand_dims(m0.pdf(x), -1) # from (...) to (..., 1) - out = nx.zeros_like(x) - for k0 in range(m0.n_components): - for k1 in range(m1.n_components): - g = gaussian_pdf(x, m_s[k0], C_s[k0])[:, None] - A, b = bures_wasserstein_mapping(m_s[k0], m_t[k1], C_s[k0], C_t[k1]) - Tk0k1x = A @ x + b - out = out + w[k0, k1] * g * Tk0k1x - return out / normalisation \ No newline at end of file + else: # rand + raise NotImplementedError('Mapping {} not implemented'.format(method)) \ No newline at end of file diff --git a/test/test_gmm.py b/test/test_gmm.py index 3440ce8b1..057a00528 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -10,7 +10,7 @@ import pytest import ot from ot.utils import proj_simplex -from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures +from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures, gmm_ot_loss, gmm_ot_plan, gmm_ot_apply_map def get_gmms(): @@ -46,9 +46,41 @@ def test_gmm_pdf(): m_s, _, C_s, _, w_s, _ = get_gmms() p = gmm_pdf(x, m_s, C_s, w_s) + def test_dist_bures(): m_s, m_t, C_s, C_t, _, _ = get_gmms() D = dist_bures(m_s, m_t, C_s, C_t) D0 = dist_bures(m_s, m_s, C_s, C_s) - print(D0) - assert np.allclose(np.diag(D0), 0) + + assert np.allclose(np.diag(D0), 0, atol=1e-6) + + +def test_gmm_ot_loss(): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() + loss = gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t) + + assert loss > 0 + + loss = gmm_ot_loss(m_s, m_s, C_s, C_s, w_s, w_s) + + assert np.allclose(loss, 0, atol=1e-6) + + +def test_gmm_ot_plan(): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() + + plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) + + assert np.allclose(plan.sum(0), w_t, atol=1e-6) + assert np.allclose(plan.sum(1), w_s, atol=1e-6) + + plan = gmm_ot_plan(m_s, m_s + 1, C_s, C_s, w_s, w_s) + + assert np.allclose(plan, np.diag(w_s), atol=1e-6) + + +def test_gmm_apply_map(): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() + rng = np.random.RandomState(seed=42) + x = rng.randn(7, 3) + gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t) From 29c314ab81224fe3cba15ad1ec76700503513e3d Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 27 Jun 2024 17:50:58 +0200 Subject: [PATCH 04/30] wip gmm flow example --- examples/others/plot_GMM_flow.py | 155 +++++++++++++++++++++++++++++++ ot/gmm.py | 4 +- ot/utils.py | 9 ++ test/test_gmm.py | 24 +++++ test/test_utils.py | 7 ++ 5 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 examples/others/plot_GMM_flow.py diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py new file mode 100644 index 000000000..50015e51d --- /dev/null +++ b/examples/others/plot_GMM_flow.py @@ -0,0 +1,155 @@ +# %% -*- coding: utf-8 -*- +""" +==================================================== +GMM Flow +==================================================== + +Illustration of the flow of a Gaussian Mixture with +respect to its GMM-OT distance with respect to a +fixed GMM. + +""" + +# Author: Eloi Tanguy +# Remi Flamary +# Julie Delon +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.utils import proj_simplex, proj_SDP +from ot.gmm import gmm_ot_loss +import torch +from torch.optim import SGD, Adam +from matplotlib.patches import Ellipse + +# %% +############################################################################## +# Generate data an plot it +# ------------------------ +torch.manual_seed(3) +ks = 5 +kt = 3 +d = 2 +m_s = torch.randn(ks, d) +m_s.requires_grad_() +m_t = torch.randn(kt, d) +# C_s = torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1) +C_s = torch.randn(ks, d, d) +C_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1)) + torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1) +C_s.requires_grad_() +C_t = torch.randn(kt, d, d) +# C_t = torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1) +C_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1)) + torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1) +# w_s = torch.tensor(ot.unif(ks)) +w_s = torch.randn(ks) +w_s.requires_grad_() +w_t = torch.tensor(ot.unif(kt)) +# w_t = proj_simplex(torch.rand(kt)) + +def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=.5): + + def eigsorted(cov): + vals, vecs = np.linalg.eigh(cov) + order = vals.argsort()[::-1] + return vals[order], vecs[:, order] + + vals, vecs = eigsorted(C) + theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) + w, h = 2 * nstd * np.sqrt(vals) + ell = Ellipse(xy=(mu[0], mu[1]), + width=w, height=h, alpha=alpha, + angle=theta, facecolor=color, edgecolor=color, label=label, fill=True) + pl.gca().add_artist(ell) + +def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): + for k in range(ms.shape[0]): + draw_cov(ms[k], Cs[k], color, None, nstd, + alpha * ws[k]) + +axis = [-3, 3, -3, 3] +pl.figure(1, (20, 10)) +pl.clf() + +pl.subplot(1, 2, 1) +pl.scatter(m_s[:, 0].detach(), m_s[:, 1].detach(), color='C0') +draw_gmm(m_s.detach(), C_s.detach(), + torch.softmax(w_s, 0).detach().numpy(), + color='C0') +pl.axis(axis) +pl.title('Source GMM') + +pl.subplot(1, 2, 2) +pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1') +draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1') +pl.axis(axis) +pl.title('Target GMM') + +# %% +############################################################################## +# Gradient descent loop +# ------------------------ +n_gd_its = 300 +lr = 3e-2 +opt = Adam([{'params': m_s, 'lr': 2 * lr}, + {'params': C_s, 'lr': lr}]) + # {'params': w_s, 'lr': lr}]) +m_list = [m_s.data.numpy().copy()] +C_list = [C_s.data.numpy().copy()] +w_list = [torch.softmax(w_s, 0).data.numpy().copy()] +loss_list = [] + +for _ in range(n_gd_its): + opt.zero_grad() + loss = gmm_ot_loss(m_s, m_t, C_s, C_t, + torch.softmax(w_s, 0), w_t) + loss.backward() + opt.step() + with torch.no_grad(): + # w_s.data = proj_simplex(w_s.data) + for k in range(m_s.shape[0]): + C_s.data[k] = proj_SDP(C_s.data[k], vmin=1e-6) + m_list.append(m_s.data.numpy().copy()) + C_list.append(C_s.data.numpy().copy()) + w_list.append(torch.softmax(w_s, 0).data.numpy().copy()) + loss_list.append(loss.item()) + +pl.figure(2) +pl.clf() +pl.plot(loss_list) +pl.title('Loss') +pl.xlabel('its') +pl.ylabel('loss') + +# %% +############################################################################## +# Steps visualisation +# ------------------------ + +n_steps_visu = 30 +axis = [-3, 3, -3, 3] +pl.figure(3, (10, 10)) +pl.clf() + +pl.scatter(m_list[0][:, 0], m_list[0][:, 1], color='C0') +draw_gmm(m_list[0], C_list[0], w_list[0], color='C0') +pl.axis(axis) + +pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1') +draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1') +pl.axis(axis) + +# [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)] +its_to_show = [-1] + +for k in its_to_show: + pl.scatter(m_list[k][:, 0], m_list[k][:, 1], color='C2') + draw_gmm(m_list[k], C_list[k], w_list[0], color='C2', alpha=1) + +pl.axis(axis) +# %% diff --git a/ot/gmm.py b/ot/gmm.py index 62bf4fd9e..73a017c0b 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -40,7 +40,7 @@ def dist_bures(m_s, m_t, C_s, C_t): D = nx.zeros((k_s, k_t), type_as=m_s) for i in range(k_s): for j in range(k_t): - D[i, j] = bures_wasserstein_distance(m_s[i], m_t[j], C_s[i], C_t[j]) + D[i, j] = bures_wasserstein_distance(m_s[i], m_t[j], C_s[i], C_t[j]) ** 2 return D @@ -91,4 +91,4 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, return out / normalisation else: # rand - raise NotImplementedError('Mapping {} not implemented'.format(method)) \ No newline at end of file + raise NotImplementedError('Mapping {} not implemented'.format(method)) diff --git a/ot/utils.py b/ot/utils.py index 04c0e550e..60c49b583 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1292,3 +1292,12 @@ def __getitem__(self, key): def __repr__(self): return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) + + +def proj_SDP(S, nx=None, vmin=0.): + if nx is None: + nx = get_backend(S) + + w, P = nx.eigh(S) + w = nx.clip(w, vmin, None) + return P @ nx.diag(w) @ P.T diff --git a/test/test_gmm.py b/test/test_gmm.py index 057a00528..323b19835 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -12,6 +12,11 @@ from ot.utils import proj_simplex from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures, gmm_ot_loss, gmm_ot_plan, gmm_ot_apply_map +try: + import torch +except ImportError: + torch = False + def get_gmms(): rng = np.random.RandomState(seed=42) @@ -84,3 +89,22 @@ def test_gmm_apply_map(): rng = np.random.RandomState(seed=42) x = rng.randn(7, 3) gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t) + + +@pytest.mark.skipif(not torch, reason="No torch available") +def test_gradient_gmm_ot_loss_pytorch(): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() + m_s = torch.tensor(m_s, requires_grad=True) + m_t = torch.tensor(m_t, requires_grad=True) + C_s = torch.tensor(C_s, requires_grad=True) + C_t = torch.tensor(C_t, requires_grad=True) + w_s = torch.tensor(w_s, requires_grad=True) + w_t = torch.tensor(w_t, requires_grad=True) + loss = gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t) + loss.backward() + grad_m_s = m_s.grad + grad_C_s = C_s.grad + grad_w_s = w_s.grad + assert (grad_m_s**2).sum().item() > 0 + assert (grad_C_s**2).sum().item() > 0 + assert (grad_w_s**2).sum().item() > 0 diff --git a/test/test_utils.py b/test/test_utils.py index 966cef989..f02682cf7 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -620,3 +620,10 @@ def test_label_normalization(nx): # labels are shifted but the shift if expected y_normalized_start = ot.utils.label_normalization(y, start=1) np.testing.assert_array_equal(y, y_normalized_start) + + +def test_proj_SDP(nx): + S = np.diag([1., -1., 2.]) + S = nx.from_numpy(S) + S1 = ot.utils.proj_SDP(S) + assert np.allclose(nx.to_numpy(S1), np.diag([1, 0, 2])) From 161faa9ccc034290fb94fa8d00fafe08b674c793 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 16 Jul 2024 14:12:01 +0200 Subject: [PATCH 05/30] batchable proj_SDP --- ot/utils.py | 35 +++++++++++++++++++++++++++++++++-- test/test_utils.py | 26 ++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index 60c49b583..af4ef10af 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1295,9 +1295,40 @@ def __repr__(self): def proj_SDP(S, nx=None, vmin=0.): + """ + Project a symmetric matrix onto the space of symmetric matrices with + eigenvalues larger or equal to `vmin`. + + Parameters: + ---------- + S : array_like (n, d, d) + The input symmetric matrix or matrices. + nx : module, optional + The numerical backend module to use. If not provided, the backend will + be fetched from the input matrix `S`. + vmin : float, optional + The minimum value for the eigenvalues. Eigenvalues below this value will + be clipped to vmin. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Returns: + ------- + P : ndarray (n, d, d) + The projected symmetric positive definite matrix. + + """ if nx is None: nx = get_backend(S) - + w, P = nx.eigh(S) w = nx.clip(w, vmin, None) - return P @ nx.diag(w) @ P.T + + if len(S.shape) == 2: # input was (d, d) + return P @ nx.diag(w) @ P.T + + else: # input was (n, d, d): broadcasting + Q = nx.einsum('ijk,ik->ijk', P, w) # Q[i] = P[i] @ diag(w[i]) + # R[i] = Q[i] @ P[i].T + return nx.einsum('ijk,ikl->ijl', Q, nx.transpose(P, (0, 2, 1))) diff --git a/test/test_utils.py b/test/test_utils.py index f02682cf7..6c3d7c35c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -623,7 +623,25 @@ def test_label_normalization(nx): def test_proj_SDP(nx): - S = np.diag([1., -1., 2.]) - S = nx.from_numpy(S) - S1 = ot.utils.proj_SDP(S) - assert np.allclose(nx.to_numpy(S1), np.diag([1, 0, 2])) + t = np.pi / 8 + U = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]]) + w = np.array([1., -1.]) + S = np.stack([U @ np.diag(w) @ U.T]*2, axis=0) + S_nx = nx.from_numpy(S) + R = ot.utils.proj_SDP(S_nx) + + w_expected = np.array([1., 0.]) + S_expected = np.stack([U @ np.diag(w_expected) @ U.T]*2, axis=0) + assert np.allclose(nx.to_numpy(R), S_expected) + + R0 = ot.utils.proj_SDP(S_nx[0]) + assert np.allclose(nx.to_numpy(R[0]), S_expected[0]) + + +def test_laplacian(): + n = 100 + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + M = ot.dist(x, x) + L = ot.utils.laplacian(M) + assert L.shape == (n, n) From 49531f5426d3219d4c7ce7afbc025612d32a1207 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 16 Jul 2024 14:12:54 +0200 Subject: [PATCH 06/30] linting --- ot/gmm.py | 7 ++----- test/test_gmm.py | 1 - 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ot/gmm.py b/ot/gmm.py index 73a017c0b..cfabc4db5 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -9,10 +9,7 @@ # # License: MIT License -import warnings - from .backend import get_backend -from .utils import dots, is_all_finite, list_to_array from .gaussian import bures_wasserstein_distance, bures_wasserstein_mapping from .lp import emd2, emd from scipy.stats import multivariate_normal @@ -76,7 +73,7 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t) else: nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t, plan) - + if method == 'bary': # TODO asserts normalisation = gmm_pdf(x, m_s, C_s, w_s)[:, None] @@ -89,6 +86,6 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, Tk0k1x = x @ A + b out = out + plan[k0, k1] * g * Tk0k1x return out / normalisation - + else: # rand raise NotImplementedError('Mapping {} not implemented'.format(method)) diff --git a/test/test_gmm.py b/test/test_gmm.py index 323b19835..3b86e3073 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -8,7 +8,6 @@ import numpy as np import pytest -import ot from ot.utils import proj_simplex from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures, gmm_ot_loss, gmm_ot_plan, gmm_ot_apply_map From 26bb200698246e2848f8069815024fdead222c95 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 16 Jul 2024 14:40:28 +0200 Subject: [PATCH 07/30] linting --- ot/gmm.py | 8 ++++---- ot/utils.py | 2 +- test/test_gmm.py | 6 +++--- test/test_utils.py | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ot/gmm.py b/ot/gmm.py index cfabc4db5..f0821b462 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -29,15 +29,15 @@ def gmm_pdf(x, m, C, w): def dist_bures(m_s, m_t, C_s, C_t): - r""" - """ nx = get_backend(m_s, C_s, m_t, C_t) k_s, k_t = m_s.shape[0], m_t.shape[0] # TODO assert tailles D = nx.zeros((k_s, k_t), type_as=m_s) for i in range(k_s): for j in range(k_t): - D[i, j] = bures_wasserstein_distance(m_s[i], m_t[j], C_s[i], C_t[j]) ** 2 + D[i, j] = bures_wasserstein_distance(m_s[i], m_t[j], + C_s[i], C_t[j]) + D[i, j] = D[i, j] ** 2 return D @@ -67,7 +67,7 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, Applies the barycentric or stochastic map associated to the GMM OT from the source GMM to the target GMM """ - + if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t) diff --git a/ot/utils.py b/ot/utils.py index af4ef10af..175af0c8f 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1309,7 +1309,7 @@ def proj_SDP(S, nx=None, vmin=0.): vmin : float, optional The minimum value for the eigenvalues. Eigenvalues below this value will be clipped to vmin. - + .. note:: This function is backend-compatible and will work on arrays from all compatible backends. diff --git a/test/test_gmm.py b/test/test_gmm.py index 3b86e3073..d66a1c3e7 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -39,7 +39,7 @@ def test_gaussian_pdf(): d = 3 x = rng.randn(n, d) m, _, C, _, _, _ = get_gmms() - p = gaussian_pdf(x, m[0], C[0]) + gaussian_pdf(x, m[0], C[0]) def test_gmm_pdf(): @@ -48,12 +48,12 @@ def test_gmm_pdf(): d = 3 x = rng.randn(n, d) m_s, _, C_s, _, w_s, _ = get_gmms() - p = gmm_pdf(x, m_s, C_s, w_s) + gmm_pdf(x, m_s, C_s, w_s) def test_dist_bures(): m_s, m_t, C_s, C_t, _, _ = get_gmms() - D = dist_bures(m_s, m_t, C_s, C_t) + dist_bures(m_s, m_t, C_s, C_t) D0 = dist_bures(m_s, m_s, C_s, C_s) assert np.allclose(np.diag(D0), 0, atol=1e-6) diff --git a/test/test_utils.py b/test/test_utils.py index dbd5b276b..0801337cb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -626,16 +626,16 @@ def test_proj_SDP(nx): t = np.pi / 8 U = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]]) w = np.array([1., -1.]) - S = np.stack([U @ np.diag(w) @ U.T]*2, axis=0) + S = np.stack([U @ np.diag(w) @ U.T] * 2, axis=0) S_nx = nx.from_numpy(S) R = ot.utils.proj_SDP(S_nx) w_expected = np.array([1., 0.]) - S_expected = np.stack([U @ np.diag(w_expected) @ U.T]*2, axis=0) + S_expected = np.stack([U @ np.diag(w_expected) @ U.T] * 2, axis=0) assert np.allclose(nx.to_numpy(R), S_expected) R0 = ot.utils.proj_SDP(S_nx[0]) - assert np.allclose(nx.to_numpy(R[0]), S_expected[0]) + assert np.allclose(nx.to_numpy(R0), S_expected[0]) def test_laplacian(): From 190b72914c65c59b0bfdf6c3ba691f14cf53d9bb Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 16 Jul 2024 15:54:57 +0200 Subject: [PATCH 08/30] gaussian + gmm pdfs, faster GMMOT plans and barycentric mappings --- examples/others/plot_GMM_flow.py | 17 +-- ot/gmm.py | 188 ++++++++++++++++++++++++++----- test/test_gmm.py | 36 +++++- 3 files changed, 203 insertions(+), 38 deletions(-) diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index 50015e51d..48cb3570a 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -4,8 +4,8 @@ GMM Flow ==================================================== -Illustration of the flow of a Gaussian Mixture with -respect to its GMM-OT distance with respect to a +Illustration of the flow of a Gaussian Mixture with +respect to its GMM-OT distance with respect to a fixed GMM. """ @@ -52,6 +52,7 @@ w_t = torch.tensor(ot.unif(kt)) # w_t = proj_simplex(torch.rand(kt)) + def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=.5): def eigsorted(cov): @@ -67,19 +68,21 @@ def eigsorted(cov): angle=theta, facecolor=color, edgecolor=color, label=label, fill=True) pl.gca().add_artist(ell) + def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): for k in range(ms.shape[0]): - draw_cov(ms[k], Cs[k], color, None, nstd, + draw_cov(ms[k], Cs[k], color, None, nstd, alpha * ws[k]) + axis = [-3, 3, -3, 3] pl.figure(1, (20, 10)) pl.clf() pl.subplot(1, 2, 1) pl.scatter(m_s[:, 0].detach(), m_s[:, 1].detach(), color='C0') -draw_gmm(m_s.detach(), C_s.detach(), - torch.softmax(w_s, 0).detach().numpy(), +draw_gmm(m_s.detach(), C_s.detach(), + torch.softmax(w_s, 0).detach().numpy(), color='C0') pl.axis(axis) pl.title('Source GMM') @@ -98,7 +101,7 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): lr = 3e-2 opt = Adam([{'params': m_s, 'lr': 2 * lr}, {'params': C_s, 'lr': lr}]) - # {'params': w_s, 'lr': lr}]) +# {'params': w_s, 'lr': lr}]) m_list = [m_s.data.numpy().copy()] C_list = [C_s.data.numpy().copy()] w_list = [torch.softmax(w_s, 0).data.numpy().copy()] @@ -106,7 +109,7 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): for _ in range(n_gd_its): opt.zero_grad() - loss = gmm_ot_loss(m_s, m_t, C_s, C_t, + loss = gmm_ot_loss(m_s, m_t, C_s, C_t, torch.softmax(w_s, 0), w_t) loss.backward() opt.step() diff --git a/ot/gmm.py b/ot/gmm.py index f0821b462..c614c2dfd 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -10,54 +10,181 @@ # License: MIT License from .backend import get_backend -from .gaussian import bures_wasserstein_distance, bures_wasserstein_mapping from .lp import emd2, emd -from scipy.stats import multivariate_normal +import numpy as np +from .lp import dist def gaussian_pdf(x, m, C): - gauss = multivariate_normal(mean=m, cov=C) - return gauss.pdf(x) + r""" + Compute the probability density function of a multivariate Gaussian distribution. + + Parameters + ---------- + x : array-like, shape (n_samples, d) + The input samples. + m : array-like, shape (d,) + The mean vector of the Gaussian distribution. + C : array-like, shape (d, d) + The covariance matrix of the Gaussian distribution. + + Returns + ------- + pdf : array-like, shape (n_samples,) + The probability density function evaluated at each sample. + + """ + _, d = x.shape + z = (2 * np.pi) ** (-d / 2) * np.linalg.det(C) ** (-0.5) + exp = np.exp(-0.5 * np.sum((x - m) @ np.linalg.inv(C) * (x - m), axis=1)) + return z * exp def gmm_pdf(x, m, C, w): - nx = get_backend(x, m, C) - out = nx.zeros((x.shape[0])) + r""" + Compute the probability density function (PDF) of a Gaussian Mixture Model (GMM) at given points. + + Parameters: + ----------- + x : array-like, shape (n_samples, d) + The input samples. + m : array-like, shape (n_components, d) + The means of the Gaussian components. + C : array-like, shape (n_components, d, d) + The covariance matrices of the Gaussian components. + w : array-like, shape (n_components,) + The weights of the Gaussian components. + + Returns: + -------- + out : array-like, shape (n_components,) + The PDF values at the given points. + + """ + out = np.zeros((x.shape[0])) for k in range(m.shape[0]): out = out + w[k] * gaussian_pdf(x, m[k], C[k]) return out -def dist_bures(m_s, m_t, C_s, C_t): +def dist_bures_squared(m_s, m_t, C_s, C_t): + """ + Compute the matrix of the squared Bures distances between the components of two Gaussian Mixture Models (GMMs). + + Parameters: + ----------- + m_s : array-like, shape (k_s, d) + Mean vectors of the source GMM. + m_t : array-like, shape (k_t, d) + Mean vectors of the target GMM. + C_s : array-like, shape (k_s, d, d) + Covariance matrices of the source GMM. + C_t : array-like, shape (k_t, d, d) + Covariance matrices of the target GMM. + + Returns: + -------- + dist : array-like, shape (k_s, k_t) + Matrix of squared Bures distances between the components of the source and target GMMs. + + """ nx = get_backend(m_s, C_s, m_t, C_t) k_s, k_t = m_s.shape[0], m_t.shape[0] - # TODO assert tailles - D = nx.zeros((k_s, k_t), type_as=m_s) + + assert m_s.shape[0] == C_s.shape[0], \ + "Source GMM has different amount of components" + + assert m_t.shape[0] == C_t.shape[0], \ + "Target GMM has different amount of components" + + assert m_s.shape[-1] == m_t.shape[-1] == C_s.shape[-1] == C_t.shape[-1], \ + "All GMMs must have the same dimension" + + D_means = dist(m_s, m_t, metric='sqeuclidean') + D_covs = nx.zeros((k_s, k_t), type_as=m_s) + for i in range(k_s): + Cs12 = nx.sqrtm(C_s[i]) # nx.sqrtm is not batchable for j in range(k_t): - D[i, j] = bures_wasserstein_distance(m_s[i], m_t[j], - C_s[i], C_t[j]) - D[i, j] = D[i, j] ** 2 - return D + C = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) + D_covs[i, j] = nx.trace(C_s[i] + C_t[j] - 2 * C) + + return nx.maximum(D_means + D_covs, 0) def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): - r""" - Gaussian Mixture OT loss """ - get_backend(m_s, C_s, w_s, m_t, C_t, w_t) # backed test - # TODO assert taille w - D = dist_bures(m_s, m_t, C_s, C_t) + Compute the Gaussian Mixture Model (GMM) Optimal Transport distance between + two GMMs. + + Parameters: + ----------- + m_s : array-like, shape (k_s, d) + Mean vectors of the source GMM. + m_t : array-like, shape (k_t, d) + Mean vectors of the target GMM. + C_s : array-like, shape (k_s, d, d) + Covariance matrices of the source GMM. + C_t : array-like, shape (k_t, d, d) + Covariance matrices of the target GMM. + w_s : array-like, shape (k_s,) + Weights of the source GMM components. + w_t : array-like, shape (k_t,) + Weights of the target GMM components. + + Returns: + -------- + loss : float + The GMM-OT loss. + + """ + get_backend(m_s, C_s, w_s, m_t, C_t, w_t) + + assert m_s.shape[0] == w_s.shape[0], \ + "Source GMM has different amount of components" + + assert m_t.shape[0] == w_t.shape[0], \ + "Target GMM has different amount of components" + + D = dist_bures_squared(m_s, m_t, C_s, C_t) return emd2(w_s, w_t, D) def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): r""" - Gaussian Mixture OT loss + Compute the Gaussian Mixture Model (GMM) Optimal Transport plan between + two GMMs. + + Parameters: + ----------- + m_s : array-like, shape (k_s, d) + Mean vectors of the source GMM. + m_t : array-like, shape (k_t, d) + Mean vectors of the target GMM. + C_s : array-like, shape (k_s, d, d) + Covariance matrices of the source GMM. + C_t : array-like, shape (k_t, d, d) + Covariance matrices of the target GMM. + w_s : array-like, shape (k_s,) + Weights of the source GMM components. + w_t : array-like, shape (k_t,) + Weights of the target GMM components. + + Returns: + -------- + plan : array-like, shape (k_s, k_t) + The GMM-OT plan. + """ - get_backend(m_s, C_s, w_s, m_t, C_t, w_t) # backed test - # TODO assert taille w - D = dist_bures(m_s, m_t, C_s, C_t) + get_backend(m_s, C_s, w_s, m_t, C_t, w_t) + + assert m_s.shape[0] == w_s.shape[0], \ + "Source GMM has different amount of components" + + assert m_t.shape[0] == w_t.shape[0], \ + "Target GMM has different amount of components" + + D = dist_bures_squared(m_s, m_t, C_s, C_t) return emd(w_s, w_t, D) @@ -76,16 +203,25 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, if method == 'bary': # TODO asserts - normalisation = gmm_pdf(x, m_s, C_s, w_s)[:, None] + normalization = gmm_pdf(x, m_s, C_s, w_s)[:, None] out = nx.zeros(x.shape) + for k0 in range(m_s.shape[0]): + Cs12 = nx.sqrtm(C_s[k0]) + Cs12inv = nx.inv(Cs12) + for k1 in range(m_t.shape[0]): g = gaussian_pdf(x, m_s[k0], C_s[k0])[:, None] - A, b = bures_wasserstein_mapping( - m_s[k0], m_t[k1], C_s[k0], C_t[k1]) + + M0 = nx.sqrtm(Cs12 @ C_t[k1] @ Cs12) + A = Cs12inv @ M0 @ Cs12inv + b = m_t[k1] - A @ m_s[k0] + + # gaussian mapping between components k0 and k1 applied to x Tk0k1x = x @ A + b out = out + plan[k0, k1] * g * Tk0k1x - return out / normalisation + + return out / normalization else: # rand raise NotImplementedError('Mapping {} not implemented'.format(method)) diff --git a/test/test_gmm.py b/test/test_gmm.py index d66a1c3e7..2061950bc 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -9,7 +9,7 @@ import numpy as np import pytest from ot.utils import proj_simplex -from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures, gmm_ot_loss, gmm_ot_plan, gmm_ot_apply_map +from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures_squared, gmm_ot_loss, gmm_ot_plan, gmm_ot_apply_map try: import torch @@ -39,7 +39,8 @@ def test_gaussian_pdf(): d = 3 x = rng.randn(n, d) m, _, C, _, _, _ = get_gmms() - gaussian_pdf(x, m[0], C[0]) + pdf = gaussian_pdf(x, m[0], C[0]) + assert pdf.shape == (n,) def test_gmm_pdf(): @@ -51,13 +52,26 @@ def test_gmm_pdf(): gmm_pdf(x, m_s, C_s, w_s) -def test_dist_bures(): +def test_dist_bures_squared(nx): m_s, m_t, C_s, C_t, _, _ = get_gmms() - dist_bures(m_s, m_t, C_s, C_t) - D0 = dist_bures(m_s, m_s, C_s, C_s) + m_s = nx.from_numpy(m_s) + m_t = nx.from_numpy(m_t) + C_s = nx.from_numpy(C_s) + C_t = nx.from_numpy(C_t) + dist_bures_squared(m_s, m_t, C_s, C_t) + D0 = dist_bures_squared(m_s, m_s, C_s, C_s) assert np.allclose(np.diag(D0), 0, atol=1e-6) + with pytest.raises(AssertionError): + dist_bures_squared(m_s[:, 1:], m_t, C_s, C_t) + + with pytest.raises(AssertionError): + dist_bures_squared(m_s[1:], m_t, C_s, C_t) + + with pytest.raises(AssertionError): + dist_bures_squared(m_s, m_t[1:], C_s, C_t) + def test_gmm_ot_loss(): m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() @@ -69,6 +83,12 @@ def test_gmm_ot_loss(): assert np.allclose(loss, 0, atol=1e-6) + with pytest.raises(AssertionError): + gmm_ot_loss(m_s, m_t, C_s, C_t, w_s[1:], w_t) + + with pytest.raises(AssertionError): + gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t[1:]) + def test_gmm_ot_plan(): m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() @@ -82,6 +102,12 @@ def test_gmm_ot_plan(): assert np.allclose(plan, np.diag(w_s), atol=1e-6) + with pytest.raises(AssertionError): + gmm_ot_loss(m_s, m_t, C_s, C_t, w_s[1:], w_t) + + with pytest.raises(AssertionError): + gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t[1:]) + def test_gmm_apply_map(): m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() From 915c9a9e4d9025f180986b2d42da760e4641508c Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 17 Jul 2024 10:17:23 +0200 Subject: [PATCH 09/30] Trand + docs and test --- ot/gmm.py | 102 ++++++++++++++++++++++++++++++++++++++++------- test/test_gmm.py | 36 ++++++++++++----- 2 files changed, 113 insertions(+), 25 deletions(-) diff --git a/ot/gmm.py b/ot/gmm.py index c614c2dfd..40d3c1e2d 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -189,10 +189,49 @@ def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, - method='bary'): - r""" - Applies the barycentric or stochastic map associated to the GMM OT from the - source GMM to the target GMM + method='bary', seed=None): + """ + Apply Gaussian Mixture Model (GMM) optimal transport (OT) mapping to input data. + + Parameters: + ----------- + x : array-like, shape (n_samples, d) + Input data points. + + m_s : array-like, shape (k_s, d) + Mean vectors of the source GMM components. + + m_t : array-like, shape (k_t, d) + Mean vectors of the target GMM components. + + C_s : array-like, shape (k_s, d, d) + Covariance matrices of the source GMM components. + + C_t : array-like, shape (k_t, d, d) + Covariance matrices of the target GMM components. + + w_s : array-like, shape (k_s,) + Weights of the source GMM components. + + w_t : array-like, shape (k_t,) + Weights of the target GMM components. + + plan : array-like, shape (k_s, k_t), optional + Optimal transport plan between the source and target GMM components. + If not provided, it will be computed internally. + + method : {'bary', 'rand'}, optional + Method for applying the GMM OT mapping. 'bary' uses barycentric mapping, + while 'rand' uses random sampling. Default is 'bary'. + + seed : int, optional + Seed for the random number generator. Only used when method='rand'. + + Returns: + -------- + out : array-like, shape (n_samples, d) + Output data points after applying the GMM OT mapping. + """ if plan is None: @@ -201,27 +240,62 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, else: nx = get_backend(x, m_s, m_t, C_s, C_t, w_s, w_t, plan) + k_s, k_t = m_s.shape[0], m_t.shape[0] + d = m_s.shape[1] + n_samples = x.shape[0] + if method == 'bary': # TODO asserts normalization = gmm_pdf(x, m_s, C_s, w_s)[:, None] out = nx.zeros(x.shape) - for k0 in range(m_s.shape[0]): - Cs12 = nx.sqrtm(C_s[k0]) + for i in range(k_s): + Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) - for k1 in range(m_t.shape[0]): - g = gaussian_pdf(x, m_s[k0], C_s[k0])[:, None] + for j in range(k_t): + g = gaussian_pdf(x, m_s[i], C_s[i])[:, None] - M0 = nx.sqrtm(Cs12 @ C_t[k1] @ Cs12) + M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) A = Cs12inv @ M0 @ Cs12inv - b = m_t[k1] - A @ m_s[k0] + b = m_t[j] - A @ m_s[i] - # gaussian mapping between components k0 and k1 applied to x - Tk0k1x = x @ A + b - out = out + plan[k0, k1] * g * Tk0k1x + # gaussian mapping between components i and j applied to x + T_ij_x = x @ A + b + out = out + plan[i, j] * g * T_ij_x return out / normalization else: # rand - raise NotImplementedError('Mapping {} not implemented'.format(method)) + # A[i, j] is the linear part of the gaussian mapping between components + # i and j, b[i, j] is the translation part + rng = np.random.RandomState(seed) + + A = nx.zeros((k_s, k_t, d, d)) + b = nx.zeros((k_s, k_t, d)) + + for i in range(k_s): + Cs12 = nx.sqrtm(C_s[i]) + Cs12inv = nx.inv(Cs12) + + for j in range(k_t): + M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) + A[i, j] = Cs12inv @ M0 @ Cs12inv + b[i, j] = m_t[j] - A[i, j] @ m_s[i] + + normalization = gmm_pdf(x, m_s, C_s, w_s) # (n_samples,) + gs = np.stack( + [gaussian_pdf(x, m_s[i], C_s[i]) for i in range(k_s)], axis=-1) + # (n_samples, k_s) + out = nx.zeros(x.shape) + + for i_sample in range(n_samples): + p_mat = plan * gs[i_sample][:, None] / normalization[i_sample] + p = p_mat.reshape(k_s * k_t) # stack line-by-line + # sample between 0 and k_s * k_t - 1 + ij_mat = rng.choice(k_s * k_t, p=p) + i = ij_mat // k_t + j = ij_mat % k_t + out[i_sample] = A[i, j] @ x[i_sample] + b[i, j] + + return out diff --git a/test/test_gmm.py b/test/test_gmm.py index 2061950bc..8486d7a42 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -17,7 +17,7 @@ torch = False -def get_gmms(): +def get_gmms(nx=None): rng = np.random.RandomState(seed=42) ks = 3 kt = 5 @@ -30,6 +30,13 @@ def get_gmms(): C_t = np.matmul(C_t, np.transpose(C_t, (0, 2, 1))) w_s = proj_simplex(rng.rand(ks)) w_t = proj_simplex(rng.rand(kt)) + if nx is not None: + m_s = nx.from_numpy(m_s) + m_t = nx.from_numpy(m_t) + C_s = nx.from_numpy(C_s) + C_t = nx.from_numpy(C_t) + w_s = nx.from_numpy(w_s) + w_t = nx.from_numpy(w_t) return m_s, m_t, C_s, C_t, w_s, w_t @@ -52,12 +59,10 @@ def test_gmm_pdf(): gmm_pdf(x, m_s, C_s, w_s) +@pytest.skip_backend('tf') # skips because of array assignment +@pytest.skip_backend("jax") def test_dist_bures_squared(nx): - m_s, m_t, C_s, C_t, _, _ = get_gmms() - m_s = nx.from_numpy(m_s) - m_t = nx.from_numpy(m_t) - C_s = nx.from_numpy(C_s) - C_t = nx.from_numpy(C_t) + m_s, m_t, C_s, C_t, _, _ = get_gmms(nx) dist_bures_squared(m_s, m_t, C_s, C_t) D0 = dist_bures_squared(m_s, m_s, C_s, C_s) @@ -73,8 +78,10 @@ def test_dist_bures_squared(nx): dist_bures_squared(m_s, m_t[1:], C_s, C_t) -def test_gmm_ot_loss(): - m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() +@pytest.skip_backend('tf') # skips because of array assignment +@pytest.skip_backend("jax") +def test_gmm_ot_loss(nx): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) loss = gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t) assert loss > 0 @@ -90,8 +97,10 @@ def test_gmm_ot_loss(): gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t[1:]) -def test_gmm_ot_plan(): - m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() +@pytest.skip_backend('tf') # skips because of array assignment +@pytest.skip_backend("jax") +def test_gmm_ot_plan(nx): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) @@ -113,7 +122,12 @@ def test_gmm_apply_map(): m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() rng = np.random.RandomState(seed=42) x = rng.randn(7, 3) - gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t) + + for method in ['bary', 'rand']: + gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, method=method) + + plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) + gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=plan) @pytest.mark.skipif(not torch, reason="No torch available") From 2ac5c9649042ff7881580a5459db4e87f8ec495f Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 17 Jul 2024 11:40:58 +0200 Subject: [PATCH 10/30] updated references, docs, RELEASES.md, CONTRIBUTORS.md and the README to reflect PR changes --- CONTRIBUTORS.md | 4 ++- README.md | 3 ++ RELEASES.md | 1 + docs/source/all.rst | 1 + ot/gaussian.py | 2 +- ot/gmm.py | 81 +++++++++++++++++++++++++++++++-------------- ot/utils.py | 4 +-- 7 files changed, 67 insertions(+), 29 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index e982cd5b6..093137e2b 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -42,7 +42,8 @@ The contributors to this library are: * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW) -* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters) +* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein + Barycenters, GMMOT) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) @@ -51,6 +52,7 @@ The contributors to this library are: * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization) * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers) * [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples) +* [Julie Delon](https://judelo.github.io/) (GMM OT) ## Acknowledgments diff --git a/README.md b/README.md index 7f2ce3ee3..da1622126 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ POT provides the following generic OT solvers (links to examples): * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. * Smooth Strongly Convex Nearest Brenier Potentials [58], with an extension to bounding potentials using [59]. +* Gaussian Mixture Model OT [69] POT provides the following Machine Learning related solvers: @@ -368,3 +369,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). [Linear-Time Gromov-Wasserstein Distances using Low Rank Couplings and Costs](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf). In International Conference on Machine Learning (ICML), 2022. [68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. + +[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index e5a8ac54b..1d423985a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,7 @@ #### New features - Add feature `mass=True` for `nx.kl_div` (PR #654) +- Gaussian Mixture Model OT `ot.gmm` (PR #655) #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 91bb36361..dca8377d8 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -22,6 +22,7 @@ API and modules dr factored gaussian + gmm gnn gromov lowrank diff --git a/ot/gaussian.py b/ot/gaussian.py index a0e7a2ae9..78cffcb42 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -727,7 +727,7 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None, r"""Return Gaussian Gromov-Wasserstein mapping between samples. The function estimates the Gaussian Gromov-Wasserstein mapping between two - Gaussien distributions source :math:`\mu_s` and target :math:`\mu_t`, whose + Gaussian distributions source :math:`\mu_s` and target :math:`\mu_t`, whose parameters are estimated from the provided samples :math:`\mathcal{X}_s` and :math:`\mathcal{X}_t`. See [57] Theorem 4.1 for more details. diff --git a/ot/gmm.py b/ot/gmm.py index 40d3c1e2d..6d6ee001f 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -44,8 +44,8 @@ def gmm_pdf(x, m, C, w): r""" Compute the probability density function (PDF) of a Gaussian Mixture Model (GMM) at given points. - Parameters: - ----------- + Parameters + ---------- x : array-like, shape (n_samples, d) The input samples. m : array-like, shape (n_components, d) @@ -55,8 +55,8 @@ def gmm_pdf(x, m, C, w): w : array-like, shape (n_components,) The weights of the Gaussian components. - Returns: - -------- + Returns + ------- out : array-like, shape (n_components,) The PDF values at the given points. @@ -69,10 +69,12 @@ def gmm_pdf(x, m, C, w): def dist_bures_squared(m_s, m_t, C_s, C_t): """ - Compute the matrix of the squared Bures distances between the components of two Gaussian Mixture Models (GMMs). + Compute the matrix of the squared Bures distances between the components of + two Gaussian Mixture Models (GMMs). Used to compute the GMM Optimal + Transport distance [69]. - Parameters: - ----------- + Parameters + ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) @@ -82,10 +84,17 @@ def dist_bures_squared(m_s, m_t, C_s, C_t): C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM. - Returns: - -------- + Returns + ------- dist : array-like, shape (k_s, k_t) - Matrix of squared Bures distances between the components of the source and target GMMs. + Matrix of squared Bures distances between the components of the source + and target GMMs. + + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in + the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, + 13(2), 936-970. """ nx = get_backend(m_s, C_s, m_t, C_t) @@ -115,10 +124,10 @@ def dist_bures_squared(m_s, m_t, C_s, C_t): def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): """ Compute the Gaussian Mixture Model (GMM) Optimal Transport distance between - two GMMs. + two GMMs introduced in [69]. - Parameters: - ----------- + Parameters + ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) @@ -132,11 +141,17 @@ def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): w_t : array-like, shape (k_t,) Weights of the target GMM components. - Returns: - -------- + Returns + ------- loss : float The GMM-OT loss. + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in + the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, + 13(2), 936-970. + """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) @@ -153,10 +168,10 @@ def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): r""" Compute the Gaussian Mixture Model (GMM) Optimal Transport plan between - two GMMs. + two GMMs introduced in [69]. - Parameters: - ----------- + Parameters + ---------- m_s : array-like, shape (k_s, d) Mean vectors of the source GMM. m_t : array-like, shape (k_t, d) @@ -170,11 +185,17 @@ def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): w_t : array-like, shape (k_t,) Weights of the target GMM components. - Returns: - -------- + Returns + ------- plan : array-like, shape (k_s, k_t) The GMM-OT plan. + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in + the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, + 13(2), 936-970. + """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) @@ -191,10 +212,14 @@ def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, method='bary', seed=None): """ - Apply Gaussian Mixture Model (GMM) optimal transport (OT) mapping to input data. + Apply Gaussian Mixture Model (GMM) optimal transport (OT) mapping to input + data. The 'barycentric' mapping corresponds to the barycentric projection + of the GMM-OT plan, and is called T_bary in [69]. The 'random' mapping takes + for each input point a random pair (i,j) of components of the GMMs and + applied the Gaussian map, it is called T_rand in [69]. - Parameters: - ----------- + Parameters + ---------- x : array-like, shape (n_samples, d) Input data points. @@ -227,11 +252,17 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, seed : int, optional Seed for the random number generator. Only used when method='rand'. - Returns: - -------- + Returns + ------- out : array-like, shape (n_samples, d) Output data points after applying the GMM OT mapping. + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in + the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, + 13(2), 936-970. + """ if plan is None: diff --git a/ot/utils.py b/ot/utils.py index 175af0c8f..1fb6313ca 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1299,7 +1299,7 @@ def proj_SDP(S, nx=None, vmin=0.): Project a symmetric matrix onto the space of symmetric matrices with eigenvalues larger or equal to `vmin`. - Parameters: + Parameters ---------- S : array_like (n, d, d) The input symmetric matrix or matrices. @@ -1313,7 +1313,7 @@ def proj_SDP(S, nx=None, vmin=0.): .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - Returns: + Returns ------- P : ndarray (n, d, d) The projected symmetric positive definite matrix. From 528fb6cf39b6880fe29b2d2ff8498ee47e7586cc Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 17 Jul 2024 12:01:07 +0200 Subject: [PATCH 11/30] fixed mean computations in ot.gaussian --- README.md | 2 +- RELEASES.md | 1 + ot/gaussian.py | 40 ++++++++++++++++++++-------------------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index da1622126..3769080a8 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ POT provides the following generic OT solvers (links to examples): * [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68]. * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. -* Smooth Strongly Convex Nearest Brenier Potentials [58], with an extension to bounding potentials using [59]. +* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59]. * Gaussian Mixture Model OT [69] POT provides the following Machine Learning related solvers: diff --git a/RELEASES.md b/RELEASES.md index 1d423985a..3d438d31a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features - Add feature `mass=True` for `nx.kl_div` (PR #654) - Gaussian Mixture Model OT `ot.gmm` (PR #655) +- Fixed `ot.gaussian` ignoring weights when computing means (PR #655, Issue #648) #### Closed issues diff --git a/ot/gaussian.py b/ot/gaussian.py index 78cffcb42..832d193da 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -160,9 +160,15 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, d = xs.shape[1] + if ws is None: + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] + if bias: - mxs = nx.mean(xs, axis=0)[None, :] - mxt = nx.mean(xt, axis=0)[None, :] + mxs = nx.dot(ws.T, xs) / nx.sum(ws) + mxt = nx.dot(wt.T, xt) / nx.sum(wt) xs = xs - mxs xt = xt - mxt @@ -170,12 +176,6 @@ def empirical_bures_wasserstein_mapping(xs, xt, reg=1e-6, ws=None, mxs = nx.zeros((1, d), type_as=xs) mxt = nx.zeros((1, d), type_as=xs) - if ws is None: - ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] - - if wt is None: - wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) @@ -315,9 +315,15 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, d = xs.shape[1] + if ws is None: + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] + if bias: - mxs = nx.mean(xs, axis=0)[None, :] - mxt = nx.mean(xt, axis=0)[None, :] + mxs = nx.dot(ws.T, xs) / nx.sum(ws) + mxt = nx.dot(wt.T, xt) / nx.sum(wt) xs = xs - mxs xt = xt - mxt @@ -325,12 +331,6 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, mxs = nx.zeros((1, d), type_as=xs) mxt = nx.zeros((1, d), type_as=xs) - if ws is None: - ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] - - if wt is None: - wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) @@ -503,15 +503,15 @@ def empirical_bures_wasserstein_barycenter( k = len(X) d = [X[i].shape[1] for i in range(k)] + if w is None: + w = [nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k)] + if bias: - m = [nx.mean(X[i], axis=0)[None, :] for i in range(k)] + m = [nx.dot(w[i].T, X[i]) / nx.sum(w[i]) for i in range(k)] X = [X[i] - m[i] for i in range(k)] else: m = [nx.zeros((1, d[i]), type_as=X[i]) for i in range(k)] - if w is None: - w = [nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k)] - C = [ nx.dot((X[i] * w[i]).T, X[i]) / nx.sum(w[i]) + reg * nx.eye(d[i], type_as=X[i]) for i in range(k) From 8ac2c1167edee38e0020f65b70f076f5b58b4bd3 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 17 Jul 2024 16:21:55 +0200 Subject: [PATCH 12/30] all dosctrings as raw strings --- ot/gmm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/gmm.py b/ot/gmm.py index 6d6ee001f..08d5f1306 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -68,7 +68,7 @@ def gmm_pdf(x, m, C, w): def dist_bures_squared(m_s, m_t, C_s, C_t): - """ + r""" Compute the matrix of the squared Bures distances between the components of two Gaussian Mixture Models (GMMs). Used to compute the GMM Optimal Transport distance [69]. @@ -122,7 +122,7 @@ def dist_bures_squared(m_s, m_t, C_s, C_t): def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): - """ + r""" Compute the Gaussian Mixture Model (GMM) Optimal Transport distance between two GMMs introduced in [69]. @@ -211,7 +211,7 @@ def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, method='bary', seed=None): - """ + r""" Apply Gaussian Mixture Model (GMM) optimal transport (OT) mapping to input data. The 'barycentric' mapping corresponds to the barycentric projection of the GMM-OT plan, and is called T_bary in [69]. The 'random' mapping takes From 42ca0d947237008494dc72d66a570b86b175a7f1 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Wed, 17 Jul 2024 16:23:21 +0200 Subject: [PATCH 13/30] linting --- examples/others/plot_GMM_flow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index 48cb3570a..065881326 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -1,5 +1,5 @@ # %% -*- coding: utf-8 -*- -""" +r""" ==================================================== GMM Flow ==================================================== @@ -22,10 +22,10 @@ import matplotlib.pylab as pl import ot import ot.plot -from ot.utils import proj_simplex, proj_SDP +from ot.utils import proj_SDP from ot.gmm import gmm_ot_loss import torch -from torch.optim import SGD, Adam +from torch.optim import Adam from matplotlib.patches import Ellipse # %% From 7a6fbb22e49a6f0c5535ab1af73fd5483d259653 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 18 Jul 2024 14:41:52 +0200 Subject: [PATCH 14/30] fixed PR number --- RELEASES.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 3d438d31a..5ee094d77 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,8 +4,8 @@ #### New features - Add feature `mass=True` for `nx.kl_div` (PR #654) -- Gaussian Mixture Model OT `ot.gmm` (PR #655) -- Fixed `ot.gaussian` ignoring weights when computing means (PR #655, Issue #648) +- Gaussian Mixture Model OT `ot.gmm` (PR #649) +- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) #### Closed issues From 57217c5ff10bc93510c17b7fb03a28043a0199f5 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Thu, 18 Jul 2024 15:03:09 +0200 Subject: [PATCH 15/30] nx.det WIP --- ot/backend.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ot/backend.py b/ot/backend.py index 819b91db5..be2289a90 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1058,6 +1058,14 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): """ raise NotImplementedError() + def det(self, a): + r""" + Compute the determinant of an array. + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.det.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1412,6 +1420,9 @@ def matmul(self, a, b): def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): return np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + def det(self, a): + return np.linalg.det(a) + _register_backend_implementation(NumpyBackend) @@ -1797,6 +1808,9 @@ def matmul(self, a, b): def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): return jnp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + def det(self, x): + return jnp.numpy.linalg.det(x) + if jax: # Only register jax backend if it is installed @@ -2290,6 +2304,9 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): out = None if copy else x return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf, out=out) + def det(self, x): + return torch.linalg.det(x) + if torch: # Only register torch backend if it is installed @@ -2691,6 +2708,9 @@ def matmul(self, a, b): def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): return cp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) + def det(self, x): + return cp.linalg.det(x) + if cp: # Only register cp backend if it is installed @@ -3121,6 +3141,9 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): x = np.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) return self.from_numpy(x) + def det(self, x): + return tf.linalg.det(x) + if tf: # Only register tensorflow backend if it is installed From 3c6f4c9ce1d8c8a7ec3519ec1c0a65a3bfdac209 Mon Sep 17 00:00:00 2001 From: Julie Delon Date: Thu, 18 Jul 2024 17:11:26 +0200 Subject: [PATCH 16/30] gmm flow example --- examples/others/plot_GMM_flow.py | 94 +++++++++++++++++++++----------- ot/gmm.py | 19 ++----- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index 065881326..34cf6e28d 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -1,4 +1,4 @@ -# %% -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- r""" ==================================================== GMM Flow @@ -20,37 +20,35 @@ import numpy as np import matplotlib.pylab as pl +from matplotlib import colormaps as cm import ot import ot.plot -from ot.utils import proj_SDP +from ot.utils import proj_SDP, proj_simplex from ot.gmm import gmm_ot_loss import torch from torch.optim import Adam from matplotlib.patches import Ellipse -# %% + ############################################################################## -# Generate data an plot it -# ------------------------ +# Generate data and plot it +# ------------------------- torch.manual_seed(3) -ks = 5 -kt = 3 +ks = 3 +kt = 2 d = 2 -m_s = torch.randn(ks, d) +eps = 0.1 +m_s = torch.randn(ks, d) m_s.requires_grad_() -m_t = torch.randn(kt, d) -# C_s = torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1) +m_t = torch.randn(kt, d) C_s = torch.randn(ks, d, d) -C_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1)) + torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1) +C_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1)) + eps*torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1) C_s.requires_grad_() C_t = torch.randn(kt, d, d) -# C_t = torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1) -C_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1)) + torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1) -# w_s = torch.tensor(ot.unif(ks)) +C_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1)) + eps*torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1) w_s = torch.randn(ks) w_s.requires_grad_() w_t = torch.tensor(ot.unif(kt)) -# w_t = proj_simplex(torch.rand(kt)) def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=.5): @@ -93,15 +91,15 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): pl.axis(axis) pl.title('Target GMM') -# %% ############################################################################## # Gradient descent loop # ------------------------ -n_gd_its = 300 + +n_gd_its = 100 lr = 3e-2 opt = Adam([{'params': m_s, 'lr': 2 * lr}, - {'params': C_s, 'lr': lr}]) -# {'params': w_s, 'lr': lr}]) + {'params': C_s, 'lr': lr}, + {'params': w_s, 'lr': lr}]) m_list = [m_s.data.numpy().copy()] C_list = [C_s.data.numpy().copy()] w_list = [torch.softmax(w_s, 0).data.numpy().copy()] @@ -114,9 +112,8 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): loss.backward() opt.step() with torch.no_grad(): - # w_s.data = proj_simplex(w_s.data) - for k in range(m_s.shape[0]): - C_s.data[k] = proj_SDP(C_s.data[k], vmin=1e-6) + w_s.data = proj_simplex(w_s.data) + C_s.data = proj_SDP(C_s.data, vmin=1e-6) m_list.append(m_s.data.numpy().copy()) C_list.append(C_s.data.numpy().copy()) w_list.append(torch.softmax(w_s, 0).data.numpy().copy()) @@ -129,30 +126,61 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): pl.xlabel('its') pl.ylabel('loss') -# %% + ############################################################################## -# Steps visualisation +# Last step visualisation # ------------------------ -n_steps_visu = 30 axis = [-3, 3, -3, 3] pl.figure(3, (10, 10)) pl.clf() - -pl.scatter(m_list[0][:, 0], m_list[0][:, 1], color='C0') +pl.title('GMM flow, last step') +pl.scatter(m_list[0][:, 0], m_list[0][:, 1], color='C0', label='Source') draw_gmm(m_list[0], C_list[0], w_list[0], color='C0') pl.axis(axis) -pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1') +pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color='C1', label='Target') draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color='C1') pl.axis(axis) -# [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)] -its_to_show = [-1] +k = -1 +pl.scatter(m_list[k][:, 0], m_list[k][:, 1], color='C2', alpha=1, label='Last step') +draw_gmm(m_list[k], C_list[k], w_list[0], color='C2', alpha=1) + +pl.axis(axis) +pl.legend(fontsize=15) + + +############################################################################## +# Steps visualisation +# ------------------------ +def index_to_color(i): + return int(i**0.5) + +n_steps_visu = 100 +pl.figure(3, (10, 10)) +pl.clf() +pl.title('GMM flow, all steps') + +its_to_show = [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)] +cmp = cm['plasma'].resampled(index_to_color(n_steps_visu)) + +pl.scatter(m_list[0][:, 0], m_list[0][:, 1], + color=cmp(index_to_color(0)), label='Source') +draw_gmm(m_list[0], C_list[0], w_list[0], + color=cmp(index_to_color(0))) + +pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), + color=cmp(index_to_color(n_steps_visu - 1)), label='Target') +draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), + color=cmp(index_to_color(n_steps_visu - 1))) + for k in its_to_show: - pl.scatter(m_list[k][:, 0], m_list[k][:, 1], color='C2') - draw_gmm(m_list[k], C_list[k], w_list[0], color='C2', alpha=1) + pl.scatter(m_list[k][:, 0], m_list[k][:, 1], + color=cmp(index_to_color(k)), alpha=0.8) + draw_gmm(m_list[k], C_list[k], w_list[0], + color=cmp(index_to_color(k)), alpha=0.04) pl.axis(axis) -# %% +pl.legend(fontsize=15) diff --git a/ot/gmm.py b/ot/gmm.py index 08d5f1306..dd0c0abc6 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -42,7 +42,8 @@ def gaussian_pdf(x, m, C): def gmm_pdf(x, m, C, w): r""" - Compute the probability density function (PDF) of a Gaussian Mixture Model (GMM) at given points. + Compute the probability density function (PDF) of a Gaussian Mixture Model (GMM) + at given points. Parameters ---------- @@ -92,9 +93,7 @@ def dist_bures_squared(m_s, m_t, C_s, C_t): References ---------- - .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in - the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, - 13(2), 936-970. + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ nx = get_backend(m_s, C_s, m_t, C_t) @@ -148,9 +147,7 @@ def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): References ---------- - .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in - the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, - 13(2), 936-970. + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) @@ -192,9 +189,7 @@ def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): References ---------- - .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in - the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, - 13(2), 936-970. + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ get_backend(m_s, C_s, w_s, m_t, C_t, w_t) @@ -259,9 +254,7 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, References ---------- - .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in - the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, - 13(2), 936-970. + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ From 5cf6bada7d906cbbaa494d3abe5c5ab7ac094bc0 Mon Sep 17 00:00:00 2001 From: Julie Delon Date: Thu, 18 Jul 2024 18:22:57 +0200 Subject: [PATCH 17/30] plot_GMMOT_plan unfinished --- examples/others/plot_GMMOT_plan.py | 70 ++++++++++++++++++++++++++++++ examples/others/plot_GMM_flow.py | 1 + ot/gmm.py | 37 ++++++++++++++++ 3 files changed, 108 insertions(+) create mode 100644 examples/others/plot_GMMOT_plan.py diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py new file mode 100644 index 000000000..da850ddf1 --- /dev/null +++ b/examples/others/plot_GMMOT_plan.py @@ -0,0 +1,70 @@ +# %% +# -*- coding: utf-8 -*- +r""" +==================================================== +GMM Plan 1D +==================================================== + +Illustration of the GMM plan for +the Mixture Wasserstein between two GMM in 1D. + +""" + +# Author: Eloi Tanguy +# Remi Flamary +# Julie Delon +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +from matplotlib import colormaps as cm +import ot +import ot.plot +from ot.utils import proj_SDP, proj_simplex +from ot.gmm import gmm_ot_loss, gmm_ot_plan_density, gmm_ot_plan, gmm_pdf + +# %% +############################################################################## +# Generate data and plot it +# ------------------------- +np.random.seed(3) +ks = 3 +kt = 2 +d = 1 +eps = 0.1 +m_s = np.random.rand(ks, d) +m_t = np.random.rand(kt, d) +C_s = np.random.randn(ks, d, d)*0.1 +C_s = np.matmul(C_s, np.transpose(C_s, (0, 2, 1))) +C_t = np.random.randn(kt, d, d)*0.1 +C_t = np.matmul(C_t, np.transpose(C_t, (0, 2, 1))) +w_s = ot.unif(ks) +w_t = ot.unif(kt) + +axis = [-3, 3, -3, 3] +pl.figure(1, (20, 10)) +pl.clf() + + +# %% +############################################################################## +# Compute plan +# ------------ + +n = 100 +x = np.linspace(0, 1, n) +y = np.linspace(0, 1, n) +xx, yy = np.meshgrid(x, y) +xx = xx.reshape((n**2, 1)) +yy = yy.reshape((n**2, 1)) +plan = gmm_ot_plan_density(xx, yy, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=0.1) + +a = gmm_pdf(x[:,None], m_s, C_s, w_s) +b = gmm_pdf(y[:,None], m_t, C_t, w_t) +plan = plan.reshape((n,n)) +ot.plot.plot1D_mat(a, b, plan, title='Plan between two GMM') + +# %% diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index 34cf6e28d..09a5f2977 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -47,6 +47,7 @@ C_t = torch.randn(kt, d, d) C_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1)) + eps*torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1) w_s = torch.randn(ks) +w_s = proj_simplex(w_s) w_s.requires_grad_() w_t = torch.tensor(ot.unif(kt)) diff --git a/ot/gmm.py b/ot/gmm.py index dd0c0abc6..a1ea3a1e4 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -13,6 +13,7 @@ from .lp import emd2, emd import numpy as np from .lp import dist +from .gaussian import bures_wasserstein_mapping def gaussian_pdf(x, m, C): @@ -323,3 +324,39 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, out[i_sample] = A[i, j] @ x[i_sample] + b[i, j] return out + + + + + +def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-8): + r""" + Args: + m0: gaussian mixture 0 + m1: gaussian mixture 1 + x: (..., d) array-like + y: (..., d) array-like (same shape as x) + atol: absolute tolerance for the condition T_kl(x) = y + + Returns: + density of the MW2 OT plan between m0 and m1 at (x, y) + """ + + if plan is None: + plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) + + def Tk0k1(k0, k1): + A, b = bures_wasserstein_mapping(m_s[k0], m_t[k1], C_s[k0], C_t[k1]) + Tx = x @ A + b + g = gaussian_pdf(x, m_s[k0], C_s[k0]) + out = plan[k0, k1] * g + norms = np.linalg.norm(Tx - y, axis=-1) + out[norms > atol] = 0 + return out + + mat = np.array( + [ + [Tk0k1(k0, k1) for k1 in range(m_t.shape[0])] + for k0 in range(m_s.shape[0]) + ]) + return np.sum(mat, axis=(0, 1)) \ No newline at end of file From 62be757bef6c3dc765dc8b399b94346e588e57fe Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Fri, 19 Jul 2024 14:37:33 +0200 Subject: [PATCH 18/30] ot.plot.plot1D_mat update + finished GMM example --- examples/others/plot_GMMOT_plan.py | 88 +++++++++++-------- ot/gmm.py | 41 +++++---- ot/plot.py | 134 ++++++++++++++++++++++++----- test/test_gmm.py | 29 +++++-- 4 files changed, 209 insertions(+), 83 deletions(-) diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index da850ddf1..3b432d032 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -5,8 +5,14 @@ GMM Plan 1D ==================================================== -Illustration of the GMM plan for -the Mixture Wasserstein between two GMM in 1D. +Illustration of the GMM plan for +the Mixture Wasserstein between two GMM in 1D, +as well as the two maps T_mean and T_rand. +T_mean is the barycentric projection of the GMM coupling, +and T_rand takes a random gaussian image between two components, +according to the coupling and the GMMs. +See [69] for details. +.. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ @@ -20,51 +26,61 @@ import numpy as np import matplotlib.pylab as pl -from matplotlib import colormaps as cm +from matplotlib import colormaps as cm import ot -import ot.plot +from ot.plot import plot1D_mat, rescale_for_imshow_plot from ot.utils import proj_SDP, proj_simplex -from ot.gmm import gmm_ot_loss, gmm_ot_plan_density, gmm_ot_plan, gmm_pdf +from ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map -# %% ############################################################################## -# Generate data and plot it +# Generate GMMOT plan plot it # ------------------------- -np.random.seed(3) -ks = 3 -kt = 2 +ks = 2 +kt = 3 d = 1 eps = 0.1 -m_s = np.random.rand(ks, d) -m_t = np.random.rand(kt, d) -C_s = np.random.randn(ks, d, d)*0.1 -C_s = np.matmul(C_s, np.transpose(C_s, (0, 2, 1))) -C_t = np.random.randn(kt, d, d)*0.1 -C_t = np.matmul(C_t, np.transpose(C_t, (0, 2, 1))) -w_s = ot.unif(ks) -w_t = ot.unif(kt) +m_s = np.array([[1], [2]]) +m_t = np.array([[3], [4.2], [5]]) +C_s = np.array([[[.05]], [[.06]]]) +C_t = np.array([[[.03]], [[.07]], [[.04]]]) +w_s = np.array([.4, .6]) +w_t = np.array([.4, .2, .4]) -axis = [-3, 3, -3, 3] -pl.figure(1, (20, 10)) -pl.clf() +n = 500 +a_x, b_x = 0, 3 +x = np.linspace(a_x, b_x, n) +a_y, b_y = 2, 6 +y = np.linspace(a_y, b_y, n) +xx, yy = np.meshgrid(x, y, indexing='ij') +plan_density = gmm_ot_plan_density(xx[:, :, None], yy[:, :, None], + m_s, m_t, C_s, C_t, w_s, w_t, + plan=None, atol=2e-2) + +a = gmm_pdf(x[:, None], m_s, C_s, w_s) +b = gmm_pdf(y[:, None], m_t, C_t, w_t) +plot1D_mat(a, b, plan_density, title='GMM OT plan') -# %% ############################################################################## -# Compute plan -# ------------ +# Generate GMMOT maps and plot them over plan +# ------------------------- +ax_s, ax_t, ax_M = plot1D_mat(a, b, plan_density, + title='GMM OT plan with T_mean and T_rand maps') +T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, + w_s, w_t, method='bary')[:, 0] +x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n, + a_y=a_y, b_y=b_y) -n = 100 -x = np.linspace(0, 1, n) -y = np.linspace(0, 1, n) -xx, yy = np.meshgrid(x, y) -xx = xx.reshape((n**2, 1)) -yy = yy.reshape((n**2, 1)) -plan = gmm_ot_plan_density(xx, yy, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=0.1) +ax_M.plot(x_rescaled, T_mean_rescaled, label='T_mean', alpha=.5, + linewidth=5, color='aqua') -a = gmm_pdf(x[:,None], m_s, C_s, w_s) -b = gmm_pdf(y[:,None], m_t, C_t, w_t) -plan = plan.reshape((n,n)) -ot.plot.plot1D_mat(a, b, plan, title='Plan between two GMM') +T_rand = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, + w_s, w_t, method='rand', seed=0)[:, 0] +x_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n, + a_y=a_y, b_y=b_y) + +ax_M.scatter(x_rescaled, T_rand_rescaled, label='T_rand', alpha=.5, + s=20, color='orange') + +ax_M.legend(loc='upper left', fontsize=13) -# %% diff --git a/ot/gmm.py b/ot/gmm.py index a1ea3a1e4..e9b31ff13 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -18,11 +18,12 @@ def gaussian_pdf(x, m, C): r""" - Compute the probability density function of a multivariate Gaussian distribution. + Compute the probability density function of a multivariate + Gaussian distribution. Parameters ---------- - x : array-like, shape (n_samples, d) + x : array-like, shape (..., d) The input samples. m : array-like, shape (d,) The mean vector of the Gaussian distribution. @@ -31,24 +32,27 @@ def gaussian_pdf(x, m, C): Returns ------- - pdf : array-like, shape (n_samples,) + pdf : array-like, shape (...,) The probability density function evaluated at each sample. """ - _, d = x.shape - z = (2 * np.pi) ** (-d / 2) * np.linalg.det(C) ** (-0.5) - exp = np.exp(-0.5 * np.sum((x - m) @ np.linalg.inv(C) * (x - m), axis=1)) + assert x.shape[-1] == m.shape[-1] == C.shape[-1] == C.shape[-2], \ + "Dimension mismatch" + nx = get_backend(x, m, C) + d = x.shape[-1] + z = (2 * np.pi) ** (-d / 2) * nx.det(C) ** (-0.5) + exp = nx.exp(-0.5 * nx.sum(((x - m) @ nx.inv(C)) * (x - m), axis=-1)) return z * exp def gmm_pdf(x, m, C, w): r""" - Compute the probability density function (PDF) of a Gaussian Mixture Model (GMM) - at given points. + Compute the probability density function (PDF) of a + Gaussian Mixture Model (GMM) at given points. Parameters ---------- - x : array-like, shape (n_samples, d) + x : array-like, shape (..., d) The input samples. m : array-like, shape (n_components, d) The means of the Gaussian components. @@ -59,11 +63,14 @@ def gmm_pdf(x, m, C, w): Returns ------- - out : array-like, shape (n_components,) + out : array-like, shape (...,) The PDF values at the given points. """ - out = np.zeros((x.shape[0])) + assert m.shape[0] == C.shape[0] == w.shape[0], \ + "All GMM parameters must have the same amount of components" + nx = get_backend(x, m, C, w) + out = nx.zeros((x.shape[:-1])) for k in range(m.shape[0]): out = out + w[k] * gaussian_pdf(x, m[k], C[k]) return out @@ -270,7 +277,6 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, n_samples = x.shape[0] if method == 'bary': - # TODO asserts normalization = gmm_pdf(x, m_s, C_s, w_s)[:, None] out = nx.zeros(x.shape) @@ -326,10 +332,8 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, return out - - - -def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-8): +def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, + plan=None, atol=1e-2): r""" Args: m0: gaussian mixture 0 @@ -341,13 +345,14 @@ def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-8 Returns: density of the MW2 OT plan between m0 and m1 at (x, y) """ - + if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) def Tk0k1(k0, k1): A, b = bures_wasserstein_mapping(m_s[k0], m_t[k1], C_s[k0], C_t[k1]) Tx = x @ A + b + print('Tx', Tx.shape) g = gaussian_pdf(x, m_s[k0], C_s[k0]) out = plan[k0, k1] * g norms = np.linalg.norm(Tx - y, axis=-1) @@ -359,4 +364,4 @@ def Tk0k1(k0, k1): [Tk0k1(k0, k1) for k1 in range(m_t.shape[0])] for k0 in range(m_s.shape[0]) ]) - return np.sum(mat, axis=(0, 1)) \ No newline at end of file + return np.sum(mat, axis=(0, 1)) diff --git a/ot/plot.py b/ot/plot.py index 4b1bfb128..555dfce7e 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -14,14 +14,16 @@ import numpy as np import matplotlib.pylab as pl -from matplotlib import gridspec -def plot1D_mat(a, b, M, title=''): - r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution +def plot1D_mat(a, b, M, title='', a_label='Source distribution', + b_label='Target distribution', color_source='#7ED321', + color_target='#4A90E2', coupling_cmap='gray'): + r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distributions. - Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and - target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between. + Creates a subplot with the source distribution :math:`\mathbf{a}` on the + bottom and target distribution :math:`\mathbf{b}` on the left. + The matrix :math:`\mathbf{M}` is shown in between. Parameters @@ -32,32 +34,120 @@ def plot1D_mat(a, b, M, title=''): Target distribution M : ndarray, shape (na, nb) Matrix to plot + a_label: str, optional + Label for source distribution + b_label: str, optional + Label for target distribution + title: str, optional + Title of the plot + color_source: str, optional + Color of the source distribution + color_target: str, optional + Color of the target distribution + coupling_cmap: str, optional + Colormap for the coupling matrix + + Returns + ------- + ax1: source plot ax + ax2: target plot ax + ax3: coupling plot ax + + .. seealso:: + :func:`rescale_for_imshow_plot` """ na, nb = M.shape - gs = gridspec.GridSpec(3, 3) + fig = pl.figure(figsize=(8, 8)) + gs = fig.add_gridspec(3, 3, height_ratios=[1, 1, 1], + width_ratios=[1, 1, 1], + hspace=0, wspace=0) xa = np.arange(na) xb = np.arange(nb) - ax1 = pl.subplot(gs[0, 1:]) - pl.plot(xb, b, 'r', label='Target distribution') - pl.yticks(()) - pl.title(title) + # horizontal source on the bottom, flipped vertically + ax1 = fig.add_subplot(gs[2, 1:]) + ax1.plot(xa, np.max(a) - a, color=color_source, label=a_label, linewidth=2) + ax1.fill_between(xa, np.max(a) - a, np.max(a) * np.ones_like(a), + color=color_source, alpha=.5) + ax1.set_xticks(()) + ax1.set_yticks(()) + ax1.set_title(a_label, y=-.15) + ax1.spines['top'].set_visible(False) + ax1.spines['right'].set_visible(False) + ax1.spines['bottom'].set_visible(False) + ax1.spines['left'].set_visible(False) + + # vertical target on the left + ax2 = fig.add_subplot(gs[0:2, 0]) + ax2.plot(b, xb, color=color_target, label=b_label, linewidth=2) + ax2.fill_between(b, xb, color=color_target, alpha=.5) + ax2.invert_xaxis() + ax2.invert_yaxis() + ax2.set_xticks(()) + ax2.set_yticks(()) + ax2.set_title(b_label) + ax2.spines['top'].set_visible(False) + ax2.spines['right'].set_visible(False) + ax2.spines['bottom'].set_visible(False) + ax2.spines['left'].set_visible(False) + + # plan image, transposed since imshow is in "yx" coords + ax3 = fig.add_subplot(gs[0:2, 1:], sharey=ax2, sharex=ax1) + ax3.imshow(1 - M.T, interpolation='nearest', origin='lower', + cmap=coupling_cmap) + ax3.set_aspect('equal') + ax3.set_title(title) + + # Set spines visibility to True and customize if desired + ax3.spines['top'].set_visible(True) + ax3.spines['right'].set_visible(True) + ax3.spines['bottom'].set_visible(True) + ax3.spines['left'].set_visible(True) + + pl.subplots_adjust(hspace=0, wspace=0) + return ax1, ax2, ax3 + + +def rescale_for_imshow_plot(x, y, n, a_y=None, b_y=None): + r""" + Gives arrays xr, yr that can be plotted over an (n, n) + imshow plot (in 'xy' coordinates). If `a_y` or `b_y` is provided, + y is sliced over its indices such that y stays in [ay, by]. - ax2 = pl.subplot(gs[1:, 0]) - pl.plot(a, xa, 'b', label='Source distribution') - pl.gca().invert_xaxis() - pl.gca().invert_yaxis() - pl.xticks(()) - - pl.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2) - pl.imshow(M, interpolation='nearest') - pl.axis('off') + Parameters + ---------- + x : ndarray, shape (nx,) + y : ndarray, shape (ny,) + n : int + Size of the imshow plot on which to plot (x, y) + a_y : float, optional + Lower bound for y + b_y : float, optional + Upper bound for y + + Returns + ------- + xr : ndarray, shape (nx,) + Rescaled x values + yr : ndarray, shape (ny,) + Rescaled y values (due to slicing, may have less elements than y) + + .. seealso:: + :func:`plot1D_mat` - pl.xlim((0, nb)) - pl.tight_layout() - pl.subplots_adjust(wspace=0., hspace=0.2) + """ + # slice over the y values that are in the y range + a_x, b_x = np.min(x), np.max(x) + if a_y is None: + a_y = np.min(y) + if b_y is None: + b_y = np.max(y) + idx = (y >= a_y) & (y <= b_y) + x_rescaled = (x[idx] - a_x) * (n - 1) / (b_x - a_x) + y_rescaled = (y[idx] - a_y) * (n - 1) / (b_y - a_y) + return x_rescaled, y_rescaled def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): diff --git a/test/test_gmm.py b/test/test_gmm.py index 8486d7a42..f0cfa30dc 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -40,23 +40,38 @@ def get_gmms(nx=None): return m_s, m_t, C_s, C_t, w_s, w_t -def test_gaussian_pdf(): +def test_gaussian_pdf(nx): rng = np.random.RandomState(seed=42) n = 7 d = 3 - x = rng.randn(n, d) - m, _, C, _, _, _ = get_gmms() + x = nx.from_numpy(rng.randn(n, d)) + m, _, C, _, _, _ = get_gmms(nx) pdf = gaussian_pdf(x, m[0], C[0]) assert pdf.shape == (n,) + x = nx.from_numpy(rng.randn(n, n, d)) + pdf = gaussian_pdf(x, m[0], C[0]) + assert pdf.shape == (n, n,) + + with pytest.raises(AssertionError): + gaussian_pdf(x, m[0, :-1], C[0]) + -def test_gmm_pdf(): +def test_gmm_pdf(nx): rng = np.random.RandomState(seed=42) n = 7 d = 3 - x = rng.randn(n, d) - m_s, _, C_s, _, w_s, _ = get_gmms() - gmm_pdf(x, m_s, C_s, w_s) + x = nx.from_numpy(rng.randn(n, d)) + m, _, C, _, w, _ = get_gmms(nx) + pdf = gmm_pdf(x, m, C, w) + assert pdf.shape == (n,) + + x = nx.from_numpy(rng.randn(n, n, d)) + pdf = gmm_pdf(x, m, C, w) + assert pdf.shape == (n, n,) + + with pytest.raises(AssertionError): + gmm_pdf(x, m[:-1], C, w) @pytest.skip_backend('tf') # skips because of array assignment From 807211ac2b3a7cd11cddda0b52e06a3bd9da61d8 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 22 Jul 2024 11:30:29 +0200 Subject: [PATCH 19/30] finished GMM plan example + nx.det test + ot.plot.plot1D_mat update and test --- examples/others/plot_GMMOT_plan.py | 8 +------- test/test_backend.py | 6 ++++++ test/test_plot.py | 22 ++++++++++++++++++++++ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index 3b432d032..dc0307cae 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -1,4 +1,3 @@ -# %% # -*- coding: utf-8 -*- r""" ==================================================== @@ -22,14 +21,10 @@ # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 1 import numpy as np -import matplotlib.pylab as pl -from matplotlib import colormaps as cm -import ot from ot.plot import plot1D_mat, rescale_for_imshow_plot -from ot.utils import proj_SDP, proj_simplex from ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map ############################################################################## @@ -83,4 +78,3 @@ s=20, color='orange') ax_M.legend(loc='upper left', fontsize=13) - diff --git a/test/test_backend.py b/test/test_backend.py index da7293821..4be6673d9 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -274,6 +274,8 @@ def test_empty_backend(): nx.assert_same_dtype_device(M, M) with pytest.raises(NotImplementedError): nx.eigh(M) + with pytest.raises(NotImplementedError): + nx.det(M) def test_func_backends(nx): @@ -683,6 +685,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(vec)) lst_name.append("nan_to_num") + d = nx.det(M1b) + lst_b.append(nx.to_numpy(d)) + lst_name.append("det") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_plot.py b/test/test_plot.py index e07977952..eae832364 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -38,6 +38,27 @@ def test_plot1D_mat(): ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') +@pytest.mark.skipif(nogo, reason="Matplotlib not installed") +def test_rescale_for_imshow_plot(): + + import ot + import ot.plot + + n = 7 + a_x, b_x = -1, 3 + x = np.linspace(a_x, b_x, n) + a_y, b_y = 2, 6 + y = np.linspace(a_y, b_y, n) + + x_rescaled, y_rescaled = ot.plot.rescale_for_imshow_plot(x, y, n) + assert x_rescaled.shape == (n, ) + assert y_rescaled.shape == (n, ) + + x_rescaled, y_rescaled = ot.plot.rescale_for_imshow_plot(x, y, n, a_y=a_y + 1, b_y=b_y - 1) + assert x_rescaled.shape[0] <= n + assert y_rescaled.shape[0] <= n + + @pytest.mark.skipif(nogo, reason="Matplotlib not installed") def test_plot2D_samples_mat(): @@ -59,3 +80,4 @@ def test_plot2D_samples_mat(): G = 1.0 * (rng.rand(n_bins, n_bins) < 0.01) ot.plot.plot2D_samples_mat(xs, xt, G, thr=1e-5) + ot.plot.plot2D_samples_mat(xs, xt, G, thr=1e-5, alpha=0.5) From 1a0f06d065ee7b5077d09853e5e22aee2702f375 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 22 Jul 2024 11:41:54 +0200 Subject: [PATCH 20/30] linting --- examples/others/plot_GMM_flow.py | 25 ++++++++++++++----------- ot/gmm.py | 6 +++--- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index 09a5f2977..d1852cc9e 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -20,7 +20,7 @@ import numpy as np import matplotlib.pylab as pl -from matplotlib import colormaps as cm +from matplotlib import colormaps as cm import ot import ot.plot from ot.utils import proj_SDP, proj_simplex @@ -38,14 +38,16 @@ kt = 2 d = 2 eps = 0.1 -m_s = torch.randn(ks, d) +m_s = torch.randn(ks, d) m_s.requires_grad_() -m_t = torch.randn(kt, d) +m_t = torch.randn(kt, d) C_s = torch.randn(ks, d, d) -C_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1)) + eps*torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1) +C_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1)) +C_s += eps * torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1) C_s.requires_grad_() C_t = torch.randn(kt, d, d) -C_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1)) + eps*torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1) +C_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1)) +C_t += eps * torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1) w_s = torch.randn(ks) w_s = proj_simplex(w_s) w_s.requires_grad_() @@ -158,6 +160,7 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): def index_to_color(i): return int(i**0.5) + n_steps_visu = 100 pl.figure(3, (10, 10)) pl.clf() @@ -166,21 +169,21 @@ def index_to_color(i): its_to_show = [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)] cmp = cm['plasma'].resampled(index_to_color(n_steps_visu)) -pl.scatter(m_list[0][:, 0], m_list[0][:, 1], +pl.scatter(m_list[0][:, 0], m_list[0][:, 1], color=cmp(index_to_color(0)), label='Source') -draw_gmm(m_list[0], C_list[0], w_list[0], +draw_gmm(m_list[0], C_list[0], w_list[0], color=cmp(index_to_color(0))) -pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), +pl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color=cmp(index_to_color(n_steps_visu - 1)), label='Target') -draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), +draw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color=cmp(index_to_color(n_steps_visu - 1))) for k in its_to_show: - pl.scatter(m_list[k][:, 0], m_list[k][:, 1], + pl.scatter(m_list[k][:, 0], m_list[k][:, 1], color=cmp(index_to_color(k)), alpha=0.8) - draw_gmm(m_list[k], C_list[k], w_list[0], + draw_gmm(m_list[k], C_list[k], w_list[0], color=cmp(index_to_color(k)), alpha=0.04) pl.axis(axis) diff --git a/ot/gmm.py b/ot/gmm.py index e9b31ff13..74b297b2d 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -18,7 +18,7 @@ def gaussian_pdf(x, m, C): r""" - Compute the probability density function of a multivariate + Compute the probability density function of a multivariate Gaussian distribution. Parameters @@ -47,7 +47,7 @@ def gaussian_pdf(x, m, C): def gmm_pdf(x, m, C, w): r""" - Compute the probability density function (PDF) of a + Compute the probability density function (PDF) of a Gaussian Mixture Model (GMM) at given points. Parameters @@ -332,7 +332,7 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, return out -def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, +def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-2): r""" Args: From 8ca96044d7b6da18d08b8275935004eecda259e0 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Mon, 22 Jul 2024 12:00:56 +0200 Subject: [PATCH 21/30] minor PR suggestions + jax det fix --- ot/backend.py | 2 +- ot/gmm.py | 22 +++++++++++++++++----- ot/utils.py | 4 ++-- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index be2289a90..464542947 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1809,7 +1809,7 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None): return jnp.nan_to_num(x, copy=copy, nan=nan, posinf=posinf, neginf=neginf) def det(self, x): - return jnp.numpy.linalg.det(x) + return jnp.linalg.det(x) if jax: diff --git a/ot/gmm.py b/ot/gmm.py index 74b297b2d..8ece24781 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -128,7 +128,7 @@ def dist_bures_squared(m_s, m_t, C_s, C_t): return nx.maximum(D_means + D_covs, 0) -def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): +def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t, log=False): r""" Compute the Gaussian Mixture Model (GMM) Optimal Transport distance between two GMMs introduced in [69]. @@ -147,11 +147,17 @@ def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): Weights of the source GMM components. w_t : array-like, shape (k_t,) Weights of the target GMM components. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the GMM optimal transportation cost. Returns ------- - loss : float + loss : float or array-like The GMM-OT loss. + log : dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status References ---------- @@ -167,10 +173,10 @@ def gmm_ot_loss(m_s, m_t, C_s, C_t, w_s, w_t): "Target GMM has different amount of components" D = dist_bures_squared(m_s, m_t, C_s, C_t) - return emd2(w_s, w_t, D) + return emd2(w_s, w_t, D, log=log) -def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): +def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t, log=False): r""" Compute the Gaussian Mixture Model (GMM) Optimal Transport plan between two GMMs introduced in [69]. @@ -189,11 +195,17 @@ def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): Weights of the source GMM components. w_t : array-like, shape (k_t,) Weights of the target GMM components. + log : bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the GMM optimal transportation matrix. Returns ------- plan : array-like, shape (k_s, k_t) The GMM-OT plan. + log : dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status References ---------- @@ -209,7 +221,7 @@ def gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t): "Target GMM has different amount of components" D = dist_bures_squared(m_s, m_t, C_s, C_t) - return emd(w_s, w_t, D) + return emd(w_s, w_t, D, log=log) def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, diff --git a/ot/utils.py b/ot/utils.py index 1fb6313ca..2ba541ea2 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1301,7 +1301,7 @@ def proj_SDP(S, nx=None, vmin=0.): Parameters ---------- - S : array_like (n, d, d) + S : array_like (n, d, d) or (d, d) The input symmetric matrix or matrices. nx : module, optional The numerical backend module to use. If not provided, the backend will @@ -1315,7 +1315,7 @@ def proj_SDP(S, nx=None, vmin=0.): Returns ------- - P : ndarray (n, d, d) + P : ndarray (n, d, d) or (d, d) The projected symmetric positive definite matrix. """ From 060bbc893c596e55cec92194ccba3180937de085 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 23 Jul 2024 11:17:37 +0200 Subject: [PATCH 22/30] batched nx.sqrtm --- ot/backend.py | 47 +++++++++++++++++++++++++++++++++++++------- test/test_backend.py | 4 ++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 464542947..0b065bfe1 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -926,9 +926,11 @@ def inv(self, a): def sqrtm(self, a): r""" - Computes the matrix square root. Requires input to be definite positive. + Computes the matrix square root. + Requires input symmetric positive semi-definite. - This function follows the api from :any:`scipy.linalg.sqrtm`. + This function follows the api from :any:`scipy.linalg.sqrtm`, + allowing batches. See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html """ @@ -1356,7 +1358,13 @@ def inv(self, a): def sqrtm(self, a): L, V = np.linalg.eigh(a) - return (V * np.sqrt(L)[None, :]) @ V.T + L = np.sqrt(L) + if len(a.shape) == 2: # input was (d, d) + return (V * L[None, :]) @ V.T + else: # input was (n, d, d) + Q = np.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) + # R[i] = Q[i] @ V[i].T + return np.einsum('ijk,ikl->ijl', Q, np.transpose(V, (0, 2, 1))) def eigh(self, a): return np.linalg.eigh(a) @@ -1761,7 +1769,13 @@ def inv(self, a): def sqrtm(self, a): L, V = jnp.linalg.eigh(a) - return (V * jnp.sqrt(L)[None, :]) @ V.T + L = jnp.sqrt(L) + if len(a.shape) == 2: # input was (d, d) + return (V * L[None, :]) @ V.T + else: # input was (n, d, d) + Q = jnp.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) + # R[i] = Q[i] @ V[i].T + return jnp.einsum('ijk,ikl->ijl', Q, jnp.transpose(V, (0, 2, 1))) def eigh(self, a): return jnp.linalg.eigh(a) @@ -2254,7 +2268,14 @@ def inv(self, a): def sqrtm(self, a): L, V = torch.linalg.eigh(a) - return (V * torch.sqrt(L)[None, :]) @ V.T + L = torch.sqrt(L) + if len(a.shape) == 2: # input was (d, d) + return (V * L[None, :]) @ V.T + else: # input was (n, d, d) + Q = torch.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) + # R[i] = Q[i] @ V[i].T + return torch.einsum('ijk,ikl->ijl', Q, + torch.transpose(V, 1, 2)) def eigh(self, a): return torch.linalg.eigh(a) @@ -2661,7 +2682,13 @@ def inv(self, a): def sqrtm(self, a): L, V = cp.linalg.eigh(a) - return (V * cp.sqrt(L)[None, :]) @ V.T + L = cp.sqrt(L) + if len(a.shape) == 2: # input was (d, d) + return (V * L[None, :]) @ V.T + else: # input was (n, d, d) + Q = cp.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) + # R[i] = Q[i] @ V[i].T + return cp.einsum('ijk,ikl->ijl', Q, cp.transpose(V, perm=[0, 2, 1])) def eigh(self, a): return cp.linalg.eigh(a) @@ -3091,7 +3118,13 @@ def inv(self, a): def sqrtm(self, a): L, V = tf.linalg.eigh(a) - return (V * tf.sqrt(L)[None, :]) @ V.T + L = tf.sqrt(L) + if len(a.shape) == 2: # input was (d, d) + return (V * L[None, :]) @ V.T + else: # input was (n, d, d) + Q = tf.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) + # R[i] = Q[i] @ V[i].T + return tf.einsum('ijk,ikl->ijl', Q, tf.transpose(V, (0, 2, 1))) def eigh(self, a): return tf.linalg.eigh(a) diff --git a/test/test_backend.py b/test/test_backend.py index 4be6673d9..d34d56da4 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -622,6 +622,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matrix square root") + A = nx.sqrtm(nx.stack([SquareMb.T @ SquareMb] * 2, axis=0)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("broadcast matrix square root") + D, U = nx.eigh(SquareMb.T @ SquareMb) lst_b.append(nx.to_numpy(nx.dot(U, nx.dot(nx.diag(D), U.T)))) lst_name.append("eigh ") From 72bd04d1d56bdb02f16a3b2d3136d772b9cdf8dc Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 23 Jul 2024 14:16:53 +0200 Subject: [PATCH 23/30] any-broadcastable sqrtm --- ot/backend.py | 54 +++++++++++++++++++------------------------- test/test_backend.py | 2 +- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 0b065bfe1..75cf450d8 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1359,12 +1359,10 @@ def inv(self, a): def sqrtm(self, a): L, V = np.linalg.eigh(a) L = np.sqrt(L) - if len(a.shape) == 2: # input was (d, d) - return (V * L[None, :]) @ V.T - else: # input was (n, d, d) - Q = np.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) - # R[i] = Q[i] @ V[i].T - return np.einsum('ijk,ikl->ijl', Q, np.transpose(V, (0, 2, 1))) + # Q[...] = V[...] @ diag(L[...]) + Q = np.einsum('...jk,...k->...jk', V, L) + # R[...] = Q[...] @ V[...].T + return np.einsum('...jk,...kl->...jl', Q, np.swapaxes(V, -1, -2)) def eigh(self, a): return np.linalg.eigh(a) @@ -1770,12 +1768,10 @@ def inv(self, a): def sqrtm(self, a): L, V = jnp.linalg.eigh(a) L = jnp.sqrt(L) - if len(a.shape) == 2: # input was (d, d) - return (V * L[None, :]) @ V.T - else: # input was (n, d, d) - Q = jnp.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) - # R[i] = Q[i] @ V[i].T - return jnp.einsum('ijk,ikl->ijl', Q, jnp.transpose(V, (0, 2, 1))) + # Q[...] = V[...] @ diag(L[...]) + Q = jnp.einsum('...jk,...k->...jk', V, L) + # R[...] = Q[...] @ V[...].T + return jnp.einsum('...jk,...kl->...jl', Q, jnp.swapaxes(V, -1, -2)) def eigh(self, a): return jnp.linalg.eigh(a) @@ -2269,13 +2265,11 @@ def inv(self, a): def sqrtm(self, a): L, V = torch.linalg.eigh(a) L = torch.sqrt(L) - if len(a.shape) == 2: # input was (d, d) - return (V * L[None, :]) @ V.T - else: # input was (n, d, d) - Q = torch.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) - # R[i] = Q[i] @ V[i].T - return torch.einsum('ijk,ikl->ijl', Q, - torch.transpose(V, 1, 2)) + # Q[...] = V[...] @ diag(L[...]) + Q = torch.einsum('...jk,...k->...jk', V, L) + # R[...] = Q[...] @ V[...].T + return torch.einsum('...jk,...kl->...jl', Q, + torch.transpose(V, -1, -2)) def eigh(self, a): return torch.linalg.eigh(a) @@ -2683,12 +2677,11 @@ def inv(self, a): def sqrtm(self, a): L, V = cp.linalg.eigh(a) L = cp.sqrt(L) - if len(a.shape) == 2: # input was (d, d) - return (V * L[None, :]) @ V.T - else: # input was (n, d, d) - Q = cp.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) - # R[i] = Q[i] @ V[i].T - return cp.einsum('ijk,ikl->ijl', Q, cp.transpose(V, perm=[0, 2, 1])) + # Q[...] = V[...] @ diag(L[...]) + Q = cp.einsum('...jk,...k->...jk', V, L) + # R[...] = Q[...] @ V[...].T + return cp.einsum('...jk,...kl->...jl', Q, + cp.swapaxes(V, -1, -2)) def eigh(self, a): return cp.linalg.eigh(a) @@ -3119,12 +3112,11 @@ def inv(self, a): def sqrtm(self, a): L, V = tf.linalg.eigh(a) L = tf.sqrt(L) - if len(a.shape) == 2: # input was (d, d) - return (V * L[None, :]) @ V.T - else: # input was (n, d, d) - Q = tf.einsum('ijk,ik->ijk', V, L) # Q[i] = V[i] @ diag(L[i]) - # R[i] = Q[i] @ V[i].T - return tf.einsum('ijk,ikl->ijl', Q, tf.transpose(V, (0, 2, 1))) + # Q[...] = V[...] @ diag(L[...]) + Q = tf.einsum('...jk,...k->...jk', V, L) + # R[...] = Q[...] @ V[...].T + return tf.einsum('...jk,...kl->...jl', Q, + tf.linalg.matrix_transpose(V, (0, 2, 1))) def eigh(self, a): return tf.linalg.eigh(a) diff --git a/test/test_backend.py b/test/test_backend.py index d34d56da4..95ec3293f 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -622,7 +622,7 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matrix square root") - A = nx.sqrtm(nx.stack([SquareMb.T @ SquareMb] * 2, axis=0)) + A = nx.sqrtm(nx.stack([SquareMb.T @ SquareMb] * 2, axis=0))[None, :] lst_b.append(nx.to_numpy(A)) lst_name.append("broadcast matrix square root") From e2681b2ea65f7c82004b51877a4230e0bd153f4a Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 23 Jul 2024 15:20:35 +0200 Subject: [PATCH 24/30] einsum and optimisations in ot.gmm --- RELEASES.md | 1 + ot/gmm.py | 102 ++++++++++++++++++++++++++++++---------------------- 2 files changed, 61 insertions(+), 42 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index a2221134b..cf95e489a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,6 +10,7 @@ #659) - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) +- `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) diff --git a/ot/gmm.py b/ot/gmm.py index 8ece24781..2ba990c74 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -105,7 +105,6 @@ def dist_bures_squared(m_s, m_t, C_s, C_t): """ nx = get_backend(m_s, C_s, m_t, C_t) - k_s, k_t = m_s.shape[0], m_t.shape[0] assert m_s.shape[0] == C_s.shape[0], \ "Source GMM has different amount of components" @@ -117,13 +116,17 @@ def dist_bures_squared(m_s, m_t, C_s, C_t): "All GMMs must have the same dimension" D_means = dist(m_s, m_t, metric='sqeuclidean') - D_covs = nx.zeros((k_s, k_t), type_as=m_s) - for i in range(k_s): - Cs12 = nx.sqrtm(C_s[i]) # nx.sqrtm is not batchable - for j in range(k_t): - C = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) - D_covs[i, j] = nx.trace(C_s[i] + C_t[j] - 2 * C) + # C2[i, j] = Cs12[i] @ C_t[j] @ Cs12[i], shape (k_s, k_t, d, d) + Cs12 = nx.sqrtm(C_s) # broadcasts matrix sqrt over (k_s,) + C2 = nx.einsum('ikl,jlm,imn->ijkn', Cs12, C_t, Cs12) + C = nx.sqrtm(C2) # broadcasts matrix sqrt over (k_s, k_t) + + # D_covs[i,j] = trace(C_s[i] + C_t[j] - 2C[i,j]) + trace_C_s = nx.einsum('ikk->i', C_s)[:, None] # (k_s, 1) + trace_C_t = nx.einsum('ikk->i', C_t)[None, :] # (1, k_t) + D_covs = trace_C_s + trace_C_t # broadcasts to (k_s, k_t) + D_covs -= 2 * nx.einsum('ijkk->ij', C) return nx.maximum(D_means + D_covs, 0) @@ -237,33 +240,24 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, ---------- x : array-like, shape (n_samples, d) Input data points. - m_s : array-like, shape (k_s, d) Mean vectors of the source GMM components. - m_t : array-like, shape (k_t, d) Mean vectors of the target GMM components. - C_s : array-like, shape (k_s, d, d) Covariance matrices of the source GMM components. - C_t : array-like, shape (k_t, d, d) Covariance matrices of the target GMM components. - w_s : array-like, shape (k_s,) Weights of the source GMM components. - w_t : array-like, shape (k_t,) Weights of the target GMM components. - plan : array-like, shape (k_s, k_t), optional Optimal transport plan between the source and target GMM components. If not provided, it will be computed internally. - method : {'bary', 'rand'}, optional Method for applying the GMM OT mapping. 'bary' uses barycentric mapping, while 'rand' uses random sampling. Default is 'bary'. - seed : int, optional Seed for the random number generator. Only used when method='rand'. @@ -291,21 +285,21 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, if method == 'bary': normalization = gmm_pdf(x, m_s, C_s, w_s)[:, None] out = nx.zeros(x.shape) + print('where plan > 0', nx.where(plan > 0)) - for i in range(k_s): + # only need to compute for non-zero plan entries + for (i, j) in zip(*nx.where(plan > 0)): Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) + g = gaussian_pdf(x, m_s[i], C_s[i])[:, None] - for j in range(k_t): - g = gaussian_pdf(x, m_s[i], C_s[i])[:, None] - - M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) - A = Cs12inv @ M0 @ Cs12inv - b = m_t[j] - A @ m_s[i] + M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) + A = Cs12inv @ M0 @ Cs12inv + b = m_t[j] - A @ m_s[i] - # gaussian mapping between components i and j applied to x - T_ij_x = x @ A + b - out = out + plan[i, j] * g * T_ij_x + # gaussian mapping between components i and j applied to x + T_ij_x = x @ A + b + out = out + plan[i, j] * g * T_ij_x return out / normalization @@ -317,14 +311,14 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, A = nx.zeros((k_s, k_t, d, d)) b = nx.zeros((k_s, k_t, d)) - for i in range(k_s): + # only need to compute for non-zero plan entries + for (i, j) in zip(*nx.where(plan > 0)): Cs12 = nx.sqrtm(C_s[i]) Cs12inv = nx.inv(Cs12) - for j in range(k_t): - M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) - A[i, j] = Cs12inv @ M0 @ Cs12inv - b[i, j] = m_t[j] - A[i, j] @ m_s[i] + M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12) + A[i, j] = Cs12inv @ M0 @ Cs12inv + b[i, j] = m_t[j] - A[i, j] @ m_s[i] normalization = gmm_pdf(x, m_s, C_s, w_s) # (n_samples,) gs = np.stack( @@ -346,25 +340,49 @@ def gmm_ot_apply_map(x, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-2): - r""" - Args: - m0: gaussian mixture 0 - m1: gaussian mixture 1 - x: (..., d) array-like - y: (..., d) array-like (same shape as x) - atol: absolute tolerance for the condition T_kl(x) = y - - Returns: - density of the MW2 OT plan between m0 and m1 at (x, y) """ + Compute the density of the Gaussian Mixture Model - Optimal Transport + coupling between GMMS at given points. + Given two arrays of points x and y, the function computes the density at + each point `(x[i], y[i])` of the product space. + + Parameters: + ----------- + x : array-like, shape (n_samples, d) + Entry points in source space for density computation. + y : array-like, shape (n_samples, d) + Entry points in target space for density computation. + m_s : array-like, shape (k_s, d) + The means of the source GMM components. + m_t : array-like, shape (k_t, d) + The means of the target GMM components. + C_s : array-like, shape (k_s, d, d) + The covariance matrices of the source GMM components. + C_t : array-like, shape (k_t, d, d) + The covariance matrices of the target GMM components. + w_s : array-like, shape (k_s,) + The weights of the source GMM components. + w_t : array-like, shape (k_t,) + The weights of the target GMM components. + plan : array-like, shape (k_s, k_t), optional + The optimal transport plan between the source and target GMMs. + If not provided, it will be computed using `gmm_ot_plan`. + atol : float, optional + The absolute tolerance used to determine the support of the GMM-OT + coupling. + + Returns: + -------- + density : array-like, shape (n_samples,) + The density of the GMM-ot coupling between the two GMMs. + """ if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) def Tk0k1(k0, k1): A, b = bures_wasserstein_mapping(m_s[k0], m_t[k1], C_s[k0], C_t[k1]) Tx = x @ A + b - print('Tx', Tx.shape) g = gaussian_pdf(x, m_s[k0], C_s[k0]) out = plan[k0, k1] * g norms = np.linalg.norm(Tx - y, axis=-1) From edbf561a1e0a632d34ebbaf48edd6884853d8fb7 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 23 Jul 2024 15:30:48 +0200 Subject: [PATCH 25/30] doc tweak --- ot/gmm.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ot/gmm.py b/ot/gmm.py index 2ba990c74..f45b9d0b5 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -342,12 +342,12 @@ def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=1e-2): """ Compute the density of the Gaussian Mixture Model - Optimal Transport - coupling between GMMS at given points. + coupling between GMMS at given points, as introduced in [69]. Given two arrays of points x and y, the function computes the density at each point `(x[i], y[i])` of the product space. - Parameters: - ----------- + Parameters + ---------- x : array-like, shape (n_samples, d) Entry points in source space for density computation. y : array-like, shape (n_samples, d) @@ -371,11 +371,15 @@ def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, The absolute tolerance used to determine the support of the GMM-OT coupling. - Returns: - -------- + Returns + ------- density : array-like, shape (n_samples,) The density of the GMM-ot coupling between the two GMMs. + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + """ if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) From 68d4324cbc7522f16e25338f683e791c3514d0e8 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 23 Jul 2024 15:31:34 +0200 Subject: [PATCH 26/30] linting fix --- ot/gmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/gmm.py b/ot/gmm.py index f45b9d0b5..cb411aef4 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -368,7 +368,7 @@ def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, The optimal transport plan between the source and target GMMs. If not provided, it will be computed using `gmm_ot_plan`. atol : float, optional - The absolute tolerance used to determine the support of the GMM-OT + The absolute tolerance used to determine the support of the GMM-OT coupling. Returns From 6a7705029b7741463bcb3fec29cb6428f17ddce1 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 23 Jul 2024 15:53:06 +0200 Subject: [PATCH 27/30] gmm ot plan density test --- test/test_gmm.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/test/test_gmm.py b/test/test_gmm.py index f0cfa30dc..d9db899a8 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -9,7 +9,7 @@ import numpy as np import pytest from ot.utils import proj_simplex -from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures_squared, gmm_ot_loss, gmm_ot_plan, gmm_ot_apply_map +from ot.gmm import gaussian_pdf, gmm_pdf, dist_bures_squared, gmm_ot_loss, gmm_ot_plan, gmm_ot_apply_map, gmm_ot_plan_density try: import torch @@ -162,3 +162,16 @@ def test_gradient_gmm_ot_loss_pytorch(): assert (grad_m_s**2).sum().item() > 0 assert (grad_C_s**2).sum().item() > 0 assert (grad_w_s**2).sum().item() > 0 + + +def test_gmm_ot_plan_density(): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() + rng = np.random.RandomState(seed=42) + n = 7 + x = rng.randn(n, 3) + + density = gmm_ot_plan_density(x, x, m_s, m_t, C_s, C_t, w_s, w_t) + assert density.shape == (n,) + + plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) + gmm_ot_plan_density(x, x, m_s, m_t, C_s, C_t, w_s, w_t, plan=plan) From d6b647480d13ce31ca73e5cc17b054744e404baa Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 30 Jul 2024 15:19:28 +0200 Subject: [PATCH 28/30] plot1D_mat tweaks for similarity with old version by default + gmm_density backend + gmm_density does the meshgrid stuff --- examples/others/plot_GMMOT_plan.py | 21 ++-- examples/plot_OT_1D.py | 1 + ot/gmm.py | 29 +++-- ot/plot.py | 181 ++++++++++++++++++----------- test/test_gmm.py | 14 ++- test/test_plot.py | 10 +- 6 files changed, 161 insertions(+), 95 deletions(-) diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index dc0307cae..6aa7db808 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -26,10 +26,11 @@ import numpy as np from ot.plot import plot1D_mat, rescale_for_imshow_plot from ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map +import matplotlib.pyplot as plt ############################################################################## # Generate GMMOT plan plot it -# ------------------------- +# --------------------------- ks = 2 kt = 3 d = 1 @@ -46,21 +47,25 @@ x = np.linspace(a_x, b_x, n) a_y, b_y = 2, 6 y = np.linspace(a_y, b_y, n) -xx, yy = np.meshgrid(x, y, indexing='ij') -plan_density = gmm_ot_plan_density(xx[:, :, None], yy[:, :, None], +plan_density = gmm_ot_plan_density(x[:, None], y[:, None], m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=2e-2) a = gmm_pdf(x[:, None], m_s, C_s, w_s) b = gmm_pdf(y[:, None], m_t, C_t, w_t) -plot1D_mat(a, b, plan_density, title='GMM OT plan') +plt.figure(figsize=(8, 8)) +plot1D_mat(a, b, plan_density, title='GMM OT plan', plot_style='xy', + a_label='Source distribution', b_label='Target distribution') ############################################################################## # Generate GMMOT maps and plot them over plan -# ------------------------- -ax_s, ax_t, ax_M = plot1D_mat(a, b, plan_density, - title='GMM OT plan with T_mean and T_rand maps') +# ------------------------------------------- +plt.figure(figsize=(8, 8)) +ax_s, ax_t, ax_M = plot1D_mat(a, b, plan_density, plot_style='xy', + title='GMM OT plan with T_mean and T_rand maps', + a_label='Source distribution', + b_label='Target distribution') T_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method='bary')[:, 0] x_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n, @@ -78,3 +83,5 @@ s=20, color='orange') ax_M.legend(loc='upper left', fontsize=13) + +# %% diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index 62f0b7dc1..22ddd14e5 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -1,3 +1,4 @@ +# %% # -*- coding: utf-8 -*- """ ====================================== diff --git a/ot/gmm.py b/ot/gmm.py index cb411aef4..e5c45ebf5 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -348,9 +348,9 @@ def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, Parameters ---------- - x : array-like, shape (n_samples, d) + x : array-like, shape (n, d) Entry points in source space for density computation. - y : array-like, shape (n_samples, d) + y : array-like, shape (m, d) Entry points in target space for density computation. m_s : array-like, shape (k_s, d) The means of the source GMM components. @@ -373,29 +373,38 @@ def gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t, Returns ------- - density : array-like, shape (n_samples,) - The density of the GMM-ot coupling between the two GMMs. + density : array-like, shape (n, m) + The density of the GMM-OT coupling between the two GMMs. References ---------- .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. """ + assert x.shape[-1] == y.shape[-1], \ + "x (n, d) and y (m, d) must have the same dimension d" + n, m = x.shape[0], y.shape[0] + nx = get_backend(x, y, m_s, m_t, C_s, C_t, w_s, w_t) + + # hand-made d-variate meshgrid in ij indexing + xx = x[:, None, :] * nx.ones((1, m, 1)) # shapes (n, m, d) + yy = y[None, :, :] * nx.ones((n, 1, 1)) # shapes (n, m, d) + if plan is None: plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) def Tk0k1(k0, k1): A, b = bures_wasserstein_mapping(m_s[k0], m_t[k1], C_s[k0], C_t[k1]) - Tx = x @ A + b - g = gaussian_pdf(x, m_s[k0], C_s[k0]) + Tx = xx @ A + b + g = gaussian_pdf(xx, m_s[k0], C_s[k0]) out = plan[k0, k1] * g - norms = np.linalg.norm(Tx - y, axis=-1) + norms = nx.norm(Tx - yy, axis=-1) out[norms > atol] = 0 return out - mat = np.array( + mat = nx.stack( [ - [Tk0k1(k0, k1) for k1 in range(m_t.shape[0])] + nx.stack([Tk0k1(k0, k1) for k1 in range(m_t.shape[0])]) for k0 in range(m_s.shape[0]) ]) - return np.sum(mat, axis=(0, 1)) + return nx.sum(mat, axis=(0, 1)) diff --git a/ot/plot.py b/ot/plot.py index 555dfce7e..70995633e 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -14,18 +14,21 @@ import numpy as np import matplotlib.pylab as pl +from matplotlib import gridspec -def plot1D_mat(a, b, M, title='', a_label='Source distribution', - b_label='Target distribution', color_source='#7ED321', - color_target='#4A90E2', coupling_cmap='gray'): +def plot1D_mat(a, b, M, title='', plot_style='yx', + a_label='', b_label='', color_source='b', + color_target='r', coupling_cmap='gray_r'): r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distributions. - Creates a subplot with the source distribution :math:`\mathbf{a}` on the - bottom and target distribution :math:`\mathbf{b}` on the left. + Creates a subplot with the source distribution :math:`\mathbf{a}` and target + distribution :math:`\mathbf{b}`t. + In 'yx' mode (default), the source is on the left and + the target on the top, and in 'xy' mode, source on the bottom (upside + down) and the target on the left. The matrix :math:`\mathbf{M}` is shown in between. - Parameters ---------- a : ndarray, shape (na,) @@ -34,94 +37,127 @@ def plot1D_mat(a, b, M, title='', a_label='Source distribution', Target distribution M : ndarray, shape (na, nb) Matrix to plot - a_label: str, optional + title : str, optional + Title of the plot + plot_style : str, optional + Style of the plot, 'yx' or 'xy'. 'yx' places the source on the left and + the target on the top, 'xy' places the source on the bottom (upside + down) and the target on the left. + a_label : str, optional Label for source distribution - b_label: str, optional + b_label : str, optional Label for target distribution - title: str, optional - Title of the plot - color_source: str, optional + color_source : str, optional Color of the source distribution - color_target: str, optional + color_target : str, optional Color of the target distribution - coupling_cmap: str, optional + coupling_cmap : str, optional Colormap for the coupling matrix Returns ------- - ax1: source plot ax - ax2: target plot ax - ax3: coupling plot ax + ax1 : source plot ax + ax2 : target plot ax + ax3 : coupling plot ax .. seealso:: :func:`rescale_for_imshow_plot` """ + assert plot_style in ['yx', 'xy'], "plot_style should be 'yx' or 'xy'" na, nb = M.shape - fig = pl.figure(figsize=(8, 8)) - gs = fig.add_gridspec(3, 3, height_ratios=[1, 1, 1], - width_ratios=[1, 1, 1], - hspace=0, wspace=0) + gs = gridspec.GridSpec(3, 3, height_ratios=[1, 1, 1], + width_ratios=[1, 1, 1], + hspace=0, wspace=0) xa = np.arange(na) xb = np.arange(nb) - # horizontal source on the bottom, flipped vertically - ax1 = fig.add_subplot(gs[2, 1:]) - ax1.plot(xa, np.max(a) - a, color=color_source, label=a_label, linewidth=2) - ax1.fill_between(xa, np.max(a) - a, np.max(a) * np.ones_like(a), - color=color_source, alpha=.5) - ax1.set_xticks(()) - ax1.set_yticks(()) - ax1.set_title(a_label, y=-.15) - ax1.spines['top'].set_visible(False) - ax1.spines['right'].set_visible(False) - ax1.spines['bottom'].set_visible(False) - ax1.spines['left'].set_visible(False) - - # vertical target on the left - ax2 = fig.add_subplot(gs[0:2, 0]) - ax2.plot(b, xb, color=color_target, label=b_label, linewidth=2) - ax2.fill_between(b, xb, color=color_target, alpha=.5) - ax2.invert_xaxis() - ax2.invert_yaxis() - ax2.set_xticks(()) - ax2.set_yticks(()) - ax2.set_title(b_label) - ax2.spines['top'].set_visible(False) - ax2.spines['right'].set_visible(False) - ax2.spines['bottom'].set_visible(False) - ax2.spines['left'].set_visible(False) - - # plan image, transposed since imshow is in "yx" coords - ax3 = fig.add_subplot(gs[0:2, 1:], sharey=ax2, sharex=ax1) - ax3.imshow(1 - M.T, interpolation='nearest', origin='lower', - cmap=coupling_cmap) - ax3.set_aspect('equal') - ax3.set_title(title) - - # Set spines visibility to True and customize if desired - ax3.spines['top'].set_visible(True) - ax3.spines['right'].set_visible(True) - ax3.spines['bottom'].set_visible(True) - ax3.spines['left'].set_visible(True) - - pl.subplots_adjust(hspace=0, wspace=0) - return ax1, ax2, ax3 - - -def rescale_for_imshow_plot(x, y, n, a_y=None, b_y=None): + # helper function for code factorisation + def _set_ticks_and_spines(ax, empty_ticks=True, visible_spines=False): + if empty_ticks: + ax.set_xticks(()) + ax.set_yticks(()) + + ax.spines['top'].set_visible(visible_spines) + ax.spines['right'].set_visible(visible_spines) + ax.spines['bottom'].set_visible(visible_spines) + ax.spines['left'].set_visible(visible_spines) + + if plot_style == 'xy': + # horizontal source on the bottom, flipped vertically + ax1 = pl.subplot(gs[2, 1:]) + ax1.plot(xa, np.max(a) - a, color=color_source, linewidth=2) + ax1.fill(xa, np.max(a) - a, np.max(a) * np.ones_like(a), + color=color_source, alpha=.5) + ax1.set_title(a_label, y=-.15) + + # vertical target on the left + ax2 = pl.subplot(gs[0:2, 0]) + ax2.plot(b, xb, color=color_target, linewidth=2) + ax2.fill(b, xb, color=color_target, alpha=.5) + ax2.invert_xaxis() + ax2.invert_yaxis() + ax2.set_title(b_label) + + _set_ticks_and_spines(ax1, empty_ticks=True, visible_spines=False) + _set_ticks_and_spines(ax2, empty_ticks=True, visible_spines=False) + + # coupling matrix in the middle + ax3 = pl.subplot(gs[0:2, 1:], sharey=ax2, sharex=ax1) + ax3.imshow(M.T, interpolation='nearest', origin='lower', + cmap=coupling_cmap) + ax3.set_title(title) + _set_ticks_and_spines(ax3, empty_ticks=False, visible_spines=True) + + pl.subplots_adjust(hspace=0, wspace=0) + return ax1, ax2, ax3 + + else: # plot_style == 'yx' + # vertical source on the left + ax1 = pl.subplot(gs[1:, 0]) + ax1.plot(a, xa, color=color_source, linewidth=2) + ax1.fill(a, xa, color=color_source, alpha=.5) + ax1.invert_xaxis() + ax1.set_title(a_label) + + # horizontal target on the top + ax2 = pl.subplot(gs[0, 1:]) + ax2.plot(xb, b, color=color_target, linewidth=2) + ax2.fill(xb, b, color=color_target, alpha=.5) + ax2.set_title(b_label) + + _set_ticks_and_spines(ax1, empty_ticks=True, visible_spines=False) + _set_ticks_and_spines(ax2, empty_ticks=True, visible_spines=False) + + # coupling matrix in the middle + ax3 = pl.subplot(gs[1:, 1:], sharey=ax1, sharex=ax2) + ax3.imshow(M, interpolation='nearest', cmap=coupling_cmap) + # Set title below matrix plot + ax3.text(0.5, -0.025, title, + ha='center', va='top', + transform=ax3.transAxes, + fontsize='large') + _set_ticks_and_spines(ax3, empty_ticks=False, visible_spines=True) + + pl.subplots_adjust(hspace=0, wspace=0) + return ax1, ax2, ax3 + + +def rescale_for_imshow_plot(x, y, n, m=None, a_y=None, b_y=None): r""" - Gives arrays xr, yr that can be plotted over an (n, n) + Gives arrays xr, yr that can be plotted over an (n, m) imshow plot (in 'xy' coordinates). If `a_y` or `b_y` is provided, y is sliced over its indices such that y stays in [ay, by]. Parameters ---------- x : ndarray, shape (nx,) - y : ndarray, shape (ny,) + y : ndarray, shape (nx,) n : int - Size of the imshow plot on which to plot (x, y) + x-axis size of the imshow plot on which to plot (x, y) + m : int, optional + y-axis size of the imshow plot, defaults to n a_y : float, optional Lower bound for y b_y : float, optional @@ -130,8 +166,8 @@ def rescale_for_imshow_plot(x, y, n, a_y=None, b_y=None): Returns ------- xr : ndarray, shape (nx,) - Rescaled x values - yr : ndarray, shape (ny,) + Rescaled x values (due to slicing, may have less elements than x) + yr : ndarray, shape (nx,) Rescaled y values (due to slicing, may have less elements than y) .. seealso:: @@ -140,13 +176,16 @@ def rescale_for_imshow_plot(x, y, n, a_y=None, b_y=None): """ # slice over the y values that are in the y range a_x, b_x = np.min(x), np.max(x) + assert x.shape[0] == y.shape[0], "x and y arrays should have the same size" if a_y is None: a_y = np.min(y) if b_y is None: b_y = np.max(y) + if m is None: + m = n idx = (y >= a_y) & (y <= b_y) x_rescaled = (x[idx] - a_x) * (n - 1) / (b_x - a_x) - y_rescaled = (y[idx] - a_y) * (n - 1) / (b_y - a_y) + y_rescaled = (y[idx] - a_y) * (m - 1) / (b_y - a_y) return x_rescaled, y_rescaled diff --git a/test/test_gmm.py b/test/test_gmm.py index d9db899a8..5280b2c14 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -164,14 +164,18 @@ def test_gradient_gmm_ot_loss_pytorch(): assert (grad_w_s**2).sum().item() > 0 -def test_gmm_ot_plan_density(): - m_s, m_t, C_s, C_t, w_s, w_t = get_gmms() +def test_gmm_ot_plan_density(nx): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) rng = np.random.RandomState(seed=42) n = 7 - x = rng.randn(n, 3) + x = nx.from_numpy(rng.randn(n, 3)) + y = nx.from_numpy(rng.randn(n + 1, 3)) - density = gmm_ot_plan_density(x, x, m_s, m_t, C_s, C_t, w_s, w_t) - assert density.shape == (n,) + density = gmm_ot_plan_density(x, y, m_s, m_t, C_s, C_t, w_s, w_t) + assert density.shape == (n, n + 1) plan = gmm_ot_plan(m_s, m_t, C_s, C_t, w_s, w_t) gmm_ot_plan_density(x, x, m_s, m_t, C_s, C_t, w_s, w_t, plan=plan) + + with pytest.raises(AssertionError): + gmm_ot_plan_density(x[:, 1:], y, m_s, m_t, C_s, C_t, w_s, w_t) diff --git a/test/test_plot.py b/test/test_plot.py index eae832364..a3aade5f8 100644 --- a/test/test_plot.py +++ b/test/test_plot.py @@ -35,7 +35,11 @@ def test_plot1D_mat(): M = ot.dist(x.reshape((n_bins, 1)), x.reshape((n_bins, 1))) M /= M.max() - ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + ot.plot.plot1D_mat(a, b, M) + ot.plot.plot1D_mat(a, b, M, plot_style='xy') + + with pytest.raises(AssertionError): + ot.plot.plot1D_mat(a, b, M, plot_style='NotAValidStyle') @pytest.mark.skipif(nogo, reason="Matplotlib not installed") @@ -54,9 +58,11 @@ def test_rescale_for_imshow_plot(): assert x_rescaled.shape == (n, ) assert y_rescaled.shape == (n, ) - x_rescaled, y_rescaled = ot.plot.rescale_for_imshow_plot(x, y, n, a_y=a_y + 1, b_y=b_y - 1) + x_rescaled, y_rescaled = ot.plot.rescale_for_imshow_plot(x, y, n, m=n, a_y=a_y + 1, b_y=b_y - 1) assert x_rescaled.shape[0] <= n assert y_rescaled.shape[0] <= n + with pytest.raises(AssertionError): + ot.plot.rescale_for_imshow_plot(x[3:], y, n) @pytest.mark.skipif(nogo, reason="Matplotlib not installed") From 3846143dbbbd08159e1924403673702a56673365 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 30 Jul 2024 15:40:05 +0200 Subject: [PATCH 29/30] gmm_density mask without slicing --- examples/others/plot_GMMOT_plan.py | 1 + ot/gmm.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index 6aa7db808..8b6db31ba 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -1,3 +1,4 @@ +# %% # -*- coding: utf-8 -*- r""" ==================================================== diff --git a/ot/gmm.py b/ot/gmm.py index e5c45ebf5..06caf90c7 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -399,7 +399,7 @@ def Tk0k1(k0, k1): g = gaussian_pdf(xx, m_s[k0], C_s[k0]) out = plan[k0, k1] * g norms = nx.norm(Tx - yy, axis=-1) - out[norms > atol] = 0 + out = out * ((norms < atol) * 1.) return out mat = nx.stack( From 98a7fba3a89fc922c95db74ff471b13864d80c47 Mon Sep 17 00:00:00 2001 From: eloitanguy Date: Tue, 30 Jul 2024 15:51:34 +0200 Subject: [PATCH 30/30] fixed erroneuous simplex projection in GMM_flow example --- examples/others/plot_GMM_flow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index d1852cc9e..8cff1cc42 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -115,7 +115,6 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=.5, alpha=1): loss.backward() opt.step() with torch.no_grad(): - w_s.data = proj_simplex(w_s.data) C_s.data = proj_SDP(C_s.data, vmin=1e-6) m_list.append(m_s.data.numpy().copy()) C_list.append(C_s.data.numpy().copy()) @@ -188,3 +187,5 @@ def index_to_color(i): pl.axis(axis) pl.legend(fontsize=15) + +# %%