Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions ot/gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,

where :
- M is the (ns,nt) metric cost matrix
- :math:`f` is the regularization term ( and df is its gradient)
- a and b are source and target weights (sum to 1)
- p and q are source and target weights (sum to 1)
- L is a loss function to account for the misfit between the similarity matrices

The algorithm used for solving the problem is conditional gradient as discussed in [24]_
Expand All @@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
Distribution in the target space
loss_fun : str, optional
Loss function used for the solver
max_iter : int, optional
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.
log : bool, optional
record log if True
**kwargs : dict
parameters can be directly passed to the ot.optim.cg solver

Expand Down Expand Up @@ -493,11 +488,11 @@ def df(G):
return gwggrad(constC, hC1, hC2, G)

if log:
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
log['fgw_dist'] = log['loss'][::-1][0]
return res, log
else:
return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)


def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
Expand All @@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5

where :
- M is the (ns,nt) metric cost matrix
- :math:`f` is the regularization term ( and df is its gradient)
- a and b are source and target weights (sum to 1)
- p and q are source and target weights (sum to 1)
- L is a loss function to account for the misfit between the similarity matrices
The algorithm used for solving the problem is conditional gradient as discussed in [1]_

Expand All @@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
Distribution in the target space.
loss_fun : str, optional
Loss function used for the solver.
max_iter : int, optional
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
Record log if True.
alpha : float, optional
Trade-off parameter (0 < alpha < 1)
armijo : bool, optional
If True the steps of the line-search is found via an armijo research.
Else closed form is used. If there is convergence issues use False.
log : bool, optional
Record log if True.
**kwargs : dict
Parameters can be directly pased to the ot.optim.cg solver.

Expand Down Expand Up @@ -573,7 +563,7 @@ def f(G):
def df(G):
return gwggrad(constC, hC1, hC2, G)

res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
if log:
log['fgw_dist'] = log['loss'][::-1][0]
log['T'] = res
Expand Down Expand Up @@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
Whether to fix the structure of the barycenter during the updates
fixed_features : bool
Whether to fix the feature of the barycenter during the updates
loss_fun : str
Loss function used for the solver either 'square_loss' or 'kl_loss'
max_iter : int, optional
Max number of iterations
tol : float, optional
Stop threshol on error (>0).
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
init_C : ndarray, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
T_temp = [t.T for t in T]
C = update_sructure_matrix(p, lambdas, T_temp, Cs)

T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]

# T is N,ns
Expand Down