Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quadratic specialized active set #466

Merged
merged 47 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
64679f7
Edit newlines
sebastiendesignolle Apr 19, 2024
ca14b66
Add ActiveSetQuadratic
sebastiendesignolle Apr 19, 2024
381820a
Vector{Bool} -> BitVector
matbesancon Apr 19, 2024
1087773
signature change for pairwise
matbesancon Apr 19, 2024
8680687
deprecation
matbesancon Apr 19, 2024
3bdf922
Add constructors and basic functions
sebastiendesignolle Apr 19, 2024
2bb3f94
Add active set functions
sebastiendesignolle Apr 19, 2024
6029f69
Fix undefined H and argmin overwriting
sebastiendesignolle Apr 19, 2024
b1f11f8
Create separate file for ActiveSetQuadratic
sebastiendesignolle Apr 19, 2024
75ff95f
Convert Inf -> typemax(eltype(direction))
sebastiendesignolle Apr 19, 2024
7a75e0b
Fix syntax issues, add @inbounds in setindex!
sebastiendesignolle Apr 19, 2024
6ed64d7
Remove debug prints and add FIXME
sebastiendesignolle Apr 19, 2024
ba49f2d
Change active_set -> s.active_set
sebastiendesignolle Apr 19, 2024
b5b3238
Change filter! into deleteat!
sebastiendesignolle Apr 19, 2024
d6cce45
Change type signature of deleteat! (only for Int)
sebastiendesignolle Apr 19, 2024
fde1c1b
Add small example of ActiveSetQuadratic (speed x3)
sebastiendesignolle Apr 19, 2024
6329f68
Remove print
sebastiendesignolle Apr 19, 2024
b078ca2
non-allocating version
matbesancon Apr 22, 2024
1ed5fc0
Merge branch '461-quadratic-specialized-activeset' of github.com:ZIB-…
sebastiendesignolle Apr 22, 2024
ce1ae0c
Remove setindex!
sebastiendesignolle Apr 22, 2024
e6a5e88
Add commented @assert for future tests
sebastiendesignolle Apr 22, 2024
70fa6ec
Add error for idxm=-1 in argmin
sebastiendesignolle Apr 22, 2024
8d2b5d5
Set modified to true in push! and argminmax
sebastiendesignolle Apr 22, 2024
af13dc6
Revert modifications of iterate_pairwise signature
sebastiendesignolle Apr 22, 2024
03f0288
Add validate for ActiveSetQuadratic
sebastiendesignolle Apr 22, 2024
d86fd07
Fix deleteat! and Generator
sebastiendesignolle Apr 23, 2024
243716b
Change weights into weights_prev in deleteat! for afw
sebastiendesignolle Apr 23, 2024
4958c2d
Comment out validate for ActiveSetQuadratic
sebastiendesignolle Apr 23, 2024
0d49033
Add lazy=false and afw
sebastiendesignolle Apr 23, 2024
573abf7
Fix initialize! bug by avoiding end
sebastiendesignolle Apr 26, 2024
4f0be96
Add detection of A and b
sebastiendesignolle Apr 26, 2024
f8cb10f
Bug fix dot(a, a) -> dot(Aa, a)
sebastiendesignolle Apr 26, 2024
ad3b72c
Add warning
sebastiendesignolle Apr 26, 2024
514883c
Add linear regression example with ActiveSetQuadratic
sebastiendesignolle Apr 26, 2024
08e0399
Update number of iterations
sebastiendesignolle May 15, 2024
14fc88c
Add quadratic active set (commented out)
sebastiendesignolle May 15, 2024
7343270
Add identity hessian support
sebastiendesignolle May 15, 2024
4866644
Factorise arguments
sebastiendesignolle May 15, 2024
6439050
Add kwarg commentary
sebastiendesignolle May 15, 2024
e10adc6
Add doc for quadratic active set
sebastiendesignolle May 15, 2024
671c6c1
Add unit test
sebastiendesignolle May 15, 2024
322c136
Comment a bit
sebastiendesignolle May 15, 2024
18b77bf
Remove randomisation
sebastiendesignolle May 15, 2024
5580860
Refactor e -> Identity and add lambda
sebastiendesignolle May 15, 2024
d4a9b03
Update docs/src/advanced.md
sebastiendesignolle May 16, 2024
518d52b
Remove commented function
sebastiendesignolle Jun 3, 2024
b7e7f1e
Change TODO
sebastiendesignolle Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/src/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ extra_vertex_storage=vertex_storage,
determines whether vertices from the storage are used in the algorithm.
See [Extra-lazification](@ref) for a complete example.

