From 06f2996b5d7f095d13c10fb4fcdca67a7c31f264 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:54:25 +0100 Subject: [PATCH 01/10] fix: handle function contexts differently from constant contexts --- .../reverse_onearg.jl | 14 +- .../forward_onearg.jl | 16 +-- .../forward_twoarg.jl | 6 +- .../reverse_onearg.jl | 34 +++-- .../reverse_twoarg.jl | 12 +- .../utils.jl | 19 ++- .../onearg.jl | 129 ++++++++++++------ .../secondorder.jl | 20 +-- .../twoarg.jl | 80 ++++++++--- .../utils.jl | 12 +- .../DifferentiationInterfaceTrackerExt.jl | 40 ++++-- .../DifferentiationInterfaceZygoteExt.jl | 119 ++++++++++++---- .../src/second_order/hvp.jl | 72 +++++++--- .../src/second_order/second_derivative.jl | 20 ++- DifferentiationInterface/src/utils/context.jl | 27 +++- 15 files changed, 443 insertions(+), 177 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index d9ca885b5..f6b08e144 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -6,7 +6,11 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::AutoReverseChainRules, + x, + ty::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.NoPullbackPrep() end @@ -17,7 +21,7 @@ function DI.prepare_pullback_same_point( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -30,7 +34,7 @@ function DI.value_and_pullback( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -46,7 +50,7 @@ function DI.value_and_pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -61,7 +65,7 @@ function DI.pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 433769eba..4c02c4a51 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -21,9 +21,8 @@ function DI.value_and_pushforward( f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), only(tx)) x_and_dx = Duplicated(x, dx_sametype) - dy, y = autodiff( - forward_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)... - ) + annotated_contexts = translate(backend, Val(1), contexts...) + dy, y = autodiff(forward_withprimal(backend), f_and_df, x_and_dx, annotated_contexts...) return y, (dy,) end @@ -38,9 +37,8 @@ function DI.value_and_pushforward( f_and_df = get_f_and_df(f, backend, Val(B)) tx_sametype = map(Fix1(convert, typeof(x)), tx) x_and_tx = BatchDuplicated(x, tx_sametype) - ty, y = autodiff( - forward_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)... - ) + annotated_contexts = translate(backend, Val(B), contexts...) + ty, y = autodiff(forward_withprimal(backend), f_and_df, x_and_tx, annotated_contexts...) return y, values(ty) end @@ -55,8 +53,9 @@ function DI.pushforward( f_and_df = get_f_and_df(f, backend) dx_sametype = convert(typeof(x), only(tx)) x_and_dx = Duplicated(x, dx_sametype) + annotated_contexts = translate(backend, Val(1), contexts...) dy = only( - autodiff(forward_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...) + autodiff(forward_noprimal(backend), f_and_df, x_and_dx, annotated_contexts...) ) return (dy,) end @@ -72,8 +71,9 @@ function DI.pushforward( f_and_df = get_f_and_df(f, backend, Val(B)) tx_sametype = map(Fix1(convert, typeof(x)), tx) x_and_tx = BatchDuplicated(x, tx_sametype) + annotated_contexts = translate(backend, Val(B), contexts...) ty = only( - autodiff(forward_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...) + autodiff(forward_noprimal(backend), f_and_df, x_and_tx, annotated_contexts...) ) return values(ty) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 58e77d25f..90736a8cf 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -25,13 +25,14 @@ function DI.value_and_pushforward( dy_sametype = make_zero(y) x_and_dx = Duplicated(x, dx_sametype) y_and_dy = Duplicated(y, dy_sametype) + annotated_contexts = translate(backend, Val(1), contexts...) autodiff( forward_noprimal(backend), f!_and_df!, Const, y_and_dy, x_and_dx, - map(translate, contexts)..., + annotated_contexts..., ) return y, (dy_sametype,) end @@ -50,13 +51,14 @@ function DI.value_and_pushforward( ty_sametype = ntuple(_ -> make_zero(y), Val(B)) x_and_tx = BatchDuplicated(x, tx_sametype) y_and_ty = BatchDuplicated(y, ty_sametype) + annotated_contexts = translate(backend, Val(B), contexts...) autodiff( forward_noprimal(backend), f!_and_df!, Const, y_and_ty, x_and_tx, - map(translate, contexts)..., + annotated_contexts..., ) return y, ty_sametype end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 660d40520..3f31ed4b7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -74,8 +74,9 @@ function DI.value_and_pullback( IA = guess_activity(typeof(x), mode) RA = guess_activity(eltype(ty), mode) dx = make_zero(x) + annotated_contexts = translate(backend, Val(1), contexts...) dinputs, result = seeded_autodiff_thunk( - mode, only(ty), f_and_df, RA, annotate(IA, x, dx), map(translate, contexts)... + mode, only(ty), f_and_df, RA, annotate(IA, x, dx), annotated_contexts... ) new_dx = first(dinputs) if isnothing(new_dx) @@ -98,8 +99,9 @@ function DI.value_and_pullback( IA = batchify_activity(guess_activity(typeof(x), mode), Val(B)) RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B)) tx = ntuple(_ -> make_zero(x), Val(B)) + annotated_contexts = translate(backend, Val(B), contexts...) dinputs, result = batch_seeded_autodiff_thunk( - mode, ty, f_and_df, RA, annotate(IA, x, tx), map(translate, contexts)... + mode, ty, f_and_df, RA, annotate(IA, x, tx), annotated_contexts... ) new_tx = values(first(dinputs)) if isnothing(new_tx) @@ -136,13 +138,9 @@ function DI.value_and_pullback!( RA = guess_activity(eltype(ty), mode) dx_righttype = convert(typeof(x), only(tx)) make_zero!(dx_righttype) + annotated_contexts = translate(backend, Val(1), contexts...) _, result = seeded_autodiff_thunk( - mode, - only(ty), - f_and_df, - RA, - Duplicated(x, dx_righttype), - map(translate, contexts)..., + mode, only(ty), f_and_df, RA, Duplicated(x, dx_righttype), annotated_contexts... ) only(tx) === dx_righttype || copyto!(only(tx), dx_righttype) return result, tx @@ -162,13 +160,9 @@ function DI.value_and_pullback!( RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B)) tx_righttype = map(Fix1(convert, typeof(x)), tx) make_zero!(tx_righttype) + annotated_contexts = translate(backend, Val(B), contexts...) _, result = batch_seeded_autodiff_thunk( - mode, - ty, - f_and_df, - RA, - BatchDuplicated(x, tx_righttype), - map(translate, contexts)..., + mode, ty, f_and_df, RA, BatchDuplicated(x, tx_righttype), annotated_contexts... ) foreach(copyto!, tx, tx_righttype) return result, tx @@ -200,8 +194,9 @@ function DI.gradient( mode = reverse_noprimal(backend) IA = guess_activity(typeof(x), mode) grad = make_zero(x) + annotated_contexts = translate(backend, Val(1), contexts...) dinputs = only( - autodiff(mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...) + autodiff(mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts...) ) new_grad = first(dinputs) if isnothing(new_grad) @@ -221,8 +216,9 @@ function DI.value_and_gradient( mode = reverse_withprimal(backend) IA = guess_activity(typeof(x), mode) grad = make_zero(x) + annotated_contexts = translate(backend, Val(1), contexts...) dinputs, result = autodiff( - mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)... + mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts... ) new_grad = first(dinputs) if isnothing(new_grad) @@ -266,12 +262,13 @@ function DI.gradient!( f_and_df = get_f_and_df(f, backend) grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype make_zero!(grad_righttype) + annotated_contexts = translate(backend, Val(1), contexts...) autodiff( reverse_noprimal(backend), f_and_df, Active, Duplicated(x, grad_righttype), - map(translate, contexts)..., + annotated_contexts..., ) grad === grad_righttype || copyto!(grad, grad_righttype) return grad @@ -298,12 +295,13 @@ function DI.value_and_gradient!( f_and_df = get_f_and_df(f, backend) grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype make_zero!(grad_righttype) + annotated_contexts = translate(backend, Val(1), contexts...) _, y = autodiff( reverse_withprimal(backend), f_and_df, Active, Duplicated(x, grad_righttype), - map(translate, contexts)..., + annotated_contexts..., ) grad === grad_righttype || copyto!(grad, grad_righttype) return y, grad diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 4da3ece8c..40a33f3a6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -23,6 +23,7 @@ function DI.value_and_pullback( f!_and_df! = get_f_and_df(f!, backend) dy_sametype = convert(typeof(y), copy(only(ty))) y_and_dy = Duplicated(y, dy_sametype) + annotated_contexts = translate(backend, Val(1), contexts...) dinputs = only( autodiff( reverse_noprimal(backend), @@ -30,7 +31,7 @@ function DI.value_and_pullback( Const, y_and_dy, Active(x), - map(translate, contexts)..., + annotated_contexts..., ), ) dx = dinputs[2] @@ -49,6 +50,7 @@ function DI.value_and_pullback( f!_and_df! = get_f_and_df(f!, backend, Val(B)) ty_sametype = map(Fix1(convert, typeof(y)), copy.(ty)) y_and_ty = BatchDuplicated(y, ty_sametype) + annotated_contexts = translate(backend, Val(B), contexts...) dinputs = only( autodiff( reverse_noprimal(backend), @@ -56,7 +58,7 @@ function DI.value_and_pullback( Const, y_and_ty, Active(x), - map(translate, contexts)..., + annotated_contexts..., ), ) tx = values(dinputs[2]) @@ -77,13 +79,14 @@ function DI.value_and_pullback( dy_sametype = convert(typeof(y), copy(only(ty))) x_and_dx = Duplicated(x, dx_sametype) y_and_dy = Duplicated(y, dy_sametype) + annotated_contexts = translate(backend, Val(1), contexts...) autodiff( reverse_noprimal(backend), f!_and_df!, Const, y_and_dy, x_and_dx, - map(translate, contexts)..., + annotated_contexts..., ) return y, (dx_sametype,) end @@ -102,13 +105,14 @@ function DI.value_and_pullback( ty_sametype = map(Fix1(convert, typeof(y)), copy.(ty)) x_and_tx = BatchDuplicated(x, tx_sametype) y_and_ty = BatchDuplicated(y, ty_sametype) + annotated_contexts = translate(backend, Val(B), contexts...) autodiff( reverse_noprimal(backend), f!_and_df!, Const, y_and_ty, x_and_tx, - map(translate, contexts)..., + annotated_contexts..., ) return y, tx_sametype end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 6c847faf9..af6a2d9a8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -42,7 +42,24 @@ end force_annotation(f::F) where {F<:Annotation} = f force_annotation(f::F) where {F} = Const(f) -translate(c::DI.Constant) = Const(DI.unwrap(c)) +function translate( + ::AutoEnzyme, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} +) where {B} + return Const(DI.unwrap(c)) +end + +function translate(backend::AutoEnzyme, ::Val{B}, c::DI.FunctionContext) where {B} + return get_f_and_df(unwrap(c), backend, Val(B)) +end + +function translate( + backend::AutoEnzyme, ::Val{B}, contexts::Vararg{DI.Context,C} +) where {B,C} + new_contexts = map(contexts) do c + _translate(backend, Val(B), c) + end + return new_contexts +end ## Modes diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 90372a9e2..4920ca267 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -170,7 +170,7 @@ struct ForwardDiffOneArgDerivativePrep{E} <: DI.DerivativePrep end function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f::F, backend::AutoForwardDiff, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...) return ForwardDiffOneArgDerivativePrep(pushforward_prep) @@ -181,7 +181,7 @@ function DI.value_and_derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -195,7 +195,7 @@ function DI.value_and_derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -208,7 +208,7 @@ function DI.derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) @@ -221,7 +221,7 @@ function DI.derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der @@ -232,7 +232,11 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_gradient!( - f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + grad, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -248,7 +252,10 @@ function DI.value_and_gradient!( end function DI.value_and_gradient( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -262,7 +269,11 @@ function DI.value_and_gradient( end function DI.gradient!( - f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + grad, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -274,7 +285,10 @@ function DI.gradient!( end function DI.gradient( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -292,7 +306,10 @@ struct ForwardDiffGradientPrep{C} <: DI.GradientPrep end function DI.prepare_gradient( - f::F, backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff, + x::AbstractArray, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -307,7 +324,7 @@ function DI.value_and_gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -323,7 +340,7 @@ function DI.value_and_gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = GradientResult(x) @@ -338,7 +355,7 @@ function DI.gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -350,7 +367,7 @@ function DI.gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -362,7 +379,11 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_jacobian!( - f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + jac, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -379,7 +400,10 @@ function DI.value_and_jacobian!( end function DI.value_and_jacobian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -391,7 +415,11 @@ function DI.value_and_jacobian( end function DI.jacobian!( - f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + jac, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -403,7 +431,10 @@ function DI.jacobian!( end function DI.jacobian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -421,7 +452,7 @@ struct ForwardDiffOneArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f::F, backend::AutoForwardDiff, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -436,7 +467,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -453,7 +484,7 @@ function DI.value_and_jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -466,7 +497,7 @@ function DI.jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -478,7 +509,7 @@ function DI.jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -488,7 +519,7 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f::F, backend::AutoForwardDiff, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} return DI.NoSecondDerivativePrep() end @@ -498,7 +529,7 @@ function DI.second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -513,7 +544,7 @@ function DI.second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -527,7 +558,7 @@ function DI.value_derivative_and_second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -546,7 +577,7 @@ function DI.value_derivative_and_second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -561,7 +592,11 @@ end ## HVP function DI.prepare_hvp( - f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff, + x, + tx::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.prepare_hvp(f, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -572,7 +607,7 @@ function DI.hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.hvp(f, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -584,7 +619,7 @@ function DI.hvp!( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.hvp!(f, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -595,7 +630,7 @@ function DI.gradient_and_hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(backend, backend), x, tx, contexts... @@ -610,7 +645,7 @@ function DI.gradient_and_hvp!( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.gradient_and_hvp!( f, grad, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts... @@ -622,7 +657,11 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.hessian!( - f::F, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + hess, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -634,7 +673,10 @@ function DI.hessian!( end function DI.hessian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -651,7 +693,7 @@ function DI.value_gradient_and_hessian!( hess, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -668,7 +710,10 @@ function DI.value_gradient_and_hessian!( end function DI.value_gradient_and_hessian( - f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f::F, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -689,7 +734,7 @@ struct ForwardDiffHessianPrep{C1,C2} <: DI.HessianPrep end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f::F, backend::AutoForwardDiff, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -706,7 +751,7 @@ function DI.hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -718,7 +763,7 @@ function DI.hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -732,7 +777,7 @@ function DI.value_gradient_and_hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -749,7 +794,7 @@ function DI.value_gradient_and_hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl index 35e743d89..b03465e82 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl @@ -16,9 +16,9 @@ function DI.prepare_hvp( inner_gradient_prep = DI.prepare_gradient(f, DI.inner(backend), xdual, contexts...) rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), + DI.FunctionContext(f), DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) @@ -39,9 +39,9 @@ function DI.hvp( (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), + DI.FunctionContext(f), DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) @@ -67,9 +67,9 @@ function DI.hvp!( (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), + DI.FunctionContext(f), DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) @@ -96,9 +96,9 @@ function DI.gradient_and_hvp( (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), + DI.FunctionContext(f), DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) @@ -125,9 +125,9 @@ function DI.gradient_and_hvp!( (; inner_gradient_prep, outer_pushforward_prep) = prep rewrap = DI.Rewrap(contexts...) new_contexts = ( - DI.Constant(f), + DI.FunctionContext(f), DI.PrepContext(inner_gradient_prep), - DI.Constant(DI.inner(backend)), + DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 2fdbc8839..91ffe753b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -111,7 +111,11 @@ end ### Unprepared, only when tag is not specified function DI.value_and_derivative( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -125,7 +129,12 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + der, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -139,7 +148,11 @@ function DI.value_and_derivative!( end function DI.derivative( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -151,7 +164,12 @@ function DI.derivative( end function DI.derivative!( - f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + der, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -169,7 +187,11 @@ struct ForwardDiffTwoArgDerivativePrep{C} <: DI.DerivativePrep end function DI.prepare_derivative( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) tag = get_tag(fc!, backend, x) @@ -183,7 +205,7 @@ function DI.value_and_derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (similar(y),)) @@ -199,7 +221,7 @@ function DI.value_and_derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (der,)) @@ -214,7 +236,7 @@ function DI.derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -228,7 +250,7 @@ function DI.derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -240,7 +262,11 @@ end ### Unprepared, only when chunk size and tag are not specified function DI.value_and_jacobian( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -255,7 +281,12 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + jac, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -269,7 +300,11 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -281,7 +316,12 @@ function DI.jacobian( end function DI.jacobian!( - f!::F, y, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + jac, + backend::AutoForwardDiff{chunksize,T}, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -299,7 +339,11 @@ struct ForwardDiffTwoArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Constant,C} + f!::F, + y, + backend::AutoForwardDiff, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) chunk = choose_chunk(backend, x) @@ -314,7 +358,7 @@ function DI.value_and_jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -331,7 +375,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) @@ -346,7 +390,7 @@ function DI.jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -360,7 +404,7 @@ function DI.jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index aa2546812..3073597d1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -77,7 +77,17 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} return ty end -_translate(::Type{T}, ::Val{B}, c::DI.Constant) where {T,B} = DI.unwrap(c) +# store preparation result with the right input eltype +struct PrepContext{T<:DI.Prep} <: DI.Context + data::T +end + +prepcontext_maker(c) = PrepContext(c) +DI.maker(::PrepContext) = prepcontext_maker + +function _translate(::Type{T}, ::Val{B}, c::DI.ConstantOrFunctionOrBackend) where {T,B} + return DI.unwrap(c) +end _translate(::Type{T}, ::Val{B}, c::DI.PrepContext) where {T,B} = DI.unwrap(c) function _translate(::Type{T}, ::Val{B}, c::DI.Cache) where {T,B} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index 0f84df11f..ea565a918 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -15,20 +15,30 @@ struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, ::AutoTracker, x, ty::NTuple, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoPullbackPrep() end function DI.prepare_pullback_same_point( - f, ::DI.NoPullbackPrep, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoPullbackPrep, + ::AutoTracker, + x, + ty::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(y, pb) end function DI.value_and_pullback( - f, ::DI.NoPullbackPrep, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoPullbackPrep, + ::AutoTracker, + x, + ty::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -43,7 +53,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -58,7 +68,7 @@ function DI.pullback( ::AutoTracker, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -69,19 +79,29 @@ end ## Gradient -function DI.prepare_gradient(f, ::AutoTracker, x, contexts::Vararg{DI.Constant,C}) where {C} +function DI.prepare_gradient( + f, ::AutoTracker, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} +) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoGradientPrep, + ::AutoTracker, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) end function DI.gradient( - f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoGradientPrep, + ::AutoTracker, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) @@ -93,7 +113,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -105,7 +125,7 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index b2f23cd65..7c6deb5cd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -17,20 +17,30 @@ struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, ::AutoZygote, x, ty::NTuple, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoPullbackPrep() end function DI.prepare_pullback_same_point( - f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoPullbackPrep, + ::AutoZygote, + x, + ty::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = pullback(f, x, map(DI.unwrap, contexts)...) return ZygotePullbackPrepSamePoint(y, pb) end function DI.value_and_pullback( - f, ::DI.NoPullbackPrep, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoPullbackPrep, + ::AutoZygote, + x, + ty::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = pullback(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -45,7 +55,7 @@ function DI.value_and_pullback( ::AutoZygote, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -60,7 +70,7 @@ function DI.pullback( ::AutoZygote, x, ty::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -71,19 +81,29 @@ end ## Gradient -function DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{DI.Constant,C}) where {C} +function DI.prepare_gradient( + f, ::AutoZygote, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} +) where {C} return DI.NoGradientPrep() end function DI.value_and_gradient( - f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoGradientPrep, + ::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, first(grad) end function DI.gradient( - f, ::DI.NoGradientPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoGradientPrep, + ::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return first(gradient(f, x, map(DI.unwrap, contexts)...)) end @@ -94,7 +114,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoZygote, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -106,39 +126,59 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoZygote, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end ## Jacobian -function DI.prepare_jacobian(f, ::AutoZygote, x, contexts::Vararg{DI.Constant,C}) where {C} +function DI.prepare_jacobian( + f, ::AutoZygote, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} +) where {C} return DI.NoJacobianPrep() end function DI.value_and_jacobian( - f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoJacobianPrep, + ::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return f(x, map(DI.unwrap, contexts)...), first(jacobian(f, x, map(DI.unwrap, contexts)...)) # https://github.com/FluxML/Zygote.jl/issues/1506 end function DI.jacobian( - f, ::DI.NoJacobianPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoJacobianPrep, + ::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return first(jacobian(f, x, map(DI.unwrap, contexts)...)) end function DI.value_and_jacobian!( - f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + jac, + prep::DI.NoJacobianPrep, + backend::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) end function DI.jacobian!( - f, jac, prep::DI.NoJacobianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + jac, + prep::DI.NoJacobianPrep, + backend::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -148,13 +188,22 @@ end # Beware, this uses ForwardDiff for the inner differentiation function DI.prepare_hvp( - f, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.prepare_hvp(f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end function DI.hvp( - f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f, + prep::DI.HVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -166,7 +215,7 @@ function DI.hvp!( backend::AutoZygote, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.hvp!( f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -174,7 +223,12 @@ function DI.hvp!( end function DI.gradient_and_hvp( - f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Constant,C} + f, + prep::DI.HVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -189,7 +243,7 @@ function DI.gradient_and_hvp!( backend::AutoZygote, x, tx::NTuple, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.gradient_and_hvp!( f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -198,25 +252,40 @@ end ## Hessian -function DI.prepare_hessian(f, ::AutoZygote, x, contexts::Vararg{DI.Constant,C}) where {C} +function DI.prepare_hessian( + f, ::AutoZygote, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} +) where {C} return DI.NoHessianPrep() end function DI.hessian( - f, ::DI.NoHessianPrep, ::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + ::DI.NoHessianPrep, + ::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} fc = DI.with_contexts(f, contexts...) return hessian(fc, x) end function DI.hessian!( - f, hess, prep::DI.NoHessianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + hess, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end function DI.value_gradient_and_hessian( - f, prep::DI.NoHessianPrep, backend::AutoZygote, x, contexts::Vararg{DI.Constant,C} + f, + prep::DI.NoHessianPrep, + backend::AutoZygote, + x, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) @@ -230,7 +299,7 @@ function DI.value_gradient_and_hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Vararg{DI.Constant,C}, + contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 4f53d81d5..cea59e7ef 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -90,7 +90,9 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) @@ -107,7 +109,9 @@ function hvp( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) @@ -124,7 +128,9 @@ function hvp!( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pushforward!( shuffled_gradient, tg, @@ -146,7 +152,9 @@ function gradient_and_hvp( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return value_and_pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) @@ -164,7 +172,9 @@ function gradient_and_hvp!( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) new_grad, _ = value_and_pushforward!( shuffled_gradient, tg, @@ -193,7 +203,9 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) outer_pushforward_prep = prepare_pushforward( shuffled_gradient, outer(backend), x, tx, new_contexts... ) @@ -210,7 +222,9 @@ function hvp( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) @@ -227,7 +241,9 @@ function hvp!( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pushforward!( shuffled_gradient, tg, @@ -249,7 +265,9 @@ function gradient_and_hvp( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return value_and_pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) @@ -267,7 +285,9 @@ function gradient_and_hvp!( ) where {F,C} (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) new_grad, _ = value_and_pushforward!( shuffled_gradient, tg, @@ -298,8 +318,8 @@ function _prepare_hvp_aux( ) where {F,C} rewrap = Rewrap(contexts...) new_contexts = ( - Constant(f), - Constant(inner(backend)), + FunctionContext(f), + BackendContext(inner(backend)), Constant(first(tx)), Constant(rewrap), contexts..., @@ -327,8 +347,8 @@ function hvp( outer_gradient_prep, outer(backend), x, - Constant(f), - Constant(inner(backend)), + FunctionContext(f), + BackendContext(inner(backend)), Constant(dx), Constant(rewrap), contexts..., @@ -355,8 +375,8 @@ function hvp!( outer_gradient_prep, outer(backend), x, - Constant(f), - Constant(inner(backend)), + FunctionContext(f), + BackendContext(inner(backend)), Constant(tx[b]), Constant(rewrap), contexts..., @@ -409,7 +429,9 @@ function _prepare_hvp_aux( contexts::Vararg{Context,C}, ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) outer_pullback_prep = prepare_pullback( shuffled_gradient, outer(backend), x, tx, new_contexts... ) @@ -426,7 +448,9 @@ function hvp( ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -443,7 +467,9 @@ function hvp!( ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -459,7 +485,9 @@ function gradient_and_hvp( ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return value_and_pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) @@ -477,7 +505,9 @@ function gradient_and_hvp!( ) where {F,C} (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) new_grad, _ = value_and_pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 02c0628af..c39c65cf4 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -56,7 +56,9 @@ function prepare_second_derivative( f::F, backend::AbstractADType, x, contexts::Vararg{Context,C} ) where {F,C} rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) outer_derivative_prep = prepare_derivative( shuffled_derivative, outer(backend), x, new_contexts... ) @@ -74,7 +76,9 @@ function second_derivative( ) where {F,C} (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return derivative( shuffled_derivative, outer_derivative_prep, outer(backend), x, new_contexts... ) @@ -89,7 +93,9 @@ function value_derivative_and_second_derivative( ) where {F,C} (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) y = f(x, map(unwrap, contexts)...) der, der2 = value_and_derivative( shuffled_derivative, outer_derivative_prep, outer(backend), x, new_contexts... @@ -107,7 +113,9 @@ function second_derivative!( ) where {F,C} (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) return derivative!( shuffled_derivative, der2, outer_derivative_prep, outer(backend), x, new_contexts... ) @@ -124,7 +132,9 @@ function value_derivative_and_second_derivative!( ) where {F,C} (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) - new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_contexts = ( + FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + ) y = f(x, map(unwrap, contexts)...) new_der, _ = value_and_derivative!( shuffled_derivative, der2, outer_derivative_prep, outer(backend), x, new_contexts... diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 310017490..51cb06a8d 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -19,6 +19,11 @@ Abstract supertype for additional context arguments, which can be passed to diff """ abstract type Context end +unwrap(c::Context) = c.data +Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) + +## Public contexts + """ Constant @@ -53,9 +58,6 @@ end constant_maker(c) = Constant(c) maker(::Constant) = constant_maker -unwrap(c::Constant) = c.data - -Base.:(==)(c1::Constant, c2::Constant) = c1.data == c2.data """ Cache @@ -70,15 +72,26 @@ end cache_maker(c) = Cache(c) maker(::Cache) = cache_maker -unwrap(c::Cache) = c.data -Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data +## Internal contexts for passing stuff around -struct PrepContext{T<:Prep} <: Context +struct FunctionContext{T} <: Context data::T end -unwrap(c::PrepContext) = c.data +functioncontext_maker(c) = FunctionContext(c) +maker(::FunctionContext) = functioncontext_maker + +struct BackendContext{T} <: Context + data::T +end + +backendcontext_maker(c) = BackendContext(c) +maker(::BackendContext) = backendcontext_maker + +const ConstantOrFunctionOrBackend = Union{Constant,FunctionContext,BackendContext} + +## Context manipulation struct Rewrap{C,T} context_makers::T From e199c9fbf175dc342cff97b606774510c7b130f5 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:00:08 +0100 Subject: [PATCH 02/10] Typos --- .../reverse_onearg.jl | 10 +-- .../onearg.jl | 84 +++++++++---------- .../twoarg.jl | 36 ++++---- .../utils.jl | 2 +- .../DifferentiationInterfaceTrackerExt.jl | 20 ++--- .../DifferentiationInterfaceZygoteExt.jl | 50 +++++------ 6 files changed, 101 insertions(+), 101 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index f6b08e144..578622e65 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -10,7 +10,7 @@ function DI.prepare_pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.NoPullbackPrep() end @@ -21,7 +21,7 @@ function DI.prepare_pullback_same_point( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -34,7 +34,7 @@ function DI.value_and_pullback( backend::AutoReverseChainRules, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} rc = ruleconfig(backend) y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...) @@ -50,7 +50,7 @@ function DI.value_and_pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -65,7 +65,7 @@ function DI.pullback( ::AutoReverseChainRules, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 4920ca267..f50f030a7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -170,7 +170,7 @@ struct ForwardDiffOneArgDerivativePrep{E} <: DI.DerivativePrep end function DI.prepare_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} pushforward_prep = DI.prepare_pushforward(f, backend, x, (one(x),), contexts...) return ForwardDiffOneArgDerivativePrep(pushforward_prep) @@ -181,7 +181,7 @@ function DI.value_and_derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} y, ty = DI.value_and_pushforward( f, prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -195,7 +195,7 @@ function DI.value_and_derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} y, _ = DI.value_and_pushforward!( f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... @@ -208,7 +208,7 @@ function DI.derivative( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return only( DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) @@ -221,7 +221,7 @@ function DI.derivative!( prep::ForwardDiffOneArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) return der @@ -236,7 +236,7 @@ function DI.value_and_gradient!( grad, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -255,7 +255,7 @@ function DI.value_and_gradient( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -273,7 +273,7 @@ function DI.gradient!( grad, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -288,7 +288,7 @@ function DI.gradient( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -309,7 +309,7 @@ function DI.prepare_gradient( f::F, backend::AutoForwardDiff, x::AbstractArray, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -324,7 +324,7 @@ function DI.value_and_gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(zero(eltype(x)), (grad,)) @@ -340,7 +340,7 @@ function DI.value_and_gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = GradientResult(x) @@ -355,7 +355,7 @@ function DI.gradient!( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -367,7 +367,7 @@ function DI.gradient( prep::ForwardDiffGradientPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -383,7 +383,7 @@ function DI.value_and_jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -403,7 +403,7 @@ function DI.value_and_jacobian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -419,7 +419,7 @@ function DI.jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -434,7 +434,7 @@ function DI.jacobian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -452,7 +452,7 @@ struct ForwardDiffOneArgJacobianPrep{C} <: DI.JacobianPrep end function DI.prepare_jacobian( - f::F, backend::AutoForwardDiff, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -467,7 +467,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) y = fc(x) @@ -484,7 +484,7 @@ function DI.value_and_jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -497,7 +497,7 @@ function DI.jacobian!( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -509,7 +509,7 @@ function DI.jacobian( prep::ForwardDiffOneArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -519,7 +519,7 @@ end ## Second derivative function DI.prepare_second_derivative( - f::F, backend::AutoForwardDiff, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} return DI.NoSecondDerivativePrep() end @@ -529,7 +529,7 @@ function DI.second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -544,7 +544,7 @@ function DI.second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -558,7 +558,7 @@ function DI.value_derivative_and_second_derivative( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -577,7 +577,7 @@ function DI.value_derivative_and_second_derivative!( ::DI.NoSecondDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} T = tag_type(f, backend, x) xdual = make_dual(T, x, one(x)) @@ -596,7 +596,7 @@ function DI.prepare_hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.prepare_hvp(f, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -607,7 +607,7 @@ function DI.hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.hvp(f, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -619,7 +619,7 @@ function DI.hvp!( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.hvp!(f, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts...) end @@ -630,7 +630,7 @@ function DI.gradient_and_hvp( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(backend, backend), x, tx, contexts... @@ -645,7 +645,7 @@ function DI.gradient_and_hvp!( backend::AutoForwardDiff, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} return DI.gradient_and_hvp!( f, grad, tg, prep, DI.SecondOrder(backend, backend), x, tx, contexts... @@ -661,7 +661,7 @@ function DI.hessian!( hess, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -676,7 +676,7 @@ function DI.hessian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -693,7 +693,7 @@ function DI.value_gradient_and_hessian!( hess, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -713,7 +713,7 @@ function DI.value_gradient_and_hessian( f::F, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc = DI.with_contexts(f, contexts...) @@ -734,7 +734,7 @@ struct ForwardDiffHessianPrep{C1,C2} <: DI.HessianPrep end function DI.prepare_hessian( - f::F, backend::AutoForwardDiff, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) @@ -751,7 +751,7 @@ function DI.hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -763,7 +763,7 @@ function DI.hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing @@ -777,7 +777,7 @@ function DI.value_gradient_and_hessian!( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = DiffResult(one(eltype(x)), (grad, hess)) @@ -794,7 +794,7 @@ function DI.value_gradient_and_hessian( prep::ForwardDiffHessianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc = DI.with_contexts(f, contexts...) result = HessianResult(x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 91ffe753b..5ffbf8412 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -115,7 +115,7 @@ function DI.value_and_derivative( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -134,7 +134,7 @@ function DI.value_and_derivative!( der, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -152,7 +152,7 @@ function DI.derivative( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -169,7 +169,7 @@ function DI.derivative!( der, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -191,7 +191,7 @@ function DI.prepare_derivative( y, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) tag = get_tag(fc!, backend, x) @@ -205,7 +205,7 @@ function DI.value_and_derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (similar(y),)) @@ -221,7 +221,7 @@ function DI.value_and_derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (der,)) @@ -236,7 +236,7 @@ function DI.derivative( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -250,7 +250,7 @@ function DI.derivative!( prep::ForwardDiffTwoArgDerivativePrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -266,7 +266,7 @@ function DI.value_and_jacobian( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -286,7 +286,7 @@ function DI.value_and_jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -304,7 +304,7 @@ function DI.jacobian( y, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -321,7 +321,7 @@ function DI.jacobian!( jac, backend::AutoForwardDiff{chunksize,T}, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C,chunksize,T} if isnothing(chunksize) && T === Nothing fc! = DI.with_contexts(f!, contexts...) @@ -343,7 +343,7 @@ function DI.prepare_jacobian( y, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) chunk = choose_chunk(backend, x) @@ -358,7 +358,7 @@ function DI.value_and_jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) jac = similar(y, length(y), length(x)) @@ -375,7 +375,7 @@ function DI.value_and_jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) @@ -390,7 +390,7 @@ function DI.jacobian( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing @@ -404,7 +404,7 @@ function DI.jacobian!( prep::ForwardDiffTwoArgJacobianPrep, backend::AutoForwardDiff, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 3073597d1..4d30b916e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -88,7 +88,7 @@ DI.maker(::PrepContext) = prepcontext_maker function _translate(::Type{T}, ::Val{B}, c::DI.ConstantOrFunctionOrBackend) where {T,B} return DI.unwrap(c) end -_translate(::Type{T}, ::Val{B}, c::DI.PrepContext) where {T,B} = DI.unwrap(c) +_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = DI.unwrap(c) function _translate(::Type{T}, ::Val{B}, c::DI.Cache) where {T,B} c0 = DI.unwrap(c) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index ea565a918..fb5da6b76 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -15,7 +15,7 @@ struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoTracker, x, ty::NTuple, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoPullbackPrep() end @@ -26,7 +26,7 @@ function DI.prepare_pullback_same_point( ::AutoTracker, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) return TrackerPullbackPrepSamePoint(y, pb) @@ -38,7 +38,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = forward(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -53,7 +53,7 @@ function DI.value_and_pullback( ::AutoTracker, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -68,7 +68,7 @@ function DI.pullback( ::AutoTracker, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -80,7 +80,7 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoTracker, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoTracker, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoGradientPrep() end @@ -90,7 +90,7 @@ function DI.value_and_gradient( ::DI.NoGradientPrep, ::AutoTracker, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, data(first(grad)) @@ -101,7 +101,7 @@ function DI.gradient( ::DI.NoGradientPrep, ::AutoTracker, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return data(first(grad)) @@ -113,7 +113,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -125,7 +125,7 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoTracker, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 7c6deb5cd..0f681c9f9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -17,7 +17,7 @@ struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep end function DI.prepare_pullback( - f, ::AutoZygote, x, ty::NTuple, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoPullbackPrep() end @@ -28,7 +28,7 @@ function DI.prepare_pullback_same_point( ::AutoZygote, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = pullback(f, x, map(DI.unwrap, contexts)...) return ZygotePullbackPrepSamePoint(y, pb) @@ -40,7 +40,7 @@ function DI.value_and_pullback( ::AutoZygote, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, pb = pullback(f, x, map(DI.unwrap, contexts)...) tx = map(ty) do dy @@ -55,7 +55,7 @@ function DI.value_and_pullback( ::AutoZygote, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; y, pb) = prep tx = map(ty) do dy @@ -70,7 +70,7 @@ function DI.pullback( ::AutoZygote, x, ty::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; pb) = prep tx = map(ty) do dy @@ -82,7 +82,7 @@ end ## Gradient function DI.prepare_gradient( - f, ::AutoZygote, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoGradientPrep() end @@ -92,7 +92,7 @@ function DI.value_and_gradient( ::DI.NoGradientPrep, ::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) return val, first(grad) @@ -103,7 +103,7 @@ function DI.gradient( ::DI.NoGradientPrep, ::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return first(gradient(f, x, map(DI.unwrap, contexts)...)) end @@ -114,7 +114,7 @@ function DI.value_and_gradient!( prep::DI.NoGradientPrep, backend::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...) return y, copyto!(grad, new_grad) @@ -126,7 +126,7 @@ function DI.gradient!( prep::DI.NoGradientPrep, backend::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...)) end @@ -134,7 +134,7 @@ end ## Jacobian function DI.prepare_jacobian( - f, ::AutoZygote, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoJacobianPrep() end @@ -144,7 +144,7 @@ function DI.value_and_jacobian( ::DI.NoJacobianPrep, ::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return f(x, map(DI.unwrap, contexts)...), first(jacobian(f, x, map(DI.unwrap, contexts)...)) # https://github.com/FluxML/Zygote.jl/issues/1506 @@ -155,7 +155,7 @@ function DI.jacobian( ::DI.NoJacobianPrep, ::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return first(jacobian(f, x, map(DI.unwrap, contexts)...)) end @@ -166,7 +166,7 @@ function DI.value_and_jacobian!( prep::DI.NoJacobianPrep, backend::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...) return y, copyto!(jac, new_jac) @@ -178,7 +178,7 @@ function DI.jacobian!( prep::DI.NoJacobianPrep, backend::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...)) end @@ -192,7 +192,7 @@ function DI.prepare_hvp( backend::AutoZygote, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.prepare_hvp(f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -203,7 +203,7 @@ function DI.hvp( backend::AutoZygote, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end @@ -215,7 +215,7 @@ function DI.hvp!( backend::AutoZygote, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.hvp!( f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -228,7 +228,7 @@ function DI.gradient_and_hvp( backend::AutoZygote, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.gradient_and_hvp( f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -243,7 +243,7 @@ function DI.gradient_and_hvp!( backend::AutoZygote, x, tx::NTuple, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return DI.gradient_and_hvp!( f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... @@ -253,7 +253,7 @@ end ## Hessian function DI.prepare_hessian( - f, ::AutoZygote, x, contexts::Varard{DI.ConstantOrFunctionOrBackend,C} + f, ::AutoZygote, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C} ) where {C} return DI.NoHessianPrep() end @@ -263,7 +263,7 @@ function DI.hessian( ::DI.NoHessianPrep, ::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} fc = DI.with_contexts(f, contexts...) return hessian(fc, x) @@ -275,7 +275,7 @@ function DI.hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} return copyto!(hess, DI.hessian(f, prep, backend, x, contexts...)) end @@ -285,7 +285,7 @@ function DI.value_gradient_and_hessian( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...) hess = DI.hessian(f, prep, backend, x, contexts...) @@ -299,7 +299,7 @@ function DI.value_gradient_and_hessian!( prep::DI.NoHessianPrep, backend::AutoZygote, x, - contexts::Varard{DI.ConstantOrFunctionOrBackend,C}, + contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...) DI.hessian!(f, hess, prep, backend, x, contexts...) From cdbd221c279fb57e196c8d1898b67323374d9ef4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:13:59 +0100 Subject: [PATCH 03/10] Typo --- .../secondorder.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl index b03465e82..d01efd074 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl @@ -17,7 +17,7 @@ function DI.prepare_hvp( rewrap = DI.Rewrap(contexts...) new_contexts = ( DI.FunctionContext(f), - DI.PrepContext(inner_gradient_prep), + PrepContext(inner_gradient_prep), DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., @@ -40,7 +40,7 @@ function DI.hvp( rewrap = DI.Rewrap(contexts...) new_contexts = ( DI.FunctionContext(f), - DI.PrepContext(inner_gradient_prep), + PrepContext(inner_gradient_prep), DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., @@ -68,7 +68,7 @@ function DI.hvp!( rewrap = DI.Rewrap(contexts...) new_contexts = ( DI.FunctionContext(f), - DI.PrepContext(inner_gradient_prep), + PrepContext(inner_gradient_prep), DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., @@ -97,7 +97,7 @@ function DI.gradient_and_hvp( rewrap = DI.Rewrap(contexts...) new_contexts = ( DI.FunctionContext(f), - DI.PrepContext(inner_gradient_prep), + PrepContext(inner_gradient_prep), DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., @@ -126,7 +126,7 @@ function DI.gradient_and_hvp!( rewrap = DI.Rewrap(contexts...) new_contexts = ( DI.FunctionContext(f), - DI.PrepContext(inner_gradient_prep), + PrepContext(inner_gradient_prep), DI.BackendContext(DI.inner(backend)), DI.Constant(rewrap), contexts..., From 3f762231868ef745dfd7d07b1bbdff0e684d5a8d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:47:07 +0100 Subject: [PATCH 04/10] Fix Enzyme translation --- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index af6a2d9a8..9e42d7429 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -42,17 +42,17 @@ end force_annotation(f::F) where {F<:Annotation} = f force_annotation(f::F) where {F} = Const(f) -function translate( +@inline function _translate( ::AutoEnzyme, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} ) where {B} return Const(DI.unwrap(c)) end -function translate(backend::AutoEnzyme, ::Val{B}, c::DI.FunctionContext) where {B} +@inline function _translate(backend::AutoEnzyme, ::Val{B}, c::DI.FunctionContext) where {B} return get_f_and_df(unwrap(c), backend, Val(B)) end -function translate( +@inline function translate( backend::AutoEnzyme, ::Val{B}, contexts::Vararg{DI.Context,C} ) where {B,C} new_contexts = map(contexts) do c From ae0789ea535575eccea4b6fe93fb02a273935277 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:22:10 +0100 Subject: [PATCH 05/10] Typo --- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 9e42d7429..522f887ce 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -49,7 +49,7 @@ force_annotation(f::F) where {F} = Const(f) end @inline function _translate(backend::AutoEnzyme, ::Val{B}, c::DI.FunctionContext) where {B} - return get_f_and_df(unwrap(c), backend, Val(B)) + return get_f_and_df(DI.unwrap(c), backend, Val(B)) end @inline function translate( From 915b4ff08b314ddf844e6cb4d851dbccaa249074 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:44:55 +0100 Subject: [PATCH 06/10] Forc eannot --- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 522f887ce..614ca92c8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -49,7 +49,7 @@ force_annotation(f::F) where {F} = Const(f) end @inline function _translate(backend::AutoEnzyme, ::Val{B}, c::DI.FunctionContext) where {B} - return get_f_and_df(DI.unwrap(c), backend, Val(B)) + return force_annotation(get_f_and_df(DI.unwrap(c), backend, Val(B))) end @inline function translate( From 629ee230d929901e5f4fe59cc39f55d71aac822a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:50:03 +0100 Subject: [PATCH 07/10] Coverage --- DifferentiationInterface/Project.toml | 2 +- .../ext/DifferentiationInterfaceForwardDiffExt/utils.jl | 3 --- DifferentiationInterface/src/utils/context.jl | 6 ------ 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index e5a8d7cf6..df1dd8601 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.26" +version = "0.6.27" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 4d30b916e..b4893ef36 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -82,9 +82,6 @@ struct PrepContext{T<:DI.Prep} <: DI.Context data::T end -prepcontext_maker(c) = PrepContext(c) -DI.maker(::PrepContext) = prepcontext_maker - function _translate(::Type{T}, ::Val{B}, c::DI.ConstantOrFunctionOrBackend) where {T,B} return DI.unwrap(c) end diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 51cb06a8d..1abdbdd2e 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -79,16 +79,10 @@ struct FunctionContext{T} <: Context data::T end -functioncontext_maker(c) = FunctionContext(c) -maker(::FunctionContext) = functioncontext_maker - struct BackendContext{T} <: Context data::T end -backendcontext_maker(c) = BackendContext(c) -maker(::BackendContext) = backendcontext_maker - const ConstantOrFunctionOrBackend = Union{Constant,FunctionContext,BackendContext} ## Context manipulation From eef9e42e42a5dbeb704a17dd436377f0f2a37e91 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:18:07 +0100 Subject: [PATCH 08/10] Pass mode object to translator in Enzyme --- .../DifferentiationInterfaceEnzymeExt.jl | 1 + .../forward_onearg.jl | 50 ++++++++--------- .../forward_twoarg.jl | 28 +++------- .../reverse_onearg.jl | 48 +++++++--------- .../reverse_twoarg.jl | 56 ++++++------------- .../utils.jl | 28 ++++++---- 6 files changed, 87 insertions(+), 124 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 2ef5364ae..328bffaf3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -7,6 +7,7 @@ using EnzymeCore: Active, Annotation, BatchDuplicated, + BatchDuplicatedNoNeed, BatchMixedDuplicated, Combined, Const, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 4c02c4a51..5335ed472 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -18,11 +18,12 @@ function DI.value_and_pushforward( tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) + mode = forward_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) dx_sametype = convert(typeof(x), only(tx)) x_and_dx = Duplicated(x, dx_sametype) - annotated_contexts = translate(backend, Val(1), contexts...) - dy, y = autodiff(forward_withprimal(backend), f_and_df, x_and_dx, annotated_contexts...) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...) return y, (dy,) end @@ -34,11 +35,12 @@ function DI.value_and_pushforward( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f_and_df = get_f_and_df(f, backend, Val(B)) + mode = forward_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode, Val(B)) tx_sametype = map(Fix1(convert, typeof(x)), tx) x_and_tx = BatchDuplicated(x, tx_sametype) - annotated_contexts = translate(backend, Val(B), contexts...) - ty, y = autodiff(forward_withprimal(backend), f_and_df, x_and_tx, annotated_contexts...) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...) return y, values(ty) end @@ -50,13 +52,12 @@ function DI.pushforward( tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) + mode = forward_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) dx_sametype = convert(typeof(x), only(tx)) x_and_dx = Duplicated(x, dx_sametype) - annotated_contexts = translate(backend, Val(1), contexts...) - dy = only( - autodiff(forward_noprimal(backend), f_and_df, x_and_dx, annotated_contexts...) - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)) return (dy,) end @@ -68,13 +69,12 @@ function DI.pushforward( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f_and_df = get_f_and_df(f, backend, Val(B)) + mode = forward_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode, Val(B)) tx_sametype = map(Fix1(convert, typeof(x)), tx) x_and_tx = BatchDuplicated(x, tx_sametype) - annotated_contexts = translate(backend, Val(B), contexts...) - ty = only( - autodiff(forward_noprimal(backend), f_and_df, x_and_tx, annotated_contexts...) - ) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)) return values(ty) end @@ -132,10 +132,9 @@ function DI.gradient( backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, ) where {F,B} - f_and_df = get_f_and_df(f, backend) - derivs = gradient( - forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows - ) + mode = forward_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) + derivs = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) return only(derivs) end @@ -145,10 +144,9 @@ function DI.value_and_gradient( backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x, ) where {F,B} - f_and_df = get_f_and_df(f, backend) - (; derivs, val) = gradient( - forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows - ) + mode = forward_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) + (; derivs, val) = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) return val, only(derivs) end @@ -201,7 +199,7 @@ function DI.jacobian( backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F,B} - f_and_df = get_f_and_df(f, backend) + f_and_df = get_f_and_df(f, backend, mode) derivs = jacobian( forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows ) @@ -215,7 +213,7 @@ function DI.value_and_jacobian( backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F,B} - f_and_df = get_f_and_df(f, backend) + f_and_df = get_f_and_df(f, backend, mode) (; derivs, val) = jacobian( forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 90736a8cf..b185b4a16 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -20,20 +20,14 @@ function DI.value_and_pushforward( tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f!_and_df! = get_f_and_df(f!, backend) + mode = forward_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode) dx_sametype = convert(typeof(x), only(tx)) dy_sametype = make_zero(y) x_and_dx = Duplicated(x, dx_sametype) y_and_dy = Duplicated(y, dy_sametype) - annotated_contexts = translate(backend, Val(1), contexts...) - autodiff( - forward_noprimal(backend), - f!_and_df!, - Const, - y_and_dy, - x_and_dx, - annotated_contexts..., - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...) return y, (dy_sametype,) end @@ -46,20 +40,14 @@ function DI.value_and_pushforward( tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f!_and_df! = get_f_and_df(f!, backend, Val(B)) + mode = forward_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) tx_sametype = map(Fix1(convert, typeof(x)), tx) ty_sametype = ntuple(_ -> make_zero(y), Val(B)) x_and_tx = BatchDuplicated(x, tx_sametype) y_and_ty = BatchDuplicated(y, ty_sametype) - annotated_contexts = translate(backend, Val(B), contexts...) - autodiff( - forward_noprimal(backend), - f!_and_df!, - Const, - y_and_ty, - x_and_tx, - annotated_contexts..., - ) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...) return y, ty_sametype end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 3f31ed4b7..5585f717e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -69,12 +69,12 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = force_annotation(get_f_and_df(f, backend)) mode = reverse_split_withprimal(backend) + f_and_df = force_annotation(get_f_and_df(f, backend, mode)) IA = guess_activity(typeof(x), mode) RA = guess_activity(eltype(ty), mode) dx = make_zero(x) - annotated_contexts = translate(backend, Val(1), contexts...) + annotated_contexts = translate(backend, mode, Val(1), contexts...) dinputs, result = seeded_autodiff_thunk( mode, only(ty), f_and_df, RA, annotate(IA, x, dx), annotated_contexts... ) @@ -94,12 +94,12 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f_and_df = force_annotation(get_f_and_df(f, backend, Val(B))) mode = reverse_split_withprimal(backend) + f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) IA = batchify_activity(guess_activity(typeof(x), mode), Val(B)) RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B)) tx = ntuple(_ -> make_zero(x), Val(B)) - annotated_contexts = translate(backend, Val(B), contexts...) + annotated_contexts = translate(backend, mode, Val(B), contexts...) dinputs, result = batch_seeded_autodiff_thunk( mode, ty, f_and_df, RA, annotate(IA, x, tx), annotated_contexts... ) @@ -133,12 +133,12 @@ function DI.value_and_pullback!( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = force_annotation(get_f_and_df(f, backend)) mode = reverse_split_withprimal(backend) + f_and_df = force_annotation(get_f_and_df(f, backend, mode)) RA = guess_activity(eltype(ty), mode) dx_righttype = convert(typeof(x), only(tx)) make_zero!(dx_righttype) - annotated_contexts = translate(backend, Val(1), contexts...) + annotated_contexts = translate(backend, mode, Val(1), contexts...) _, result = seeded_autodiff_thunk( mode, only(ty), f_and_df, RA, Duplicated(x, dx_righttype), annotated_contexts... ) @@ -155,12 +155,12 @@ function DI.value_and_pullback!( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f_and_df = force_annotation(get_f_and_df(f, backend, Val(B))) mode = reverse_split_withprimal(backend) + f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B)) tx_righttype = map(Fix1(convert, typeof(x)), tx) make_zero!(tx_righttype) - annotated_contexts = translate(backend, Val(B), contexts...) + annotated_contexts = translate(backend, mode, Val(B), contexts...) _, result = batch_seeded_autodiff_thunk( mode, ty, f_and_df, RA, BatchDuplicated(x, tx_righttype), annotated_contexts... ) @@ -190,11 +190,11 @@ function DI.gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) mode = reverse_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) IA = guess_activity(typeof(x), mode) grad = make_zero(x) - annotated_contexts = translate(backend, Val(1), contexts...) + annotated_contexts = translate(backend, mode, Val(1), contexts...) dinputs = only( autodiff(mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts...) ) @@ -212,11 +212,11 @@ function DI.value_and_gradient( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) mode = reverse_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) IA = guess_activity(typeof(x), mode) grad = make_zero(x) - annotated_contexts = translate(backend, Val(1), contexts...) + annotated_contexts = translate(backend, mode, Val(1), contexts...) dinputs, result = autodiff( mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts... ) @@ -259,17 +259,12 @@ function DI.gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) + mode = reverse_noprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype make_zero!(grad_righttype) - annotated_contexts = translate(backend, Val(1), contexts...) - autodiff( - reverse_noprimal(backend), - f_and_df, - Active, - Duplicated(x, grad_righttype), - annotated_contexts..., - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + autodiff(mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...) grad === grad_righttype || copyto!(grad, grad_righttype) return grad end @@ -292,16 +287,13 @@ function DI.value_and_gradient!( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - f_and_df = get_f_and_df(f, backend) + mode = reverse_withprimal(backend) + f_and_df = get_f_and_df(f, backend, mode) grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype make_zero!(grad_righttype) - annotated_contexts = translate(backend, Val(1), contexts...) + annotated_contexts = translate(backend, mode, Val(1), contexts...) _, y = autodiff( - reverse_withprimal(backend), - f_and_df, - Active, - Duplicated(x, grad_righttype), - annotated_contexts..., + mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts... ) grad === grad_righttype || copyto!(grad, grad_righttype) return y, grad diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 40a33f3a6..c93d04dab 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -20,19 +20,13 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f!_and_df! = get_f_and_df(f!, backend) + mode = reverse_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode) dy_sametype = convert(typeof(y), copy(only(ty))) y_and_dy = Duplicated(y, dy_sametype) - annotated_contexts = translate(backend, Val(1), contexts...) + annotated_contexts = translate(backend, mode, Val(1), contexts...) dinputs = only( - autodiff( - reverse_noprimal(backend), - f!_and_df!, - Const, - y_and_dy, - Active(x), - annotated_contexts..., - ), + autodiff(mode, f!_and_df!, Const, y_and_dy, Active(x), annotated_contexts...) ) dx = dinputs[2] return y, (dx,) @@ -47,19 +41,13 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f!_and_df! = get_f_and_df(f!, backend, Val(B)) + mode = reverse_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) ty_sametype = map(Fix1(convert, typeof(y)), copy.(ty)) y_and_ty = BatchDuplicated(y, ty_sametype) - annotated_contexts = translate(backend, Val(B), contexts...) + annotated_contexts = translate(backend, mode, Val(B), contexts...) dinputs = only( - autodiff( - reverse_noprimal(backend), - f!_and_df!, - Const, - y_and_ty, - Active(x), - annotated_contexts..., - ), + autodiff(mode, f!_and_df!, Const, y_and_ty, Active(x), annotated_contexts...) ) tx = values(dinputs[2]) return y, tx @@ -74,20 +62,14 @@ function DI.value_and_pullback( ty::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} - f!_and_df! = get_f_and_df(f!, backend) + mode = reverse_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode) dx_sametype = make_zero(x) dy_sametype = convert(typeof(y), copy(only(ty))) x_and_dx = Duplicated(x, dx_sametype) y_and_dy = Duplicated(y, dy_sametype) - annotated_contexts = translate(backend, Val(1), contexts...) - autodiff( - reverse_noprimal(backend), - f!_and_df!, - Const, - y_and_dy, - x_and_dx, - annotated_contexts..., - ) + annotated_contexts = translate(backend, mode, Val(1), contexts...) + autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...) return y, (dx_sametype,) end @@ -100,19 +82,13 @@ function DI.value_and_pullback( ty::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} - f!_and_df! = get_f_and_df(f!, backend, Val(B)) + mode = reverse_noprimal(backend) + f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) tx_sametype = ntuple(_ -> make_zero(x), Val(B)) ty_sametype = map(Fix1(convert, typeof(y)), copy.(ty)) x_and_tx = BatchDuplicated(x, tx_sametype) y_and_ty = BatchDuplicated(y, ty_sametype) - annotated_contexts = translate(backend, Val(B), contexts...) - autodiff( - reverse_noprimal(backend), - f!_and_df!, - Const, - y_and_ty, - x_and_tx, - annotated_contexts..., - ) + annotated_contexts = translate(backend, mode, Val(B), contexts...) + autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...) return y, tx_sametype end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 614ca92c8..a476489cb 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -8,15 +8,19 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B) ## Annotations -function get_f_and_df(f::F, ::AutoEnzyme{M,Nothing}, ::Val{B}=Val(1)) where {F,M,B} +@inline function get_f_and_df( + f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) +) where {F,M,B} return f end -function get_f_and_df(f::F, ::AutoEnzyme{M,<:Const}, ::Val{B}=Val(1)) where {F,M,B} +@inline function get_f_and_df( + f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) +) where {F,M,B} return Const(f) end -function get_f_and_df( +@inline function get_f_and_df( f::F, ::AutoEnzyme{ M, @@ -25,13 +29,15 @@ function get_f_and_df( MixedDuplicated, BatchDuplicated, BatchMixedDuplicated, - EnzymeCore.DuplicatedNoNeed, - EnzymeCore.BatchDuplicatedNoNeed, + DuplicatedNoNeed, + BatchDuplicatedNoNeed, }, }, + mode::Mode, ::Val{B}=Val(1), ) where {F,M,B} # TODO: needs more sophistication for mixed activities + @assert !(guess_activity(F, mode) <: Const) if B == 1 return Duplicated(f, make_zero(f)) else @@ -43,20 +49,22 @@ force_annotation(f::F) where {F<:Annotation} = f force_annotation(f::F) where {F} = Const(f) @inline function _translate( - ::AutoEnzyme, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} + ::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} ) where {B} return Const(DI.unwrap(c)) end -@inline function _translate(backend::AutoEnzyme, ::Val{B}, c::DI.FunctionContext) where {B} - return force_annotation(get_f_and_df(DI.unwrap(c), backend, Val(B))) +@inline function _translate( + backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext +) where {B} + return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B))) end @inline function translate( - backend::AutoEnzyme, ::Val{B}, contexts::Vararg{DI.Context,C} + backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context,C} ) where {B,C} new_contexts = map(contexts) do c - _translate(backend, Val(B), c) + _translate(backend, mode, Val(B), c) end return new_contexts end From f47306ceb49e285bf0c3ad0ddc14ea5549f2a9ba Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:40:11 +0100 Subject: [PATCH 09/10] Typo --- .../forward_onearg.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 5335ed472..00b66808a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -199,10 +199,9 @@ function DI.jacobian( backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F,B} + mode = forward_noprimal(backend) f_and_df = get_f_and_df(f, backend, mode) - derivs = jacobian( - forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows - ) + derivs = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) jac_tensor = only(derivs) return maybe_reshape(jac_tensor, prep.output_length, length(x)) end @@ -213,10 +212,9 @@ function DI.value_and_jacobian( backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F,B} + mode = forward_withprimal(backend) f_and_df = get_f_and_df(f, backend, mode) - (; derivs, val) = jacobian( - forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows - ) + (; derivs, val) = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows) jac_tensor = only(derivs) return val, maybe_reshape(jac_tensor, prep.output_length, length(x)) end From 2186d82eda7bb1d7d9a50f868b8447dd91be386d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 18:56:26 +0100 Subject: [PATCH 10/10] Cleaner error --- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index a476489cb..dd8221768 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -37,7 +37,6 @@ end ::Val{B}=Val(1), ) where {F,M,B} # TODO: needs more sophistication for mixed activities - @assert !(guess_activity(F, mode) <: Const) if B == 1 return Duplicated(f, make_zero(f)) else