Skip to content

Commit

Permalink
Merge branch 'master' into finitediscretemeasure
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jun 7, 2021
2 parents d5d57c7 + 3674ceb commit 7f7cce0
Show file tree
Hide file tree
Showing 8 changed files with 451 additions and 202 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimalTransport"
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
authors = ["zsteve <stephenz@student.unimelb.edu.au>"]
version = "0.3.10"
version = "0.3.11"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand Down
3 changes: 2 additions & 1 deletion src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ include("distances/bures.jl")
include("utils.jl")
include("exact.jl")
include("wasserstein.jl")
include("entropic.jl")
include("entropic/sinkhorn.jl")
include("entropic/sinkhorn_stabilized.jl")
include("quadratic.jl")

end
130 changes: 0 additions & 130 deletions src/entropic.jl → src/entropic/sinkhorn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -461,136 +461,6 @@ function sinkhorn_unbalanced2(
return dot(γ, C)
end

"""
sinkhorn_stabilized_epsscaling(μ, ν, C, ε; lambda = 0.5, k = 5, kwargs...)
Compute the optimal transport plan for the entropically regularized optimal transport problem
with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic regularisation parameter `ε`. Employs the log-domain stabilized algorithm of Schmitzer et al. [^S19] with ε-scaling.
`k` ε-scaling steps are used with scaling factor `lambda`, i.e. sequentially solve Sinkhorn using `sinkhorn_stabilized` with regularisation parameters
``ε_i \\in [λ^{1-k}, \\ldots, λ^{-1}, 1] \\times ε``.
See also: [`sinkhorn_stabilized`](@ref), [`sinkhorn`](@ref)
"""
function sinkhorn_stabilized_epsscaling(μ, ν, C, ε; lambda=0.5, k=5, kwargs...)
α = zero(μ)
β = zero(ν)
for ε_i in* lambda^(1 - j) for j in k:-1:1)
@debug "Epsilon-scaling Sinkhorn algorithm: ε = $ε_i"
α, β = sinkhorn_stabilized(
μ, ν, C, ε_i; alpha=α, beta=β, return_duals=true, kwargs...
)
end
gamma = similar(C)
getK!(gamma, C, α, β, ε, μ, ν)
return gamma
end

function getK!(K, C, α, β, ε, μ, ν)
@. K = exp(-(C - α - β') / ε) * μ * ν'
return K
end

"""
sinkhorn_stabilized(μ, ν, C, ε; absorb_tol = 1e3, alpha_0 = zero(μ), beta = zero(ν), maxiter = 1_000, atol = tol, rtol=nothing, return_duals = false)
Compute the optimal transport plan for the entropically regularized optimal transport problem
with source and target marginals `μ` and `ν`, cost matrix `C` of size `(length(μ), length(ν))`, and entropic regularisation parameter `ε`. Employs the log-domain stabilized algorithm of Schmitzer et al. [^S19]
`alpha` and `beta` are initial scalings for the stabilized Gibbs kernel. If not specified, `alpha` and `beta` are initialised to zero.
If `return_duals = true`, then the optimal dual variables `(u, v)` corresponding to `(μ, ν)` are returned. Otherwise, the coupling `γ` is returned.
[^S19]: Schmitzer, B., 2019. Stabilized sparse scaling algorithms for entropy regularized transport problems. SIAM Journal on Scientific Computing, 41(3), pp.A1443-A1481.
See also: [`sinkhorn`](@ref)
"""
function sinkhorn_stabilized(
μ,
ν,
C,
ε;
absorb_tol=1e3,
maxiter=1_000,
tol=nothing,
atol=tol,
rtol=nothing,
check_convergence=10,
alpha=zero(μ),
beta=zero(ν),
return_duals=false,
)
if tol !== nothing
Base.depwarn(
"keyword argument `tol` is deprecated, please use `atol` and `rtol`",
:sinkhorn_stabilized,
)
end
sum(μ) sum(ν) ||
throw(ArgumentError("source and target marginals must have the same mass"))

T = float(Base.promote_eltype(μ, ν, C))
_atol = atol === nothing ? 0 : atol
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol

norm_μ = sum(abs, μ)
isconverged = false

K = similar(C)
gamma = similar(C)

getK!(K, C, alpha, beta, ε, μ, ν)
u = μ ./ sum(K; dims=2)
v = ν ./ (K' * u)
tmp_u = similar(u)
for iter in 0:maxiter
if (max(norm(u, Inf), norm(v, Inf)) > absorb_tol)
@debug "Absorbing (u, v) into (alpha, beta)"
# absorb into α, β
alpha += ε * log.(u)
beta += ε * log.(v)
u .= 1
v .= 1
getK!(K, C, alpha, beta, ε, μ, ν)
end
if iter % check_convergence == 0
# check marginal
getK!(gamma, C, alpha, beta, ε, μ, ν)
@. gamma *= u * v'
norm_diff = sum(abs, gamma * ones(size(ν)) - μ)
norm_uKv = sum(abs, gamma)
@debug "Stabilized Sinkhorn algorithm (" *
string(iter) *
"/" *
string(maxiter) *
": error of source marginal = " *
string(norm_diff)

if norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))
@debug "Stabilized Sinkhorn algorithm ($iter/$maxiter): converged"
isconverged = true
break
end
end
mul!(tmp_u, K, v)
u = μ ./ tmp_u
mul!(v, K', u)
v = ν ./ v
end

if !isconverged
@warn "Stabilized Sinkhorn algorithm ($maxiter/$maxiter): not converged"
end

alpha = alpha + ε * log.(u)
beta = beta + ε * log.(v)
if return_duals
return alpha, beta
end
getK!(gamma, C, alpha, beta, ε, μ, ν)
return gamma
end

"""
sinkhorn_barycenter(μ, C, ε, w; tol=1e-9, check_marginal_step=10, max_iter=1000)
Expand Down
Loading

0 comments on commit 7f7cce0

Please sign in to comment.