From 32101c764bbc298e18b7ffe93721d4a991b114b9 Mon Sep 17 00:00:00 2001 From: theo Date: Tue, 20 Dec 2022 08:36:54 +0100 Subject: [PATCH 1/2] fix stopThr hardcoded in some places --- ot/bregman.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 4e1a25c1d..38b38ec32 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1278,7 +1278,7 @@ def get_reg(n): # exponential decreasing regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, - numItermax=numInnerItermax, stopThr=1e-9, + numItermax=numInnerItermax, stopThr=stopThr, warmstart=(alpha, beta), verbose=False, print_period=20, tau=tau, log=True) @@ -3293,17 +3293,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli if log: sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, + stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -3320,17 +3320,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli else: sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=1e-9, + numIterMax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs) From ab49b3d6b5721780a75ab6e915a0368e0efed71d Mon Sep 17 00:00:00 2001 From: theo Date: Tue, 20 Dec 2022 08:57:02 +0100 Subject: [PATCH 2/2] added fix documentation in RELEASES.Md --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 9cfdd352a..5e5a1db61 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -32,6 +32,7 @@ roughly 2^31) (PR #381) - Fixed weak optimal transport docstring (Issue #404, PR #410) - Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, PR #413) +- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls that explicitly specified `stopThr=1e-9` (Issue #421). ## 0.8.2