Skip to content

Commit

Permalink
fixes #44
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed May 2, 2024
1 parent 5eeaeb7 commit 0bf73fd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ end
function sumheat(heat::Heat, chkpscheme::Scheme, tsteps::Int64)
# AD: Create shadow copy for derivatives
@checkpoint_struct chkpscheme heat for i in 1:tsteps
# checkpoint_struct_for(advance, heat)
heat.Tlast .= heat.Tnext
advance(heat)
end
Expand All @@ -87,7 +86,7 @@ function heat(scheme::Scheme, tsteps::Int)
heat.Tnext[end] = 0

# Compute gradient
autodiff(Enzyme.ReverseWithPrimal, sumheat, Duplicated(heat, dheat), scheme, tsteps)
autodiff(Enzyme.ReverseWithPrimal, sumheat, Duplicated(heat, dheat), Const(scheme), Const(tsteps))

return heat.Tnext, dheat.Tnext[2:end-1]
end
Expand Down
4 changes: 2 additions & 2 deletions examples/optcontrol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Technique for Resilience. United States: N. p., 2016. https://www.osti.gov/biblio/1364654.

using Checkpointing
using Zygote
using Enzyme


include("optcontrolfunc.jl")
Expand Down Expand Up @@ -69,7 +69,7 @@ function muoptcontrol(scheme, steps, ::EnzymeTool)
end
return model.F[2]
end
autodiff(Enzyme.ReverseWithPrimal, foo, Duplicated(model, bmodel))
autodiff(Enzyme.Reverse, foo, Duplicated(model, bmodel))

F = model.F
L = bmodel.F
Expand Down
9 changes: 7 additions & 2 deletions src/Rules/EnzymeRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ function augmented_primal(
model,
range,
)
primal = func.val(body.val, alg.val, deepcopy(model.val), range.val)
if needs_primal(config)
primal = func.val(body.val, alg.val, model.val, range.val)
return AugmentedReturn(primal, nothing, (model.val,))
else
return AugmentedReturn(nothing, nothing, (model.val,))
Expand Down Expand Up @@ -50,8 +50,13 @@ function augmented_primal(
model,
condition,
)
primal = func.val(body.val, alg.val, deepcopy(model.val), condition.val)
if needs_primal(config)
return AugmentedReturn(func.val(body.val, alg.val, model.val, condition.val), nothing, (model.val,))
return AugmentedReturn(
primal,
nothing,
(model.val,),
)
else
return AugmentedReturn(nothing, nothing, (model.val,))
end
Expand Down

0 comments on commit 0bf73fd

Please sign in to comment.