Skip to content

Commit f13d503

Browse files
authored
Merge 8d7b1c6 into 7c72cf8
2 parents 7c72cf8 + 8d7b1c6 commit f13d503

File tree

4 files changed

+83
-36
lines changed

4 files changed

+83
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PythonOT"
22
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
33
authors = ["David Widmann"]
4-
version = "0.1.6"
4+
version = "0.1.9"
55

66
[deps]
77
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ sinkhorn_unbalanced2
3737
barycenter_unbalanced
3838
mm_unbalanced
3939
```
40+
41+
## Partial optimal transport
42+
43+
```@docs
44+
entropic_partial_wasserstein
45+
```

src/PythonOT.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ export emd,
1313
sinkhorn_unbalanced,
1414
sinkhorn_unbalanced2,
1515
empirical_sinkhorn_divergence,
16-
mm_unbalanced
16+
mm_unbalanced,
17+
entropic_partial_wasserstein,
18+
entropic_partial_gromov_wasserstein,
19+
entropic_partial_gromov_wasserstein2,
20+
partial_wasserstein,
21+
partial_wasserstein2
1722

1823
const pot = PyCall.PyNULL()
1924

src/lib.jl

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -306,29 +306,18 @@ Python function.
306306
# Examples
307307
308308
```jldoctest sinkhorn_unbalanced
309-
julia> μ = [0.5, 0.2, 0.3];
309+
julia> μ = [0.5, 0.5];
310310
311-
julia> ν = [0.0, 1.0];
311+
julia> ν = [0.5, 0.5];
312312
313-
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
313+
julia> C = [0.0 1.0; 1.0 0.0];
314314
315-
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
316-
3×2 Matrix{Float64}:
317-
0.0 0.5
318-
0.0 0.2002
319-
0.0 0.2998
315+
julia> round.(sinkhorn_unbalanced(μ, ν, C, 1, 1); sigdigits=7)
316+
2×2 Matrix{Float64}:
317+
0.322054 0.118477
318+
0.118477 0.322054
320319
```
321320
322-
It is possible to provide multiple target marginals as columns of a matrix. In this case the
323-
optimal transport costs are returned:
324-
325-
```jldoctest sinkhorn_unbalanced
326-
julia> ν = [0.0 0.5; 1.0 0.5];
327-
328-
julia> round.(sinkhorn_unbalanced(μ, ν, C, 0.01, 1_000); sigdigits=4)
329-
2-element Vector{Float64}:
330-
0.9497
331-
0.4494
332321
```
333322
334323
See also: [`sinkhorn_unbalanced2`](@ref)
@@ -365,25 +354,14 @@ Python function.
365354
# Examples
366355
367356
```jldoctest sinkhorn_unbalanced2
368-
julia> μ = [0.5, 0.2, 0.3];
357+
julia> μ = [0.5, 0.1];
369358
370-
julia> ν = [0.0, 1.0];
371-
372-
julia> C = [0.0 1.0; 2.0 0.0; 0.5 1.5];
359+
julia> ν = [0.5, 0.5];
373360
374-
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
375-
0.9497
376-
```
377-
378-
It is possible to provide multiple target marginals as columns of a matrix:
379-
380-
```jldoctest sinkhorn_unbalanced2
381-
julia> ν = [0.0 0.5; 1.0 0.5];
361+
julia> C = [0.0 1.0; 1.0 0.0];
382362
383-
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 0.01, 1_000); sigdigits=4)
384-
2-element Vector{Float64}:
385-
0.9497
386-
0.4494
363+
julia> round.(sinkhorn_unbalanced2(μ, ν, C, 1., 1.); sigdigits=8)
364+
0.19600125
387365
```
388366
389367
See also: [`sinkhorn_unbalanced`](@ref)
@@ -566,3 +544,61 @@ julia> round.(mm_unbalanced(a, b, M, 5, div="l2"), digits=2)
566544
function mm_unbalanced(a, b, M, reg_m; kwargs...)
567545
return pot.unbalanced.mm_unbalanced(a, b, M, reg_m; kwargs...)
568546
end
547+
548+
549+
"""
550+
entropic_partial_wasserstein(a, b, M, reg; kwargs...)
551+
552+
Solves the partial optimal transport problem and returns the OT plan
553+
The function considers the following problem:
554+
555+
```math
556+
\\gamma = \\mathop{\\arg \\min}_\\gamma \\quad \\langle \\gamma,
557+
\\mathbf{M} \\rangle_F + \\mathrm{reg} \\cdot\\Omega(\\gamma)
558+
```
559+
560+
- `a` and `b` are the sample weights
561+
- `M` is the metric cost matrix
562+
- `reg` is a regularization term > 0
563+
564+
This function is a wrapper of the function
565+
[`entropic_partial_wasserstein`](https://pythonot.github.io/gen_modules/ot.partial.html#ot.partial.entropic_partial_wasserstein) in the
566+
Python Optimal Transport package. Keyword arguments are listed in the documentation of the
567+
Python function.
568+
569+
570+
# Examples
571+
572+
```jldoctest
573+
julia> a = [.1, .2];
574+
575+
julia> b = [.1, .1];
576+
577+
julia> M = [0. 1.; 2. 3.];
578+
579+
julia> round.(entropic_partial_wasserstein(a, b, M, 1, m=0.1), digits=2)
580+
2×2 Matrix{Float64}:
581+
0.06 0.02
582+
0.01 0.0
583+
```
584+
585+
"""
586+
function entropic_partial_wasserstein(a, b, M, reg; kwargs...)
587+
return pot.partial.entropic_partial_wasserstein(a, b, M, reg; kwargs...)
588+
end
589+
590+
function partial_wasserstein(a, b, M; kwargs...)
591+
return pot.partial.partial_wasserstein2(a, b, M; kwargs...)
592+
end
593+
594+
function partial_wasserstein2(a, b, M; kwargs...)
595+
return pot.partial.partial_wasserstein2(a, b, M; kwargs...)
596+
end
597+
598+
function entropic_partial_gromov_wasserstein(C1, C2, p, q, reg; kwargs...)
599+
return pot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, reg; kwargs...)
600+
end
601+
602+
function entropic_partial_gromov_wasserstein2(C1, C2, p, q; kwargs...)
603+
return pot.partial.entropic_partial_gromov_wasserstein2(C1, C2, p, q; kwargs...)
604+
end

0 commit comments

Comments
 (0)