Skip to content

Commit

Permalink
gromov-wasserstein (#19)
Browse files Browse the repository at this point in the history
* add gromov-wasserstein

* add docstrings for gromov-wasserstein

* update to LTS

* bump CI Julia version
  • Loading branch information
zsteve committed Jan 25, 2023
1 parent 9afa6af commit bbf2b79
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.8'
- '1'
os:
- ubuntu-latest
Expand Down
41 changes: 41 additions & 0 deletions src/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,44 @@ See also: [`barycenter`](@ref)
function barycenter_unbalanced(A, C, ε, λ; kwargs...)
return pot.barycenter_unbalanced(A, C, ε, λ; kwargs...)
end

"""
gromov_wasserstein(μ, ν, Cμ, Cν, loss = "square_loss"; kwargs...)
Compute the exact Gromov-Wasserstein transport plan between `(μ, Cμ)` and `(ν, Cν)`.
The Gromov-Wasserstein transport problem seeks to find a minimizer of
```math
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, k, l} L((C_μ)_{ik}, (C_ν)_{jl}) \\gamma_{ij} \\gamma_{kl},
```
where ``L`` is quadratic (`loss = "square_loss"`) or the Kullback-Leibler divergence (`loss = "kl_loss"`).
This function is a wrapper of the function
[`gromov_wasserstein`](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.gromov_wasserstein) in the
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
Python function.
"""
function gromov_wasserstein(μ, ν, Cμ, Cν, loss="square_loss"; kwargs...)
return pot.gromov.gromov_wasserstein(Cμ, Cν, μ, ν, loss; kwargs...)
end

"""
entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss = "square_loss"; kwargs...)
Compute the entropy-regularized Gromov-Wasserstein transport plan between `(μ, Cμ)` and `(ν, Cν)` with parameter `ε`.
The entropy-regularized Gromov-Wasserstein transport problem seeks to find a minimizer of
```math
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, k, l} L((C_μ)_{ik}, (C_ν)_{jl}) \\gamma_{ij} \\gamma_{kl} + ε \\Omega(\\gamma),
```
where ``L`` is quadratic (`loss = "square_loss"`) or the Kullback-Leibler divergence (`loss = "kl_loss"`)
and ``\\Omega(\\gamma) = \\sum_{ij} \\gamma_{ij} \\log(\\gamma_{ij})`` is the entropic regularization term.
This function is a wrapper of the function
[`entropic_gromov_wasserstein`](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.entropic_gromov_wasserstein) in the
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
Python function.
"""
function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...)
return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...)
end

0 comments on commit bbf2b79

Please sign in to comment.