-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from SymbolicML/constant-optimization
Overload `Optim.optimize` for `::Node`
- Loading branch information
Showing
4 changed files
with
224 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
module DynamicExpressionsOptimExt | ||
|
||
using DynamicExpressions: AbstractExpressionNode, eval_tree_array | ||
using Compat: @inline | ||
|
||
import Optim: Optim, OptimizationResults, NLSolversBase | ||
|
||
#! format: off | ||
""" | ||
ExpressionOptimizationResults{R,N<:AbstractExpressionNode} | ||
Optimization results for an expression, which wraps the base optimization results | ||
on a vector of constants. | ||
""" | ||
struct ExpressionOptimizationResults{R<:OptimizationResults,N<:AbstractExpressionNode} <: OptimizationResults | ||
_results::R # The raw results from Optim. | ||
tree::N # The final expression tree | ||
end | ||
#! format: on | ||
function Base.getproperty(r::ExpressionOptimizationResults, s::Symbol) | ||
if s == :tree || s == :minimizer | ||
return getfield(r, :tree) | ||
else | ||
return getproperty(getfield(r, :_results), s) | ||
end | ||
end | ||
function Base.propertynames(r::ExpressionOptimizationResults) | ||
return (:tree, propertynames(getfield(r, :_results))...) | ||
end | ||
function Optim.minimizer(r::ExpressionOptimizationResults) | ||
return r.tree | ||
end | ||
|
||
function set_constant_nodes!( | ||
constant_nodes::AbstractArray{N}, x | ||
) where {T,N<:AbstractExpressionNode{T}} | ||
for (ci, xi) in zip(constant_nodes, x) | ||
ci.val::T = xi::T | ||
end | ||
return nothing | ||
end | ||
|
||
"""Wrap function or objective with insertion of values of the constant nodes.""" | ||
function wrap_func( | ||
f::F, tree::N, constant_nodes::AbstractArray{N} | ||
) where {F<:Function,T,N<:AbstractExpressionNode{T}} | ||
function wrapped_f(args::Vararg{Any,M}) where {M} | ||
first_args = args[1:(end - 1)] | ||
x = last(args) | ||
set_constant_nodes!(constant_nodes, x) | ||
return @inline(f(first_args..., tree)) | ||
end | ||
return wrapped_f | ||
end | ||
function wrap_func( | ||
::Nothing, tree::N, constant_nodes::AbstractArray{N} | ||
) where {N<:AbstractExpressionNode} | ||
return nothing | ||
end | ||
function wrap_func( | ||
f::NLSolversBase.InplaceObjective, tree::N, constant_nodes::AbstractArray{N} | ||
) where {N<:AbstractExpressionNode} | ||
# Some objectives, like `Optim.only_fg!(fg!)`, are not functions but instead | ||
# `InplaceObjective`. These contain multiple functions, each of which needs to be | ||
# wrapped. Some functions are `nothing`; those can be left as-is. | ||
@assert fieldnames(NLSolversBase.InplaceObjective) == (:df, :fdf, :fgh, :hv, :fghv) | ||
return NLSolversBase.InplaceObjective( | ||
wrap_func(f.df, tree, constant_nodes), | ||
wrap_func(f.fdf, tree, constant_nodes), | ||
wrap_func(f.fgh, tree, constant_nodes), | ||
wrap_func(f.hv, tree, constant_nodes), | ||
wrap_func(f.fghv, tree, constant_nodes), | ||
) | ||
end | ||
|
||
""" | ||
optimize(f, [g!, [h!,]] tree, args...; kwargs...) | ||
Optimize an expression tree with respect to the constants in the tree. | ||
Returns an `ExpressionOptimizationResults` object, which wraps the base | ||
optimization results on a vector of constants. You may use `res.minimizer` | ||
to view the optimized expression tree. | ||
""" | ||
function Optim.optimize(f::F, tree::AbstractExpressionNode, args...; kwargs...) where {F} | ||
return Optim.optimize(f, nothing, tree, args...; kwargs...) | ||
end | ||
function Optim.optimize( | ||
f::F, g!::G, tree::AbstractExpressionNode, args...; kwargs... | ||
) where {F,G} | ||
return Optim.optimize(f, g!, nothing, tree, args...; kwargs...) | ||
end | ||
function Optim.optimize( | ||
f::F, g!::G, h!::H, tree::AbstractExpressionNode{T}, args...; make_copy=true, kwargs... | ||
) where {F,G,H,T} | ||
if make_copy | ||
tree = copy(tree) | ||
end | ||
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree) | ||
x0 = T[t.val::T for t in constant_nodes] | ||
if !isnothing(h!) | ||
throw( | ||
ArgumentError( | ||
"Optim.optimize does not yet support Hessians on `AbstractExpressionNode`. " * | ||
"Please raise an issue at github.com/SymbolicML/DynamicExpressions.jl.", | ||
), | ||
) | ||
end | ||
base_res = if isnothing(g!) | ||
Optim.optimize(wrap_func(f, tree, constant_nodes), x0, args...; kwargs...) | ||
else | ||
Optim.optimize( | ||
wrap_func(f, tree, constant_nodes), | ||
wrap_func(g!, tree, constant_nodes), | ||
x0, | ||
args...; | ||
kwargs..., | ||
) | ||
end | ||
set_constant_nodes!(constant_nodes, Optim.minimizer(base_res)) | ||
return ExpressionOptimizationResults(base_res, tree) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
using DynamicExpressions, Optim, Zygote | ||
using Random: MersenneTwister as RNG | ||
|
||
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(exp,)) | ||
x1, x2 = (i -> Node(Float64; feature=i)).(1:2) | ||
|
||
X = rand(RNG(0), Float64, 2, 100) | ||
y = @. exp(X[1, :] * 2.1 - 0.9) + X[2, :] * -0.9 | ||
|
||
original_tree = exp(x1 * 0.8 - 0.0) + 5.2 * x2 | ||
target_tree = exp(x1 * 2.1 - 0.9) + -0.9 * x2 | ||
|
||
f(tree) = sum(abs2, tree(X, operators) .- y) | ||
|
||
@testset "Basic optimization" begin | ||
tree = copy(original_tree) | ||
res = optimize(f, tree) | ||
|
||
# Should be unchanged by default | ||
if VERSION >= v"1.9" | ||
ext = Base.get_extension(DynamicExpressions, :DynamicExpressionsOptimExt) | ||
@test res isa ext.ExpressionOptimizationResults | ||
end | ||
@test tree == original_tree | ||
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01) | ||
end | ||
|
||
@testset "With gradients" begin | ||
tree = copy(original_tree) | ||
did_i_run = Ref(false) | ||
# Now, try with gradients too (via Zygote and our hand-rolled forward-mode AD) | ||
g!(G, tree) = | ||
let | ||
ŷ, dŷ_dconstants, _ = eval_grad_tree_array(tree, X, operators; variable=false) | ||
dresult_dŷ = @. 2 * (ŷ - y) | ||
for i in eachindex(G) | ||
G[i] = sum( | ||
j -> dresult_dŷ[j] * dŷ_dconstants[i, j], | ||
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)), | ||
) | ||
end | ||
did_i_run[] = true | ||
return nothing | ||
end | ||
|
||
res = optimize(f, g!, tree, BFGS()) | ||
@test did_i_run[] | ||
@test res.f_calls > 0 | ||
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01) | ||
@test Optim.minimizer(res) === res.minimizer | ||
@test propertynames(res) == (:tree, propertynames(getfield(res, :_results))...) | ||
|
||
@testset "Hessians not implemented" begin | ||
@test_throws ArgumentError optimize(f, g!, t -> t, tree, BFGS()) | ||
VERSION >= v"1.9" && @test_throws( | ||
"Optim.optimize does not yet support Hessians on `AbstractExpressionNode`", | ||
optimize(f, g!, t -> t, tree, BFGS()) | ||
) | ||
end | ||
end | ||
|
||
# Now, try combined | ||
@testset "Combined evaluation with gradient" begin | ||
tree = copy(original_tree) | ||
did_i_run_2 = Ref(false) | ||
fg!(F, G, tree) = | ||
let | ||
if G !== nothing | ||
ŷ, dŷ_dconstants, _ = eval_grad_tree_array( | ||
tree, X, operators; variable=false | ||
) | ||
dresult_dŷ = @. 2 * (ŷ - y) | ||
for i in eachindex(G) | ||
G[i] = sum( | ||
j -> dresult_dŷ[j] * dŷ_dconstants[i, j], | ||
eachindex(axes(dŷ_dconstants, 2), axes(dresult_dŷ, 1)), | ||
) | ||
end | ||
if F !== nothing | ||
did_i_run_2[] = true | ||
return sum(abs2, ŷ .- y) | ||
end | ||
elseif F !== nothing | ||
# Only f | ||
return sum(abs2, tree(X, operators) .- y) | ||
end | ||
end | ||
res = optimize(Optim.only_fg!(fg!), tree, BFGS()) | ||
|
||
@test did_i_run_2[] | ||
@test isapprox(get_constants(res.minimizer), get_constants(target_tree); atol=0.01) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters