diff --git a/ot/gromov.py b/ot/gromov.py index 986934195..43780a4f6 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, where : - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [24]_ @@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, Distribution in the target space loss_fun : str, optional Loss function used for the solver - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True + alpha : float, optional + Trade-off parameter (0 < alpha < 1) armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. + log : bool, optional + record log if True **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -493,11 +488,11 @@ def df(G): return gwggrad(constC, hC1, hC2, G) if log: - res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) log['fgw_dist'] = log['loss'][::-1][0] return res, log else: - return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 where : - M is the (ns,nt) metric cost matrix - - :math:`f` is the regularization term ( and df is its gradient) - - a and b are source and target weights (sum to 1) + - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [1]_ @@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 Distribution in the target space. loss_fun : str, optional Loss function used for the solver. - max_iter : int, optional - Max number of iterations - tol : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - Record log if True. + alpha : float, optional + Trade-off parameter (0 < alpha < 1) armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. + log : bool, optional + Record log if True. **kwargs : dict Parameters can be directly pased to the ot.optim.cg solver. @@ -573,7 +563,7 @@ def f(G): def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) if log: log['fgw_dist'] = log['loss'][::-1][0] log['T'] = res @@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ Whether to fix the structure of the barycenter during the updates fixed_features : bool Whether to fix the feature of the barycenter during the updates + loss_fun : str + Loss function used for the solver either 'square_loss' or 'kl_loss' + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshol on error (>0). + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. init_C : ndarray, shape (N,N), optional Initialization for the barycenters' structure matrix. If not set a random init is used. @@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T_temp = [t.T for t in T] C = update_sructure_matrix(p, lambdas, T_temp, Cs) - T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, + T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)] # T is N,ns