Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PythonOT"
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
authors = ["David Widmann"]
version = "0.1.5"
version = "0.1.6"

[deps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ PythonOT.Smooth.smooth_ot_dual
sinkhorn_unbalanced
sinkhorn_unbalanced2
barycenter_unbalanced
mm_unbalanced
```
3 changes: 2 additions & 1 deletion src/PythonOT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export emd,
barycenter_unbalanced,
sinkhorn_unbalanced,
sinkhorn_unbalanced2,
empirical_sinkhorn_divergence
empirical_sinkhorn_divergence,
mm_unbalanced

const pot = PyCall.PyNULL()

Expand Down
76 changes: 63 additions & 13 deletions src/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,11 @@ julia> ν = [0.0, 1.0];

julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];

julia> sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000)
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
3×2 Matrix{Float64}:
0.0 0.499964
0.0 0.200188
0.0 0.29983
0.0 0.5
0.0 0.2002
0.0 0.2998
```

It is possible to provide multiple target marginals as columns of a matrix. In this case the
Expand All @@ -325,10 +325,10 @@ optimal transport costs are returned:
```jldoctest sinkhorn_unbalanced
julia> ν = [0.0 0.5; 1.0 0.5];

julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=6)
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
2-element Vector{Float64}:
0.949709
0.449411
0.9497
0.4494
```

See also: [`sinkhorn_unbalanced2`](@ref)
Expand Down Expand Up @@ -371,20 +371,19 @@ julia> ν = [0.0, 1.0];

julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];

julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
1-element Vector{Float64}:
0.949709
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
0.9497
```

It is possible to provide multiple target marginals as columns of a matrix:

```jldoctest sinkhorn_unbalanced2
julia> ν = [0.0 0.5; 1.0 0.5];

julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=6)
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
2-element Vector{Float64}:
0.949709
0.449411
0.9497
0.4494
```

See also: [`sinkhorn_unbalanced`](@ref)
Expand Down Expand Up @@ -516,3 +515,54 @@ Python function.
function entropic_gromov_wasserstein(μ, ν, Cμ, Cν, ε, loss="square_loss"; kwargs...)
return pot.gromov.entropic_gromov_wasserstein(Cμ, Cν, μ, ν, loss, ε; kwargs...)
end

"""
mm_unbalanced(a, b, M, reg_m; reg=0, c=a*b', kwargs...)

Solve the unbalanced optimal transport problem and return the OT plan.
The function solves the following optimization problem:

```math
W = \\min_{\\gamma \\geq 0} \\langle \\gamma, M \\rangle_F +
\\mathrm{reg_{m1}} \\cdot \\operatorname{div}(\\gamma \\mathbf{1}, a) +
\\mathrm{reg_{m2}} \\cdot \\operatorname{div}(\\gamma^\\mathsf{T} \\mathbf{1}, b) +
\\mathrm{reg} \\cdot \\operatorname{div}(\\gamma, c)
```

where

- `M` is the metric cost matrix,
- `a` and `b` are source and target unbalanced distributions,
- `c` is a reference distribution for the regularization,
- `reg_m` is the marginal relaxation term (if it is a scalar or an indexable object of length 1, then the same term is applied to both marginal relaxations), and
- `reg` is a regularization term.

This function is a wrapper of the function
[`mm_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.mm_unbalanced) in the
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
Python function.

# Examples

```jldoctest
julia> a=[.5, .5];

julia> b=[.5, .5];

julia> M=[1. 36.; 9. 4.];

julia> round.(mm_unbalanced(a, b, M, 5, div="kl"), digits=2)
2×2 Matrix{Float64}:
0.45 0.0
0.0 0.34

julia> round.(mm_unbalanced(a, b, M, 5, div="l2"), digits=2)
2×2 Matrix{Float64}:
0.4 0.0
0.0 0.1
```

"""
function mm_unbalanced(a, b, M, reg_m; kwargs...)
return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)
end