Skip to content

[MRG] Tiny fix in SSNB #535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f2c65a8
added functions to a new mapping module
eloitanguy Sep 20, 2023
284c004
simplify ssnb function structure
eloitanguy Sep 20, 2023
a2d9ae8
RELEASES.md conflict fix
eloitanguy Sep 20, 2023
cd7be18
SSNB example
eloitanguy Sep 20, 2023
c82e0e6
removed numpy saves from example for prod
eloitanguy Sep 20, 2023
79a99ef
tests apart from the import exception catch
eloitanguy Sep 20, 2023
801c280
tests apart from the import exception catch
eloitanguy Sep 20, 2023
ff46975
da class and tests
eloitanguy Sep 21, 2023
60944f5
guessed PR number
eloitanguy Sep 21, 2023
d0d42be
Merge remote-tracking branch 'origin/master' into contrib_ssnb
eloitanguy Sep 21, 2023
7bc3213
removed unused import
eloitanguy Sep 21, 2023
55f0e09
PEP8 tab errors fix
eloitanguy Sep 21, 2023
9dfd82e
skip ssnb test if no cvxpy
eloitanguy Sep 21, 2023
0489392
test and doc fixes
eloitanguy Sep 21, 2023
2adcab3
doc dependency + minor comment in ot __init__.py
eloitanguy Sep 21, 2023
945554e
fetch ot main diffsh
eloitanguy Sep 21, 2023
3e2e5b8
PEP8 fixes
eloitanguy Sep 21, 2023
596edd4
test typo fix
eloitanguy Sep 21, 2023
80fa0b9
ssnb da backend test fix
eloitanguy Sep 21, 2023
0a349ce
moved joint ot mappings to the mapping module
eloitanguy Sep 21, 2023
66e484b
merge with pythonot master
eloitanguy Sep 26, 2023
8eb1542
better ssnb example + ssnb initilisation + small joint_ot_mapping tests
eloitanguy Sep 27, 2023
f29dcff
better ssnb example + ssnb initilisation + small joint_ot_mapping tests
eloitanguy Sep 27, 2023
7a5e6d7
removed unused dependency in example
eloitanguy Sep 27, 2023
55a9b28
no longer import mapping in __init__ + example thumbnail fix + made q…
eloitanguy Oct 11, 2023
2ed1f02
Merge remote-tracking branch 'origin/master' into contrib_ssnb
eloitanguy Oct 18, 2023
d7b3c51
merge with POT main
eloitanguy Oct 18, 2023
b8a0774
fix barycentric projection factor omission in SSNB solver init
eloitanguy Oct 18, 2023
fe1163a
added modif in RELEASES.md
eloitanguy Oct 18, 2023
2236ce4
fix PR number in RELEASES.md
eloitanguy Oct 18, 2023
b3578f0
broadcast fix
eloitanguy Oct 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## 0.9.2dev

#### New features
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526)
+ Added support for [Nearest Brenier Potentials (SSNB)](http://proceedings.mlr.press/v108/paty20a/paty20a.pdf) (PR #526) + minor fix (PR #535)
+ Tweaked `get_backend` to ignore `None` inputs (PR #525)
+ Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507)
+ The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533)
Expand Down
16 changes: 1 addition & 15 deletions examples/others/plot_SSNB.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
# License: MIT License

# sphinx_gallery_thumbnail_number = 4
# sphinx_gallery_thumbnail_number = 3

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -63,20 +63,6 @@
plt.legend(loc='upper right')
plt.show()

# %%
# Plotting image of barycentric projection (SSNB initialisation values)
plt.clf()
pi = ot.emd(ot.unif(n_fitting_samples), ot.unif(n_fitting_samples), ot.dist(Xs, Xt))
plt.scatter(Xs[:, 0], Xs[:, 1], c='dodgerblue', label='source')
plt.scatter(Xt[:, 0], Xt[:, 1], c='red', label='target')
bar_img = pi @ Xt
for i in range(n_fitting_samples):
plt.plot([Xs[i, 0], bar_img[i, 0]], [Xs[i, 1], bar_img[i, 1]], color='black', alpha=.5)
plt.title('Images of in-data source samples by the barycentric map')
plt.legend(loc='upper right')
plt.axis('equal')
plt.show()

# %%
# Fitting the Nearest Brenier Potential
L = 3 # need L > 2 to allow the 2*y term, default is 1.4
Expand Down
2 changes: 1 addition & 1 deletion ot/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def nearest_brenier_potential_fit(X, V, X_classes=None, a=None, b=None, strongly
if init_method == 'target':
G_val = V
else: # Init G_val with barycentric projection
G_val = emd(a, b, dist(X, V)) @ V
G_val = emd(a, b, dist(X, V)) @ V / a.reshape(n, 1)
phi_val = None
log_dict = {
'G_list': [],
Expand Down