Skip to content

Commit

Permalink
Merge 22b4f92 into d33f1cf
Browse files Browse the repository at this point in the history
  • Loading branch information
zsteve authored Aug 28, 2021
2 parents d33f1cf + 22b4f92 commit 09ab6b3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/quadratic_newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ struct QuadraticOTNewton{T<:Real,K<:Real,D<:Real} <: QuadraticOT
θ::T
κ::K
δ::D
armijo_max::Int
end

"""
QuadraticOTNewton(θ = 0.1, κ = 0.5, δ = 1e-5)
QuadraticOTNewton(;θ = 0.1, κ = 0.5, δ = 1e-5, armijo_max = 50)
Semi-smooth Newton method (Algorithm 2 of Lorenz et al. 2019) with Armijo parameters `θ` and `κ`, and conjugate gradient regularisation `δ`.
`armijo_max` sets the maximum number of Armijo step size trials.
See also: [`QuadraticOTNewton`](@ref), [`quadreg`](@ref)
"""
function QuadraticOTNewton(; θ=0.1, κ=0.5, δ=1e-5)
return QuadraticOTNewton(θ, κ, δ)
function QuadraticOTNewton(; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50)
return QuadraticOTNewton(θ, κ, δ, armijo_max)
end

Base.show(io::IO, ::QuadraticOTNewton) = print(io, "Semi-smooth Newton algorithm")
Expand Down Expand Up @@ -53,16 +55,18 @@ function build_cache(
v = similar(ν, T, size(ν, 1))
fill!(u, zero(T))
fill!(v, zero(T))
δu = zeros(T, size(u))
δv = zeros(T, size(v))
δu = similar(u, T)
δv = similar(v, T)
# intermediate variables (don't need to be initialised)
σ = similar(C, T)
γ = similar(C, T)
M = size(μ, 1)
N = size(ν, 1)
G = zeros(T, M + N, M + N)
G = similar(u, T, M + N, M + N)
fill!(G, zero(T))
# initial guess for conjugate gradient
x = zeros(T, M + N)
x = similar(u, T, M + N)
fill!(x, zero(T))
return QuadraticOTNewtonCache(u, v, δu, δv, σ, γ, G, x, M, N)
end

Expand Down Expand Up @@ -144,6 +148,8 @@ function descent_step!(solver::QuadraticOTSolver{<:QuadraticOTNewton})
# Armijo parameters
θ = solver.alg.θ
κ = solver.alg.κ
armijo_max = solver.alg.armijo_max
armijo_counter = 0

# dual objective
function Φ(u, v, μ, ν, C, ε)
Expand All @@ -154,8 +160,10 @@ function descent_step!(solver::QuadraticOTSolver{<:QuadraticOTNewton})
d = -eps * (dot(δu, μ) + dot(δv, ν)) + eps * dot(γ, δu .+ δv')
t = 1
Φ0 = Φ(u, v, μ, ν, C, eps)
while Φ(u + t * δu, v + t * δv, μ, ν, C, eps) Φ0 + t * θ * d
while (armijo_counter < armijo_max) &&
(Φ(u + t * δu, v + t * δv, μ, ν, C, eps) Φ0 + t * θ * d)
t = κ * t
armijo_counter += 1
end
u .= u + t * δu
return v .= v + t * δv
Expand Down
10 changes: 10 additions & 0 deletions test/gpu/simple_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,15 @@ Random.seed!(100)
@test convert(Array, γ) sinkhorn_unbalanced(μscaled, ν, C, λ1, λ2, ε)
@test c sinkhorn_unbalanced2(μscaled, ν, C, λ1, λ2, ε)
end

@testset "quadreg" begin
# use a different reg parameter
ε_quad = 1.0f0
γ = quadreg(cu_μ, cu_ν, cu_C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1f-5, 50))
# compare with results on the CPU
@test convert(Array, γ)
quadreg(μ, ν, C, ε_quad, QuadraticOTNewton(0.1f0, 0.5f0, 1f-5, 50)) atol =
1f-4 rtol = 1f-4
end
end
end

0 comments on commit 09ab6b3

Please sign in to comment.