diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index d53b449..a66d945 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -85,7 +85,7 @@ function set_runtime_activity2( Enzyme.set_runtime_activity(a, RTA) end function_annotation(::Nothing) = Nothing -function_annotation(::AutoEnzyme{<:Any, A}) where A = A +function_annotation(::AutoEnzyme{<:Any, A}) where {A} = A function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, @@ -225,9 +225,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, if func_annot <: Enzyme.Const basefunc = Enzyme.Const(basefunc) elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated - basefunc = Enzyme.BatchDuplicated(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x))) - elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed - basefunc = Enzyme.BatchDuplicatedNoNeed(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x))) + basefunc = Enzyme.BatchDuplicated(basefunc, Tuple(make_zero(basefunc) + for i in 1:length(x))) + elseif func_annot <: Enzyme.DuplicatedNoNeed || + func_annot <: Enzyme.BatchDuplicatedNoNeed + basefunc = Enzyme.BatchDuplicatedNoNeed(basefunc, Tuple(make_zero(basefunc) + for i in 1:length(x))) end # else # seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) @@ -241,12 +244,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(jc) end Enzyme.make_zero!(y) - if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed + if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || + func_annot <: Enzyme.DuplicatedNoNeed || + func_annot <: Enzyme.BatchDuplicatedNoNeed for bf in basefunc.dval Enzyme.make_zero!(bf) end end - Enzyme.autodiff(fmode, basefunc , BatchDuplicated(y, Jaccache), + Enzyme.autodiff(fmode, basefunc, BatchDuplicated(y, Jaccache), BatchDuplicated(θ, seeds), Const(p)) for i in eachindex(θ) if J isa Vector @@ -575,7 +580,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x for i in eachindex(Jaccache) Enzyme.make_zero!(Jaccache[i]) end - Jaccache, y = Enzyme.autodiff(WithPrimal(fmode), f.cons, Duplicated, + Jaccache, + y = Enzyme.autodiff(WithPrimal(fmode), f.cons, Duplicated, BatchDuplicated(θ, seeds), Const(p)) if size(y, 1) == 1 return reduce(vcat, Jaccache) diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index 7574af7..b0ecc0d 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -30,7 +30,7 @@ function OptimizationBase.instantiate_function( adtype, soadtype = OptimizationBase.generate_adtype(adtype) if g == true && f.grad === nothing - prep_grad = prepare_gradient(f.f, adtype, x, Constant(p), strict=Val(false)) + prep_grad = prepare_gradient(f.f, adtype, x, Constant(p), strict = Val(false)) function grad(res, θ) gradient!(f.f, res, prep_grad, adtype, θ, Constant(p)) end @@ -47,7 +47,7 @@ function OptimizationBase.instantiate_function( if fg == true && f.fg === nothing if g == false - prep_grad = prepare_gradient(f.f, adtype, x, Constant(p), strict=Val(false)) + prep_grad = prepare_gradient(f.f, adtype, x, Constant(p), strict = Val(false)) end function fg!(res, θ) (y, _) = value_and_gradient!(f.f, res, prep_grad, adtype, θ, Constant(p)) @@ -68,7 +68,7 @@ function OptimizationBase.instantiate_function( hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if h == true && f.hess === nothing - prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p), strict=Val(false)) + prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p), strict = Val(false)) function hess(res, θ) hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p)) end @@ -85,13 +85,17 @@ function OptimizationBase.instantiate_function( if fgh == true && f.fgh === nothing function fgh!(G, H, θ) - (y, _, _) = value_derivative_and_second_derivative!( + (y, + _, + _) = value_derivative_and_second_derivative!( f.f, G, H, prep_hess, soadtype, θ, Constant(p)) return y end if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(G, H, θ, p) - (y, _, _) = value_derivative_and_second_derivative!( + (y, + _, + _) = value_derivative_and_second_derivative!( f.f, G, H, prep_hess, soadtype, θ, Constant(p)) return y end @@ -143,7 +147,7 @@ function OptimizationBase.instantiate_function( cons_jac_prototype = f.cons_jac_prototype cons_jac_colorvec = f.cons_jac_colorvec if cons !== nothing && cons_j == true && f.cons_j === nothing - prep_jac = prepare_jacobian(cons_oop, adtype, x, strict=Val(false)) + prep_jac = prepare_jacobian(cons_oop, adtype, x, strict = Val(false)) function cons_j!(J, θ) jacobian!(cons_oop, J, prep_jac, adtype, θ) if size(J, 1) == 1 @@ -157,7 +161,8 @@ function OptimizationBase.instantiate_function( end if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing - prep_pullback = prepare_pullback(cons_oop, adtype, x, (ones(eltype(x), num_cons),), strict=Val(false)) + prep_pullback = prepare_pullback( + cons_oop, adtype, x, (ones(eltype(x), num_cons),), strict = Val(false)) function cons_vjp!(J, θ, v) pullback!(cons_oop, (J,), prep_pullback, adtype, θ, (v,)) end @@ -169,7 +174,7 @@ function OptimizationBase.instantiate_function( if cons !== nothing && f.cons_jvp === nothing && cons_jvp == true prep_pushforward = prepare_pushforward( - cons_oop, adtype, x, (ones(eltype(x), length(x)),), strict=Val(false)) + cons_oop, adtype, x, (ones(eltype(x), length(x)),), strict = Val(false)) function cons_jvp!(J, θ, v) pushforward!(cons_oop, (J,), prep_pushforward, adtype, θ, (v,)) end @@ -182,7 +187,8 @@ function OptimizationBase.instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && cons_h == true && f.cons_h === nothing - prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i), strict=Val(false)) + prep_cons_hess = [prepare_hessian( + cons_oop, soadtype, x, Constant(i), strict = Val(false)) for i in 1:num_cons] function cons_h!(H, θ) @@ -201,7 +207,7 @@ function OptimizationBase.instantiate_function( if f.lag_h === nothing && cons !== nothing && lag_h == true lag_extras = prepare_hessian( lagrangian, soadtype, x, Constant(one(eltype(x))), - Constant(ones(eltype(x), num_cons)), Constant(p), strict=Val(false)) + Constant(ones(eltype(x), num_cons)), Constant(p), strict = Val(false)) lag_hess_prototype = zeros(Bool, num_cons, length(x)) function lag_h!(H::AbstractMatrix, θ, σ, λ) @@ -294,7 +300,8 @@ function OptimizationBase.instantiate_function( adtype, soadtype = OptimizationBase.generate_sparse_adtype(adtype) if g == true && f.grad === nothing - extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p), strict=Val(false)) + extras_grad = prepare_gradient( + f.f, adtype.dense_ad, x, Constant(p), strict = Val(false)) function grad(res, θ) gradient!(f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) end @@ -311,16 +318,19 @@ function OptimizationBase.instantiate_function( if fg == true && f.fg === nothing if g == false - extras_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p), strict=Val(false)) + extras_grad = prepare_gradient( + f.f, adtype.dense_ad, x, Constant(p), strict = Val(false)) end function fg!(res, θ) - (y, _) = value_and_gradient!( + (y, + _) = value_and_gradient!( f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) return y end if p !== SciMLBase.NullParameters() && p !== nothing function fg!(res, θ, p) - (y, _) = value_and_gradient!( + (y, + _) = value_and_gradient!( f.f, res, extras_grad, adtype.dense_ad, θ, Constant(p)) return y end @@ -334,7 +344,7 @@ function OptimizationBase.instantiate_function( hess_sparsity = f.hess_prototype hess_colors = f.hess_colorvec if h == true && f.hess === nothing - prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p), strict=Val(false)) + prep_hess = prepare_hessian(f.f, soadtype, x, Constant(p), strict = Val(false)) function hess(res, θ) hessian!(f.f, res, prep_hess, soadtype, θ, Constant(p)) end @@ -354,14 +364,18 @@ function OptimizationBase.instantiate_function( if fgh == true && f.fgh === nothing function fgh!(G, H, θ) - (y, _, _) = value_derivative_and_second_derivative!( + (y, + _, + _) = value_derivative_and_second_derivative!( f.f, G, H, θ, prep_hess, soadtype, Constant(p)) return y end if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(G, H, θ, p) - (y, _, _) = value_derivative_and_second_derivative!( + (y, + _, + _) = value_derivative_and_second_derivative!( f.f, G, H, θ, prep_hess, soadtype, Constant(p)) return y end @@ -458,7 +472,8 @@ function OptimizationBase.instantiate_function( conshess_sparsity = f.cons_hess_prototype conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing && cons_h == true - prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i), strict=Val(false)) + prep_cons_hess = [prepare_hessian( + cons_oop, soadtype, x, Constant(i), strict = Val(false)) for i in 1:num_cons] colores = getfield.(prep_cons_hess, :coloring_result) conshess_sparsity = getfield.(colores, :A) @@ -479,7 +494,7 @@ function OptimizationBase.instantiate_function( if cons !== nothing && f.lag_h === nothing && lag_h == true lag_extras = prepare_hessian( lagrangian, soadtype, x, Constant(one(eltype(x))), - Constant(ones(eltype(x), num_cons)), Constant(p), strict=Val(false)) + Constant(ones(eltype(x), num_cons)), Constant(p), strict = Val(false)) lag_hess_prototype = lag_extras.coloring_result.A lag_hess_colors = lag_extras.coloring_result.color diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 5acc1f7..9b81424 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -78,13 +78,17 @@ function instantiate_function( if fgh == true && f.fgh === nothing function fgh!(G, H, θ) - (y, _, _) = value_derivative_and_second_derivative!( + (y, + _, + _) = value_derivative_and_second_derivative!( f.f, G, H, prep_hess, soadtype, θ, Constant(p)) return y end if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(G, H, θ, p) - (y, _, _) = value_derivative_and_second_derivative!( + (y, + _, + _) = value_derivative_and_second_derivative!( f.f, G, H, prep_hess, soadtype, θ, Constant(p)) return y end @@ -338,13 +342,17 @@ function instantiate_function( if fgh == true && f.fgh === nothing function fgh!(θ) - (y, G, H) = value_derivative_and_second_derivative( + (y, + G, + H) = value_derivative_and_second_derivative( f.f, prep_hess, adtype, θ, Constant(p)) return y, G, H end if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(θ, p) - (y, G, H) = value_derivative_and_second_derivative( + (y, + G, + H) = value_derivative_and_second_derivative( f.f, prep_hess, adtype, θ, Constant(p)) return y, G, H end diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index fa9a7e5..e135339 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -41,13 +41,15 @@ function instantiate_function( prep_grad = prepare_gradient(f.f, adtype.dense_ad, x, Constant(p)) end function fg!(res, θ) - (y, _) = value_and_gradient!( + (y, + _) = value_and_gradient!( f.f, res, prep_grad, adtype.dense_ad, θ, Constant(p)) return y end if p !== SciMLBase.NullParameters() function fg!(res, θ, p) - (y, _) = value_and_gradient!( + (y, + _) = value_and_gradient!( f.f, res, prep_grad, adtype.dense_ad, θ, Constant(p)) return y end @@ -81,13 +83,17 @@ function instantiate_function( if fgh == true && f.fgh === nothing function fgh!(G, H, θ) - (y, _, _) = value_derivative_and_second_derivative!( + (y, + _, + _) = value_derivative_and_second_derivative!( f.f, G, H, prep_hess, soadtype.dense_ad, θ, Constant(p)) return y end if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(G, H, θ, p) - (y, _, _) = value_derivative_and_second_derivative!( + (y, + _, + _) = value_derivative_and_second_derivative!( f.f, G, H, prep_hess, soadtype.dense_ad, θ, Constant(p)) return y end @@ -336,14 +342,18 @@ function instantiate_function( if fgh == true && f.fgh === nothing function fgh!(θ) - (y, G, H) = value_derivative_and_second_derivative( + (y, + G, + H) = value_derivative_and_second_derivative( f.f, prep_hess, soadtype, θ, Constant(p)) return y, G, H end if p !== SciMLBase.NullParameters() && p !== nothing function fgh!(θ, p) - (y, G, H) = value_derivative_and_second_derivative( + (y, + G, + H) = value_derivative_and_second_derivative( f.f, prep_hess, soadtype, θ, Constant(p)) return y, G, H end diff --git a/test/matrixvalued.jl b/test/matrixvalued.jl index 7a60823..71f51f1 100644 --- a/test/matrixvalued.jl +++ b/test/matrixvalued.jl @@ -13,7 +13,8 @@ using Test, ReverseDiff # 1. Matrix Factorization @show adtype function matrix_factorization_objective(X, A) - U, V = @view(X[1:size(A, 1), 1:Int(size(A, 2) / 2)]), + U, + V = @view(X[1:size(A, 1), 1:Int(size(A, 2) / 2)]), @view(X[1:size(A, 1), (Int(size(A, 2) / 2) + 1):size(A, 2)]) return norm(A - U * V') end