Skip to content

Commit

Permalink
Merge pull request #30 from SymbolicML/constant-optimization
Browse files Browse the repository at this point in the history
Overload `Optim.optimize` for `::Node`
  • Loading branch information
MilesCranmer committed Jan 28, 2024
2 parents 160c30a + 80c370f commit 17f04ad
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 1 deletion.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[weakdeps]
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DynamicExpressionsOptimExt = "Optim"
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
DynamicExpressionsZygoteExt = "Zygote"

Expand All @@ -28,6 +30,7 @@ Aqua = "0.7"
Compat = "3.37, 4"
Enzyme = "^0.11.12"
LoopVectorization = "0.12"
Optim = "0.19, 1"
MacroTools = "0.4, 0.5"
PackageExtensionCompat = "1"
PrecompileTools = "1"
Expand All @@ -40,6 +43,7 @@ julia = "1.6"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -48,4 +52,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "Aqua", "Enzyme", "ForwardDiff", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
test = ["Test", "SafeTestsets", "Aqua", "Enzyme", "Optim", "ForwardDiff", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
123 changes: 123 additions & 0 deletions ext/DynamicExpressionsOptimExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
module DynamicExpressionsOptimExt

using DynamicExpressions: AbstractExpressionNode, eval_tree_array
using Compat: @inline

import Optim: Optim, OptimizationResults, NLSolversBase

#! format: off
"""
ExpressionOptimizationResults{R,N<:AbstractExpressionNode}
Optimization results for an expression, which wraps the base optimization results
on a vector of constants.
"""
struct ExpressionOptimizationResults{R<:OptimizationResults,N<:AbstractExpressionNode} <: OptimizationResults
_results::R # The raw results from Optim.
tree::N # The final expression tree
end
#! format: on
function Base.getproperty(r::ExpressionOptimizationResults, s::Symbol)
if s == :tree || s == :minimizer
return getfield(r, :tree)
else
return getproperty(getfield(r, :_results), s)
end
end
function Base.propertynames(r::ExpressionOptimizationResults)
return (:tree, propertynames(getfield(r, :_results))...)
end
function Optim.minimizer(r::ExpressionOptimizationResults)
return r.tree
end

function set_constant_nodes!(
constant_nodes::AbstractArray{N}, x
) where {T,N<:AbstractExpressionNode{T}}
for (ci, xi) in zip(constant_nodes, x)
ci.val::T = xi::T
end
return nothing
end

"""Wrap function or objective with insertion of values of the constant nodes."""
function wrap_func(
f::F, tree::N, constant_nodes::AbstractArray{N}
) where {F<:Function,T,N<:AbstractExpressionNode{T}}
function wrapped_f(args::Vararg{Any,M}) where {M}
first_args = args[1:(end - 1)]
x = last(args)
set_constant_nodes!(constant_nodes, x)
return @inline(f(first_args..., tree))
end
return wrapped_f
end
function wrap_func(
::Nothing, tree::N, constant_nodes::AbstractArray{N}
) where {N<:AbstractExpressionNode}
return nothing
end
function wrap_func(
f::NLSolversBase.InplaceObjective, tree::N, constant_nodes::AbstractArray{N}
) where {N<:AbstractExpressionNode}
# Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead
# `InplaceObjective`. These contain multiple functions, each of which needs to be
# wrapped. Some functions are `nothing`; those can be left as-is.
@assert fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv)
return NLSolversBase.InplaceObjective(
wrap_func(f.df, tree, constant_nodes),
wrap_func(f.fdf, tree, constant_nodes),
wrap_func(f.fgh, tree, constant_nodes),
wrap_func(f.hv, tree, constant_nodes),
wrap_func(f.fghv, tree, constant_nodes),
)
end

