Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.26"
version = "0.6.27"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
end

function DI.prepare_pullback(
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.Constant,C}
f,
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
return DI.NoPullbackPrep()
end
Expand All @@ -17,7 +21,7 @@ function DI.prepare_pullback_same_point(
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.Constant,C},
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
Expand All @@ -30,7 +34,7 @@ function DI.value_and_pullback(
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.Constant,C},
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
Expand All @@ -46,7 +50,7 @@ function DI.value_and_pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.Constant,C},
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
(; y, pb) = prep
tx = map(ty) do dy
Expand All @@ -61,7 +65,7 @@ function DI.pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.Constant,C},
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
(; pb) = prep
tx = map(ty) do dy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using EnzymeCore:
Active,
Annotation,
BatchDuplicated,
BatchDuplicatedNoNeed,
BatchMixedDuplicated,
Combined,
Const,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)

Check warning on line 22 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
dx_sametype = convert(typeof(x), only(tx))
x_and_dx = Duplicated(x, dx_sametype)
dy, y = autodiff(
forward_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)

Check warning on line 26 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L25-L26

Added lines #L25 - L26 were not covered by tests
return y, (dy,)
end

Expand All @@ -35,12 +35,12 @@
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f_and_df = get_f_and_df(f, backend, Val(B))
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))

Check warning on line 39 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
tx_sametype = map(Fix1(convert, typeof(x)), tx)
x_and_tx = BatchDuplicated(x, tx_sametype)
ty, y = autodiff(
forward_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)

Check warning on line 43 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L42-L43

Added lines #L42 - L43 were not covered by tests
return y, values(ty)
end

Expand All @@ -52,12 +52,12 @@
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)

Check warning on line 56 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L55-L56

Added lines #L55 - L56 were not covered by tests
dx_sametype = convert(typeof(x), only(tx))
x_and_dx = Duplicated(x, dx_sametype)
dy = only(
autodiff(forward_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...)
)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...))

Check warning on line 60 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
return (dy,)
end

Expand All @@ -69,12 +69,12 @@
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f_and_df = get_f_and_df(f, backend, Val(B))
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))

Check warning on line 73 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L72-L73

Added lines #L72 - L73 were not covered by tests
tx_sametype = map(Fix1(convert, typeof(x)), tx)
x_and_tx = BatchDuplicated(x, tx_sametype)
ty = only(
autodiff(forward_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...)
)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))

Check warning on line 77 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L76-L77

Added lines #L76 - L77 were not covered by tests
return values(ty)
end

Expand Down Expand Up @@ -132,10 +132,9 @@
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
) where {F,B}
f_and_df = get_f_and_df(f, backend)
derivs = gradient(
forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
)
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
derivs = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)

Check warning on line 137 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L135-L137

Added lines #L135 - L137 were not covered by tests
return only(derivs)
end

Expand All @@ -145,10 +144,9 @@
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
) where {F,B}
f_and_df = get_f_and_df(f, backend)
(; derivs, val) = gradient(
forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
)
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
(; derivs, val) = gradient(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)

Check warning on line 149 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L147-L149

Added lines #L147 - L149 were not covered by tests
return val, only(derivs)
end

Expand Down Expand Up @@ -201,10 +199,9 @@
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
) where {F,B}
f_and_df = get_f_and_df(f, backend)
derivs = jacobian(
forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
)
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
derivs = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)

Check warning on line 204 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L202-L204

Added lines #L202 - L204 were not covered by tests
jac_tensor = only(derivs)
return maybe_reshape(jac_tensor, prep.output_length, length(x))
end
Expand All @@ -215,10 +212,9 @@
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
) where {F,B}
f_and_df = get_f_and_df(f, backend)
(; derivs, val) = jacobian(
forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
)
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
(; derivs, val) = jacobian(mode, f_and_df, x; chunk=Val(B), shadows=prep.shadows)

Check warning on line 217 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L215-L217

Added lines #L215 - L217 were not covered by tests
jac_tensor = only(derivs)
return val, maybe_reshape(jac_tensor, prep.output_length, length(x))
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,14 @@
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
f!_and_df! = get_f_and_df(f!, backend)
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode)

Check warning on line 24 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
dx_sametype = convert(typeof(x), only(tx))
dy_sametype = make_zero(y)
x_and_dx = Duplicated(x, dx_sametype)
y_and_dy = Duplicated(y, dy_sametype)
autodiff(
forward_noprimal(backend),
f!_and_df!,
Const,
y_and_dy,
x_and_dx,
map(translate, contexts)...,
)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...)

Check warning on line 30 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L29-L30

Added lines #L29 - L30 were not covered by tests
return y, (dy_sametype,)
end

Expand All @@ -45,19 +40,14 @@
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f!_and_df! = get_f_and_df(f!, backend, Val(B))
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))

Check warning on line 44 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L43-L44

Added lines #L43 - L44 were not covered by tests
tx_sametype = map(Fix1(convert, typeof(x)), tx)
ty_sametype = ntuple(_ -> make_zero(y), Val(B))
x_and_tx = BatchDuplicated(x, tx_sametype)
y_and_ty = BatchDuplicated(y, ty_sametype)
autodiff(
forward_noprimal(backend),
f!_and_df!,
Const,
y_and_ty,
x_and_tx,
map(translate, contexts)...,
)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)

Check warning on line 50 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L49-L50

Added lines #L49 - L50 were not covered by tests
return y, ty_sametype
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@
ty::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = reverse_split_withprimal(backend)
f_and_df = force_annotation(get_f_and_df(f, backend, mode))