## Specialized active set for quadratic functions

If the objective function is quadratic, a considerable speedup can be obtained by using the structure `ActiveSetQuadratic`.
It relies on the storage of various scalar products to efficiently determine the best (and worst for `blended_pairwise_conditional_gradient`) atom in the active set without the need of computing many scalar products in each iteration.
The user should provide the Hessian matrix `A` as well as the linear part `b` of the function, such that:
```math
\nabla f(x)=Ax+b.
```
If the Hessian matrix `A` is simply a scaled identity (for a distance function for instance), `LinearAlgebra.I` or any `LinearAlgebra.UniformScaling` can be given.
Note that these parameters can also be automatically detected, but the precision of this detection (which basically requires solving a linear system) soon becomes insufficient for practical purposes when the dimension increases.

See the examples `quadratic.jl` and `quadratic_A.jl` for the exact syntax.

## Miscellaneous

- Emphasis: All solvers support emphasis (parameter `Emphasis`) to either exploit vectorized linear algebra or be memory efficient, e.g., for large-scale instances
Expand Down
107 changes: 107 additions & 0 deletions examples/quadratic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using LinearAlgebra
using FrankWolfe
using Random

# Example of speedup using the symmetry reduction
# See arxiv.org/abs/2302.04721 for the context
# and arxiv.org/abs/2310.20677 for further symmetrisation
# The symmetry exploited is the invariance of a tensor
# by exchange of the dimensions

struct BellCorrelationsLMO{T} <: FrankWolfe.LinearMinimizationOracle
m::Int # number of inputs
tmp::Vector{T} # used to compute scalar products
end