"""
optimize(f, [g!, [h!,]] tree, args...; kwargs...)
Optimize an expression tree with respect to the constants in the tree.
Returns an `ExpressionOptimizationResults` object, which wraps the base
optimization results on a vector of constants. You may use `res.minimizer`
to view the optimized expression tree.
"""
function Optim.optimize(f::F, tree::AbstractExpressionNode, args...; kwargs...) where {F}
return Optim.optimize(f, nothing, tree, args...; kwargs...)
end
function Optim.optimize(
f::F, g!::G, tree::AbstractExpressionNode, args...; kwargs...
) where {F,G}
return Optim.optimize(f, g!, nothing, tree, args...; kwargs...)
end
function Optim.optimize(
f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...; make_copy=true, kwargs...
) where {F,G,H,T}
if make_copy
tree = copy(tree)
end
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
x0 = T[t.val::T for t in constant_nodes]
if !isnothing(h!)
throw(
ArgumentError(
"Optim.optimize does not yet support Hessians on `AbstractExpressionNode`. " *
"Please raise an issue at github.com/SymbolicML/DynamicExpressions.jl.",
),
)
end
base_res = if isnothing(g!)
Optim.optimize(wrap_func(f, tree, constant_nodes), x0, args...; kwargs...)
else
Optim.optimize(
wrap_func(f, tree, constant_nodes),
wrap_func(g!, tree, constant_nodes),
x0,
args...;
kwargs...,
)
end
set_constant_nodes!(constant_nodes, Optim.minimizer(base_res))
return ExpressionOptimizationResults(base_res, tree)
end

end
92 changes: 92 additions & 0 deletions test/test_optim.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using DynamicExpressions, Optim, Zygote
using Random: MersenneTwister as RNG

operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,))
x1, x2 = (i -> Node(Float64; feature=i)).(1:2)

X = rand(RNG(0), Float64, 2, 100)
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9

original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2

f(tree) = sum(abs2, tree(X, operators) .- y)

@testset "Basic optimization" begin
tree = copy(original_tree)
res = optimize(f, tree)

# Should be unchanged by default
if VERSION >= v"1.9"
ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsOptimExt)
@test res isa ext.ExpressionOptimizationResults
end
@test tree == original_tree
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
end

@testset "With gradients" begin
tree = copy(original_tree)
did_i_run = Ref(false)
# Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD)
g!(G, tree) =
let
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(tree, X, operators; variable=false)
dresult_dŷ = @. 2 * (ŷ - y)
for i in eachindex(G)
G[i] = sum(
j -> dresult_dŷ[j] * dŷ_dconstants[i, j],
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)),
)
end
did_i_run[] = true
return nothing
end

res = optimize(f, g!, tree, BFGS())
@test did_i_run[]
@test res.f_calls > 0
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
@test Optim.minimizer(res) === res.minimizer
@test propertynames(res) == (:tree, propertynames(getfield(res, :_results))...)

@testset "Hessians not implemented" begin
@test_throws ArgumentError optimize(f, g!, t -> t, tree, BFGS())
VERSION >= v"1.9" && @test_throws(
"Optim.optimize does not yet support Hessians on `AbstractExpressionNode`",
optimize(f, g!, t -> t, tree, BFGS())
)
end
end

# Now, try combined
@testset "Combined evaluation with gradient" begin
tree = copy(original_tree)
did_i_run_2 = Ref(false)
fg!(F, G, tree) =
let
if G !== nothing
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(
tree, X, operators; variable=false
)
dresult_dŷ = @. 2 * (ŷ - y)
for i in eachindex(G)
G[i] = sum(
j -> dresult_dŷ[j] * dŷ_dconstants[i, j],
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)),
)
end
if F !== nothing
did_i_run_2[] = true
return sum(abs2, ŷ .- y)
end
elseif F !== nothing
# Only f
return sum(abs2, tree(X, operators) .- y)
end
end
res = optimize(Optim.only_fg!(fg!), tree, BFGS())

@test did_i_run_2[]
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01)
end
4 changes: 4 additions & 0 deletions test/unittest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ end
include("test_deprecations.jl")
end

@safetestset "Test Optim.jl" begin
include("test_optim.jl")
end

@safetestset "Test tree construction and scoring" begin
include("test_tree_construction.jl")
end
Expand Down

0 comments on commit 17f04ad

Please sign in to comment.