Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions ot/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
# The corresponding scipy function does not work for matrices


def line_search_armijo(f, xk, pk, gfk, old_fval,
args=(), c1=1e-4, alpha0=0.99):
def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
alpha0=0.99, alpha_min=None, alpha_max=None
):
r"""
Armijo linesearch function that works with matrices

Expand All @@ -44,6 +46,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
alpha_min : float, optional
minimum value for alpha
alpha_max : float, optional
maximum value for alpha

Returns
-------
Expand Down Expand Up @@ -80,13 +86,15 @@ def phi(alpha1):
if alpha is None:
return 0., fc[0], phi0
else:
# scalar_search_armijo can return alpha > 1
alpha = min(1, alpha)
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
return alpha, fc[0], phi1


def solve_linesearch(cost, G, deltaG, Mi, f_val,
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
def solve_linesearch(
cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
):
"""
Solve the linesearch in the FW iterations

Expand Down Expand Up @@ -117,6 +125,10 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
M : array-like (ns,nt), optional
Cost matrix between the features. Only used and necessary when armijo=False
alpha_min : float, optional
Minimum value for alpha
alpha_max : float, optional
Maximum value for alpha

Returns
-------
Expand All @@ -136,7 +148,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
International Conference on Machine Learning (ICML). 2019.
"""
if armijo:
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
alpha, fc, f_val = line_search_armijo(
cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
)
else: # requires symetric matrices
G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
if isinstance(M, int) or isinstance(M, float):
Expand All @@ -150,6 +164,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
c = cost(G)

alpha = solve_1d_linesearch_quad(a, b, c)
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
fc = None
f_val = cost(G + alpha * deltaG)

Expand Down Expand Up @@ -274,7 +290,10 @@ def cost(G):
deltaG = Gc - G

# line search
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
alpha, fc, f_val = solve_linesearch(
cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
alpha_min=0., alpha_max=1., **kwargs
)

G = G + alpha * deltaG

Expand Down Expand Up @@ -420,7 +439,9 @@ def cost(G):

# line search
dcost = Mi + reg1 * (1 + nx.log(G)) # ??
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
alpha, fc, f_val = line_search_armijo(
cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
)

G = G + alpha * deltaG

Expand Down