Check warning on line 73 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L73

Added line #L73 was not covered by tests
IA = guess_activity(typeof(x), mode)
RA = guess_activity(eltype(ty), mode)
dx = make_zero(x)
annotated_contexts = translate(backend, mode, Val(1), contexts...)

Check warning on line 77 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L77

Added line #L77 was not covered by tests
dinputs, result = seeded_autodiff_thunk(
mode, only(ty), f_and_df, RA, annotate(IA, x, dx), map(translate, contexts)...
mode, only(ty), f_and_df, RA, annotate(IA, x, dx), annotated_contexts...
)
new_dx = first(dinputs)
if isnothing(new_dx)
Expand All @@ -93,13 +94,14 @@
ty::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_split_withprimal(backend)
f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B)))

Check warning on line 98 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L98

Added line #L98 was not covered by tests
IA = batchify_activity(guess_activity(typeof(x), mode), Val(B))
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
tx = ntuple(_ -> make_zero(x), Val(B))
annotated_contexts = translate(backend, mode, Val(B), contexts...)

Check warning on line 102 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L102

Added line #L102 was not covered by tests
dinputs, result = batch_seeded_autodiff_thunk(
mode, ty, f_and_df, RA, annotate(IA, x, tx), map(translate, contexts)...
mode, ty, f_and_df, RA, annotate(IA, x, tx), annotated_contexts...
)
new_tx = values(first(dinputs))
if isnothing(new_tx)
Expand Down Expand Up @@ -131,18 +133,14 @@
ty::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = reverse_split_withprimal(backend)
f_and_df = force_annotation(get_f_and_df(f, backend, mode))

Check warning on line 137 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L137

Added line #L137 was not covered by tests
RA = guess_activity(eltype(ty), mode)
dx_righttype = convert(typeof(x), only(tx))
make_zero!(dx_righttype)
annotated_contexts = translate(backend, mode, Val(1), contexts...)

Check warning on line 141 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L141

Added line #L141 was not covered by tests
_, result = seeded_autodiff_thunk(
mode,
only(ty),
f_and_df,
RA,
Duplicated(x, dx_righttype),
map(translate, contexts)...,
mode, only(ty), f_and_df, RA, Duplicated(x, dx_righttype), annotated_contexts...
)
only(tx) === dx_righttype || copyto!(only(tx), dx_righttype)
return result, tx
Expand All @@ -157,18 +155,14 @@
ty::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
mode = reverse_split_withprimal(backend)
f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B)))

Check warning on line 159 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L159

Added line #L159 was not covered by tests
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
tx_righttype = map(Fix1(convert, typeof(x)), tx)
make_zero!(tx_righttype)
annotated_contexts = translate(backend, mode, Val(B), contexts...)

Check warning on line 163 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L163

Added line #L163 was not covered by tests
_, result = batch_seeded_autodiff_thunk(
mode,
ty,
f_and_df,
RA,
BatchDuplicated(x, tx_righttype),
map(translate, contexts)...,
mode, ty, f_and_df, RA, BatchDuplicated(x, tx_righttype), annotated_contexts...
)
foreach(copyto!, tx, tx_righttype)
return result, tx
Expand Down Expand Up @@ -196,12 +190,13 @@
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
mode = reverse_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)

Check warning on line 194 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L194

Added line #L194 was not covered by tests
IA = guess_activity(typeof(x), mode)
grad = make_zero(x)
annotated_contexts = translate(backend, mode, Val(1), contexts...)

Check warning on line 197 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L197

Added line #L197 was not covered by tests
dinputs = only(
autodiff(mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...)
autodiff(mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts...)
)
new_grad = first(dinputs)
if isnothing(new_grad)
Expand All @@ -217,12 +212,13 @@
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
mode = reverse_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)

Check warning on line 216 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L216

Added line #L216 was not covered by tests
IA = guess_activity(typeof(x), mode)
grad = make_zero(x)
annotated_contexts = translate(backend, mode, Val(1), contexts...)

Check warning on line 219 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L219

Added line #L219 was not covered by tests
dinputs, result = autodiff(
mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...
mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts...
)
new_grad = first(dinputs)
if isnothing(new_grad)
Expand Down Expand Up @@ -263,16 +259,12 @@
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
mode = reverse_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)

Check warning on line 263 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L262-L263

Added lines #L262 - L263 were not covered by tests
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
make_zero!(grad_righttype)
autodiff(
reverse_noprimal(backend),
f_and_df,
Active,
Duplicated(x, grad_righttype),
map(translate, contexts)...,
)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
autodiff(mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...)

Check warning on line 267 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L266-L267

Added lines #L266 - L267 were not covered by tests
grad === grad_righttype || copyto!(grad, grad_righttype)
return grad
end
Expand All @@ -295,15 +287,13 @@
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
mode = reverse_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)

Check warning on line 291 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L290-L291

Added lines #L290 - L291 were not covered by tests
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
make_zero!(grad_righttype)
annotated_contexts = translate(backend, mode, Val(1), contexts...)

Check warning on line 294 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl#L294

Added line #L294 was not covered by tests
_, y = autodiff(
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, grad_righttype),
map(translate, contexts)...,
mode, f_and_df, Active, Duplicated(x, grad_righttype), annotated_contexts...
)
grad === grad_righttype || copyto!(grad, grad_righttype)
return y, grad
Expand Down
Loading
Loading