Skip to content

Commit

Permalink
Merge 6e14eda into 5354fb6
Browse files Browse the repository at this point in the history
  • Loading branch information
zsteve authored Aug 24, 2021
2 parents 5354fb6 + 6e14eda commit 72a3285
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ IterativeSolvers = "0.8.4, 0.9"
LogExpFunctions = "0.2, 0.3"
MathOptInterface = "0.9"
NNlib = "0.6, 0.7"
PDMats = "0.11"
PDMats = "0.10, 0.11"
QuadGK = "2"
StatsBase = "0.33.8"
julia = "1"
Expand Down
2 changes: 1 addition & 1 deletion examples/basic/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ mu = hcat(mu1, mu2)
C = pairwise(SqEuclidean(), support'; dims=2)
for λ1 in (0.25, 0.5, 0.75)
λ2 = 1 - λ1
a = sinkhorn_barycenter(mu, C, 0.01, [λ1, λ2]; max_iter=1000)
a = sinkhorn_barycenter(mu, C, 0.01, [λ1, λ2], SinkhornGibbs())
plot!(plt, support, a; label="\$\\mu \\quad (\\lambda_1 = $λ1)\$")
end
plt
2 changes: 2 additions & 0 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using NNlib: NNlib
using StatsBase: StatsBase

export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling
export SinkhornBarycenterGibbs

export sinkhorn, sinkhorn2
export emd, emd2
Expand All @@ -37,6 +38,7 @@ include("entropic/sinkhorn_stabilized.jl")
include("entropic/sinkhorn_epsscaling.jl")
include("entropic/sinkhorn_unbalanced.jl")
include("entropic/sinkhorn_barycenter.jl")
include("entropic/sinkhorn_barycenter_gibbs.jl")

include("quadratic.jl")

Expand Down
148 changes: 118 additions & 30 deletions src/entropic/sinkhorn_barycenter.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,118 @@
# Barycenter solver

struct SinkhornBarycenterSolver{A<:Sinkhorn,M,CT,W,E<:Real,T<:Real,R<:Real,C1,C2}
source::M
C::CT
eps::E
w::W
alg::A
atol::T
rtol::R
maxiter::Int
check_convergence::Int
cache::C1
convergence_cache::C2
end

function build_solver(
μ::AbstractMatrix,
C::AbstractMatrix,
ε::Real,
w::AbstractVector,
alg::Sinkhorn;
atol=nothing,
rtol=nothing,
check_convergence=10,
maxiter::Int=1_000,
)
# check that input marginals are balanced
checkbalanced(μ)

size2 = (size(μ, 2),)

# compute type
T = float(Base.promote_eltype(μ, one(eltype(C)) / ε))

# build caches using SinkhornGibbsCache struct (since there is no dependence on ν)
cache = build_cache(T, alg, size2, μ, C, ε)
convergence_cache = build_convergence_cache(T, size2, μ)

# set tolerances
_atol = atol === nothing ? 0 : atol
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol

# create solver
solver = SinkhornBarycenterSolver(
μ, C, ε, w, alg, _atol, _rtol, maxiter, check_convergence, cache, convergence_cache
)
return solver
end

function solve!(solver::SinkhornBarycenterSolver)
# unpack solver
μ = solver.source
w = solver.w
atol = solver.atol
rtol = solver.rtol

maxiter = solver.maxiter
check_convergence = solver.check_convergence
cache = solver.cache
convergence_cache = solver.convergence_cache

# unpack cache
u = cache.u
v = cache.v
K = cache.K
Kv = cache.Kv
a = cache.a

isconverged = false
to_check_step = check_convergence
A_batched_mul_B!(Kv, K, v)
for iter in 1:maxiter
# prestep if needed (not used for SinkhornBarycenterSolver{SinkhornGibbs})
prestep!(solver, iter)

# Sinkhorn iteration
a .= prod(Kv' .^ w; dims=1)' # TODO: optimise
u .= a ./ Kv
At_batched_mul_B!(v, K, u)
v .= μ ./ v
A_batched_mul_B!(Kv, K, v)

# decrement check marginal step
to_check_step -= 1
# check convergence
if to_check_step == 0 || iter == maxiter
# reset counter
to_check_step = check_convergence

isconverged, abserror = OptimalTransport.check_convergence(
a, u, Kv, convergence_cache, atol, rtol
)
@debug string(solver.alg) *
" (" *
string(iter) *
"/" *
string(maxiter) *
": absolute error of source marginal = " *
string(maximum(abserror))

if isconverged
@debug "$(solver.alg) ($iter/$maxiter): converged"
break
end
end
end
if !isconverged
@warn "$(solver.alg) ($maxiter/$maxiter): not converged"
end
return nothing
end

"""
sinkhorn_barycenter(μ, C, ε, w; tol=1e-9, check_marginal_step=10, max_iter=1000)
sinkhorn_barycenter(μ, C, ε, w, alg = SinkhornGibbs(); kwargs...)
Compute the Sinkhorn barycenter for a collection of `N` histograms contained in the columns of `μ`, for a cost matrix `C` of size `(size(μ, 1), size(μ, 1))`, relative weights `w` of size `N`, and entropic regularisation parameter `ε`.
Returns the entropically regularised barycenter of the `μ`, i.e. the histogram `ρ` of length `size(μ, 1)` that solves
Expand All @@ -11,33 +124,8 @@ Returns the entropically regularised barycenter of the `μ`, i.e. the histogram
where ``\\operatorname{OT}_{ε}(\\mu, \\nu) = \\inf_{\\gamma \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle + \\varepsilon \\Omega(\\gamma)``
is the entropic optimal transport loss with cost ``C`` and regularisation ``\\epsilon``.
"""
function sinkhorn_barycenter(μ, C, ε, w; tol=1e-9, check_marginal_step=10, max_iter=1000)
sums = sum(μ; dims=1)
if !isapprox(extrema(sums)...)
throw(ArgumentError("Error: marginals are unbalanced"))
end
K = exp.(-C / ε)
converged = false
v = ones(size(μ))
u = ones(size(μ))
N = size(μ, 2)
for n in 1:max_iter
v = μ ./ (K' * u)
a = ones(size(u, 1))
a = prod((K * v)' .^ w; dims=1)'
u = a ./ (K * v)
if n % check_marginal_step == 0
# check marginal errors
err = maximum(abs.(μ .- v .* (K' * u)))
@debug "Sinkhorn algorithm: iteration $n" err
if err < tol
converged = true
break
end
end
end
if !converged
@warn "Sinkhorn did not converge"
end
return u[:, 1] .* (K * v[:, 1])
function sinkhorn_barycenter(μ, C, ε, w, alg::Sinkhorn; kwargs...)
solver = build_solver(μ, C, ε, w, alg; kwargs...)
solve!(solver)
return solution(solver)
end
38 changes: 38 additions & 0 deletions src/entropic/sinkhorn_barycenter_gibbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# cache

struct SinkhornBarycenterGibbsCache{U,V,KT,A}
u::U
v::V
K::KT
Kv::U
a::A
end

# solver cache
function build_cache(
::Type{T}, ::SinkhornGibbs, size2::Tuple, μ::AbstractMatrix, C::AbstractMatrix, ε::Real
) where {T}
# compute Gibbs kernel (has to be mutable for ε-scaling algorithm)
K = similar(C, T)
@. K = exp(-C / ε)

# create and initialize dual potentials
u = similar(μ, T, size(μ, 1), size2...)
v = similar(μ, T, size(μ, 1), size2...)
a = similar(μ, T, size(μ, 1), 1)
fill!(u, one(T))
fill!(v, one(T))
fill!(a, one(T))

# cache for next iterate of `u`
Kv = similar(u)

return SinkhornBarycenterGibbsCache(u, v, K, Kv, a)
end

prestep!(::SinkhornBarycenterSolver{SinkhornGibbs}, ::Int) = nothing

function solution(solver::SinkhornBarycenterSolver{SinkhornGibbs})
cache = solver.cache
return cache.u[:, 1] .* cache.Kv[:, 1]
end
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ function checksize2(μ::AbstractVecOrMat, ν::AbstractVecOrMat)
return (max(size_μ_2, size_ν_2),)
end

"""
isallapprox(x::AbstractVector)
Check that all entries of `x` are approximately equal
"""
function isallapprox(x::AbstractVecOrMat)
return all(y -> isapprox(y, x[1]), x[2:end])
end

"""
checkbalanced(μ::AbstractVecOrMat, ν::AbstractVecOrMat)
Expand All @@ -67,6 +76,11 @@ function checkbalanced(x::AbstractVecOrMat, y::AbstractVecOrMat)
throw(ArgumentError("source and target marginals are not balanced"))
return nothing
end
function checkbalanced(x::AbstractMatrix)
isallapprox(sum(x; dims=1)) ||
throw(ArgumentError("source and target marginals are not balanced"))
return nothing
end

"""
A_batched_mul_B!(c::AbstractVector, A::AbstractMatrix, b::AbstractVector)
Expand Down
41 changes: 28 additions & 13 deletions test/entropic/sinkhorn_barycenter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,38 @@ const POT = PythonOT
Random.seed!(100)

@testset "sinkhorn_barycenter.jl" begin
@testset "example" begin
# set up support
support = range(-1; stop=1, length=250)
μ1 = normalize!(exp.(-(support .+ 0.5) .^ 2 ./ 0.1^2), 1)
μ2 = normalize!(exp.(-(support .- 0.5) .^ 2 ./ 0.1^2), 1)
μ_all = hcat(μ1, μ2)
# set up support
support = range(-1; stop=1, length=500)
N = 10
μ = hcat([normalize!(exp.(-(support .+ rand()) .^ 2 ./ 0.1^2), 1) for _ in 1:N]...)

# create cost matrix
C = pairwise(SqEuclidean(), support'; dims=2)

# create cost matrix
C = pairwise(SqEuclidean(), support'; dims=2)
# regularisation parameter
ε = 0.05

# compute Sinkhorn barycenter
eps = 0.01
μ_interp = sinkhorn_barycenter(μ_all, C, eps, [0.5, 0.5])
# weights
w = ones(N) / N

@testset "example" begin
α = sinkhorn_barycenter(μ, C, ε, w, SinkhornGibbs())

# compare with POT
# need to use a larger tolerance here because of a quirk with the POT solver
μ_interp_pot = POT.barycenter(μ_all, C, eps; weights=[0.5, 0.5], stopThr=1e-9)
@test μ_interp μ_interp_pot rtol = 1e-6
α_pot = POT.barycenter(μ, C, ε; weights=w, stopThr=1e-9)
@test α α_pot rtol = 1e-6
end

# different element type
@testset "Float32" begin
μ32 = map(Float32, μ)
ε32 = map(Float32, ε)
C32 = map(Float32, C)
w32 = map(Float32, w)
α = sinkhorn_barycenter(μ32, C32, ε32, w32, SinkhornGibbs())

α_pot = POT.barycenter(μ32, C32, ε32; weights=w32, stopThr=1e-9)
@test α α_pot rtol = 1e-6
end
end

0 comments on commit 72a3285

Please sign in to comment.