From 7b0aae6cbf86027a26d4e74f277b3c58ebcc9023 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 17 Nov 2021 14:44:26 +0100 Subject: [PATCH 1/3] Now limiting alpha to a minimum value as well as a max value --- ot/optim.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/ot/optim.py b/ot/optim.py index cacec538f..fa03f6ad3 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -18,8 +18,10 @@ # The corresponding scipy function does not work for matrices -def line_search_armijo(f, xk, pk, gfk, old_fval, - args=(), c1=1e-4, alpha0=0.99): +def line_search_armijo( + f, xk, pk, gfk, old_fval, args=(), c1=1e-4, + alpha0=0.99, alpha_min=None, alpha_max=None +): r""" Armijo linesearch function that works with matrices @@ -80,13 +82,15 @@ def phi(alpha1): if alpha is None: return 0., fc[0], phi0 else: - # scalar_search_armijo can return alpha > 1 - alpha = min(1, alpha) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) return alpha, fc[0], phi1 -def solve_linesearch(cost, G, deltaG, Mi, f_val, - armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None): +def solve_linesearch( + cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, + reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None +): """ Solve the linesearch in the FW iterations @@ -136,7 +140,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, International Conference on Machine Learning (ICML). 2019. """ if armijo: - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val) + alpha, fc, f_val = line_search_armijo( + cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max + ) else: # requires symetric matrices G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M) if isinstance(M, int) or isinstance(M, float): @@ -150,6 +156,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val, c = cost(G) alpha = solve_1d_linesearch_quad(a, b, c) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) fc = None f_val = cost(G + alpha * deltaG) @@ -274,7 +282,10 @@ def cost(G): deltaG = Gc - G # line search - alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs) + alpha, fc, f_val = solve_linesearch( + cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, + alpha_min=0., alpha_max=1., **kwargs + ) G = G + alpha * deltaG @@ -420,7 +431,9 @@ def cost(G): # line search dcost = Mi + reg1 * (1 + nx.log(G)) # ?? - alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val) + alpha, fc, f_val = line_search_armijo( + cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1. + ) G = G + alpha * deltaG From e8c9c4bf884445efc89af9aa47112079d0a414f2 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 17 Nov 2021 14:49:05 +0100 Subject: [PATCH 2/3] Docs --- ot/optim.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ot/optim.py b/ot/optim.py index fa03f6ad3..12229d731 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -45,7 +45,11 @@ def line_search_armijo( c1 : float, optional :math:`c_1` const in armijo rule (>0) alpha0 : float, optional - initial step (>0) + initial step (>0), + alpha_min : float, optional + minimum value for alpha + alpha_max : float, optional + maximum value for alpha Returns ------- @@ -121,6 +125,10 @@ def solve_linesearch( Constant for the gromov cost. See :ref:`[24] `. Only used and necessary when armijo=False M : array-like (ns,nt), optional Cost matrix between the features. Only used and necessary when armijo=False + alpha_min : float, optional + Minimum value for alpha + alpha_max : float, optional + Maximum value for alpha Returns ------- From 9b940748c2cec81d062b3c77f0e6bbeff1ad5775 Mon Sep 17 00:00:00 2001 From: Nathan Cassereau Date: Wed, 17 Nov 2021 14:49:42 +0100 Subject: [PATCH 3/3] typo --- ot/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/optim.py b/ot/optim.py index 12229d731..9b8a8f799 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -45,7 +45,7 @@ def line_search_armijo( c1 : float, optional :math:`c_1` const in armijo rule (>0) alpha0 : float, optional - initial step (>0), + initial step (>0) alpha_min : float, optional minimum value for alpha alpha_max : float, optional