function FrankWolfe.compute_extreme_point(
lmo::BellCorrelationsLMO{T},
A::Array{T, 2};
kwargs...,
) where {T <: Number}
ax = [ones(T, lmo.m) for n in 1:2]
axm = [zeros(Int, lmo.m) for n in 1:2]
scm = typemax(T)
for i in 1:100
rand!(ax[2], [-1, 1])
sc1 = zero(T)
sc2 = one(T)
while sc1 < sc2
sc2 = sc1
mul!(lmo.tmp, A', ax[1])
for x2 in 1:length(ax[2])
ax[2][x2] = lmo.tmp[x2] > zero(T) ? -one(T) : one(T)
end
mul!(lmo.tmp, A, ax[2])
for x2 in 1:length(ax[1])
ax[1][x2] = lmo.tmp[x2] > zero(T) ? -one(T) : one(T)
end
sc1 = dot(ax[1], lmo.tmp)
end
if sc1 < scm
scm = sc1
for n in 1:2
axm[n] .= ax[n]
end
end
end
# returning a full tensor is naturally naive, but this is only a toy example
return [axm[1][x1]*axm[2][x2] for x1 in 1:lmo.m, x2 in 1:lmo.m]
end

function correlation_tensor_GHZ_polygon(N::Int, m::Int; type=Float64)
res = zeros(type, m*ones(Int, N)...)
tab_cos = [cos(x*type(pi)/m) for x in 0:N*m]
tab_cos[abs.(tab_cos) .< Base.rtoldefault(type)] .= zero(type)
for ci in CartesianIndices(res)
res[ci] = tab_cos[sum(ci.I)-N+1]
end
return res
end

function benchmark_Bell(p::Array{T, 2}, quadratic::Bool; fw_method=FrankWolfe.blended_pairwise_conditional_gradient, kwargs...) where {T <: Number}
normp2 = dot(p, p) / 2
# weird syntax to enable the compiler to correctly understand the type
f = let p = p, normp2 = normp2
x -> normp2 + dot(x, x) / 2 - dot(p, x)
end
grad! = let p = p
(storage, xit) -> begin
for x in eachindex(xit)
storage[x] = xit[x] - p[x]
end
end
end
function reynolds_permutedims(atom::Array{Int, 2}, lmo::BellCorrelationsLMO{T}) where {T <: Number}
res = zeros(T, size(atom))
for per in [[1, 2], [2, 1]]
res .+= permutedims(atom, per)
end
res ./= 2
return res
end
function reynolds_adjoint(gradient::Array{T, 2}, lmo::BellCorrelationsLMO{T}) where {T <: Number}
return gradient # we can spare symmetrising the gradient as it remains symmetric throughout the algorithm
end
lmo = BellCorrelationsLMO{T}(size(p, 1), zeros(T, size(p, 1)))
x0 = FrankWolfe.compute_extreme_point(lmo, -p)
if quadratic
active_set = FrankWolfe.ActiveSetQuadratic([(one(T), x0)], I, -p)
else
active_set = FrankWolfe.ActiveSet([(one(T), x0)])
end
return fw_method(f, grad!, lmo, active_set; line_search=FrankWolfe.Shortstep(one(T)), kwargs...)
end

p = correlation_tensor_GHZ_polygon(2, 100)
max_iteration = 10^3 # speedups are way more important for more iterations
verbose = false
# the following kwarg passing might break for old julia versions
@time benchmark_Bell(p, false; verbose, max_iteration, lazy=false, fw_method=FrankWolfe.blended_pairwise_conditional_gradient) # 2.4s
@time benchmark_Bell(p, true; verbose, max_iteration, lazy=false, fw_method=FrankWolfe.blended_pairwise_conditional_gradient) # 0.8s
@time benchmark_Bell(p, false; verbose, max_iteration, lazy=true, fw_method=FrankWolfe.blended_pairwise_conditional_gradient) # 2.1s
@time benchmark_Bell(p, true; verbose, max_iteration, lazy=true, fw_method=FrankWolfe.blended_pairwise_conditional_gradient) # 0.4s
@time benchmark_Bell(p, false; verbose, max_iteration, lazy=false, fw_method=FrankWolfe.away_frank_wolfe) # 5.7s
@time benchmark_Bell(p, true; verbose, max_iteration, lazy=false, fw_method=FrankWolfe.away_frank_wolfe) # 2.3s
@time benchmark_Bell(p, false; verbose, max_iteration, lazy=true, fw_method=FrankWolfe.away_frank_wolfe) # 3s
@time benchmark_Bell(p, true; verbose, max_iteration, lazy=true, fw_method=FrankWolfe.away_frank_wolfe) # 0.7s
println()
67 changes: 67 additions & 0 deletions examples/quadratic_A.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using FrankWolfe
using Random
using LinearAlgebra
Random.seed!(0)

n = 5 # number of dimensions
p = 10^3 # number of points
k = 10^4 # number of iterations
T = Float64

function simple_reg_loss(θ, data_point)
(xi, yi) = data_point
(a, b) = (θ[1:end-1], θ[end])
pred = a ⋅ xi + b
return (pred - yi)^2 / 2
end

function ∇simple_reg_loss(storage, θ, data_point)
(xi, yi) = data_point
(a, b) = (θ[1:end-1], θ[end])
pred = a ⋅ xi + b
@. storage[1:end-1] += xi * (pred - yi)
storage[end] += pred - yi
return storage
end

xs = [10randn(T, n) for _ in 1:p]
bias = 4
params_perfect = [1:n; bias]

# similar example with noisy data, Gaussian noise around the linear estimate
data_noisy = [(x, x ⋅ (1:n) + bias + 0.5 * randn(T)) for x in xs]

f(x) = sum(simple_reg_loss(x, data_point) for data_point in data_noisy)

function gradf(storage, x)
storage .= 0
for dp in data_noisy
∇simple_reg_loss(storage, x, dp)
end
end

lmo = FrankWolfe.LpNormLMO{T, 2}(1.05 * norm(params_perfect))

x0 = FrankWolfe.compute_extreme_point(lmo, zeros(T, n+1))

# standard active set
# active_set = FrankWolfe.ActiveSet([(1.0, x0)])

# specialized active set, automatically detecting the parameters A and b of the quadratic function f
active_set = FrankWolfe.ActiveSetQuadratic([(one(T), x0)], gradf)

@time res = FrankWolfe.blended_pairwise_conditional_gradient(
# @time res = FrankWolfe.away_frank_wolfe(
f,
gradf,
lmo,
active_set;
verbose=true,
lazy=true,
line_search=FrankWolfe.Adaptive(L_est=10.0, relaxed_smoothness=true),
max_iteration=k,
print_iter=k / 10,
trajectory=true,
)

println()
3 changes: 2 additions & 1 deletion examples/reynolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ function benchmark_Bell(p::Array{T, 3}, sym::Bool; kwargs...) where {T <: Number
x0 = FrankWolfe.compute_extreme_point(lmo, -p)
println("Output type of the LMO: ", typeof(x0))
active_set = FrankWolfe.ActiveSet([(one(T), x0)])
# active_set = FrankWolfe.ActiveSetQuadratic([(one(T), x0)], I, -p)
return FrankWolfe.blended_pairwise_conditional_gradient(f, grad!, lmo, active_set; lazy=true, line_search=FrankWolfe.Shortstep(one(T)), kwargs...)
end

p = 0.5correlation_tensor_GHZ_polygon(3, 8)
benchmark_Bell(p, true; verbose=true, max_iteration=10^6, print_iter=10^4) # 27_985 iterations and 89 atoms
benchmark_Bell(p, true; verbose=true, max_iteration=10^6, print_iter=10^4) # 24_914 iterations and 89 atoms
println()
benchmark_Bell(p, false; verbose=true, max_iteration=10^6, print_iter=10^4) # 107_647 iterations and 379 atoms
println()
1 change: 1 addition & 0 deletions src/FrankWolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("polytope_oracles.jl")
include("moi_oracle.jl")
include("function_gradient.jl")
include("active_set.jl")
include("active_set_quadratic.jl")

include("blended_cg.jl")
include("afw.jl")
Expand Down
78 changes: 39 additions & 39 deletions src/active_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,17 @@ function Base.push!(as::AbstractActiveSet, (λ, a))
end

function Base.deleteat!(as::AbstractActiveSet, idx)
deleteat!(as.weights, idx)
deleteat!(as.atoms, idx)
# WARNING assumes that idx is sorted
for (i, j) in enumerate(idx)
deleteat!(as, j-i+1)
end
return as
end

function Base.setindex!(as::AbstractActiveSet, tup::Tuple, idx)
as.weights[idx] = tup[1]
as.atoms[idx] = tup[2]
return tup
function Base.deleteat!(as::AbstractActiveSet, idx::Int)
deleteat!(as.atoms, idx)
deleteat!(as.weights, idx)
return as
end

function Base.empty!(as::AbstractActiveSet)
Expand Down Expand Up @@ -108,10 +110,8 @@ function active_set_update!(active_set::AbstractActiveSet, lambda, atom, renorm=
if idx === nothing
idx = find_atom(active_set, atom)
end
updating = false
if idx > 0
@inbounds active_set.weights[idx] = active_set.weights[idx] + lambda
updating = true
@inbounds active_set.weights[idx] += lambda
else
push!(active_set, (lambda, atom))
end
Expand All @@ -138,14 +138,14 @@ function active_set_update_scale!(x::IT, lambda, atom::SparseArrays.SparseVector
@. x *= (1 - lambda)
nzvals = SparseArrays.nonzeros(atom)
nzinds = SparseArrays.nonzeroinds(atom)
for idx in eachindex(nzvals)
@inbounds for idx in eachindex(nzvals)
x[nzinds[idx]] += lambda * nzvals[idx]
end
return x
end

"""
active_set_update_iterate_pairwise!(x, lambda, fw_atom, away_atom)
active_set_update_iterate_pairwise!(active_set, x, lambda, fw_atom, away_atom)

Operates `x ← x + λ a_fw - λ a_aw`.
"""
Expand All @@ -155,7 +155,7 @@ function active_set_update_iterate_pairwise!(x::IT, lambda::Real, fw_atom::A, aw
end

function active_set_validate(active_set::AbstractActiveSet)
return sum(active_set.weights) ≈ 1.0 && all(>=(0), active_set.weights)
return sum(active_set.weights) ≈ 1.0 && all((0), active_set.weights)
end

function active_set_renormalize!(active_set::AbstractActiveSet)
Expand Down Expand Up @@ -223,13 +223,13 @@ end
function active_set_cleanup!(active_set; weight_purge_threshold=1e-12, update=true, add_dropped_vertices=false, vertex_storage=nothing)
if add_dropped_vertices && vertex_storage !== nothing
for (weight, v) in zip(active_set.weights, active_set.atoms)
if weight <= weight_purge_threshold
if weight weight_purge_threshold
push!(vertex_storage, v)
end
end
end

filter!(e -> e[1] > weight_purge_threshold, active_set)
# one cannot use a generator as deleteat! modifies active_set in place
deleteat!(active_set, [idx for idx in eachindex(active_set) if active_set.weights[idx] ≤ weight_purge_threshold])
if update
compute_active_set_iterate!(active_set)
end
Expand All @@ -252,18 +252,19 @@ Computes the linear minimizer in the direction on the active set.
Returns `(λ_i, a_i, i)`
"""
function active_set_argmin(active_set::AbstractActiveSet, direction)
val = dot(active_set.atoms[1], direction)
idx = 1
temp = 0
for i in 2:length(active_set)
temp = fast_dot(active_set.atoms[i], direction)
if temp < val
val = temp
idx = i
valm = typemax(eltype(direction))
idxm = -1
@inbounds for i in eachindex(active_set)
val = fast_dot(active_set.atoms[i], direction)
if val < valm
valm = val
idxm = i
end
end
# return lambda, vertex, index
return (active_set[idx]..., idx)
if idxm == -1
error("Infinite minimum $valm in the active set. Does the gradient contain invalid (NaN / Inf) entries?")
end
return (active_set[idxm]..., idxm)
end

"""
Expand All @@ -273,28 +274,27 @@ Computes the linear minimizer in the direction on the active set.
Returns `(λ_min, a_min, i_min, val_min, λ_max, a_max, i_max, val_max, val_max-val_min ≥ Φ)`
"""
function active_set_argminmax(active_set::AbstractActiveSet, direction; Φ=0.5)
val = Inf
valM = -Inf
idx = -1
valm = typemax(eltype(direction))
valM = typemin(eltype(direction))
idxm = -1
idxM = -1
for i in eachindex(active_set)
temp_val = fast_dot(active_set.atoms[i], direction)
if temp_val < val
val = temp_val
idx = i
@inbounds for i in eachindex(active_set)
val = fast_dot(active_set.atoms[i], direction)
if val < valm
valm = val
idxm = i
end
if valM < temp_val
valM = temp_val
if valM < val
valM = val
idxM = i
end
end
if idx == -1 || idxM == -1
error("Infinite minimum $val or maximum $valM in the active set. Does the gradient contain invalid (NaN / Inf) entries?")
if idxm == -1 || idxM == -1
error("Infinite minimum $valm or maximum $valM in the active set. Does the gradient contain invalid (NaN / Inf) entries?")
end
return (active_set[idx]..., idx, val, active_set[idxM]..., idxM, valM, valM - val ≥ Φ)
return (active_set[idxm]..., idxm, valm, active_set[idxM]..., idxM, valM, valM - valm ≥ Φ)
end


"""
active_set_initialize!(as, v)

Expand Down
Loading
Loading