Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding mutable checkpointing using Enzyme #8

Merged
merged 6 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <mschanen@anl.gov>", "Sri Hari Krishna Narayanan <snarayan@mcs.anl.gov>"]
version = "0.1.0"
version = "0.2.0"

[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
julia = "1.6"
Enzyme = "0.9"

[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand Down
86 changes: 86 additions & 0 deletions examples/mutable/optcontrol.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# This is a Julia version of Solution of the optimal control problem
# based on code written by Andrea Walther. See:
# Walther, Andrea, and Narayanan, Sri Hari Krishna. Extending the Binomial Checkpointing
# Technique for Resilience. United States: N. p., 2016. https://www.osti.gov/biblio/1364654.

using Checkpointing


include("optcontrolfunc.jl")

function header()
println("**************************************************************************")
println("* Solution of the optimal control problem *")
println("* *")
println("* J(y) = y_2(1) -> min *")
println("* s.t. dy_1/dt = 0.5*y_1(t) + u(t), y_1(0)=1 *")
println("* dy_2/dt = y_1(t)^2 + 0.5*u(t)^2 y_2(0)=0 *")
println("* *")
println("* the adjoints equations fulfill *")
println("* *")
println("* dl_1/dt = -0.5*l_1(t) - 2*y_1(t)*l_2(t) l_1(1)=0 *")
println("* dl_2/dt = 0 l_2(1)=1 *")
println("* *")
println("* with Revolve for Online and (Multi-Stage) Offline Checkpointing *")
println("* *")
println("**************************************************************************")

println("**************************************************************************")
println("* The solution of the optimal control problem above is *")
println("* *")
println("* y_1*(t) = (2*e^(3t)+e^3)/(e^(3t/2)*(2+e^3)) *")
println("* y_2*(t) = (2*e^(3t)-e^(6-3t)-2+e^6)/((2+e^3)^2) *")
println("* u*(t) = (2*e^(3t)-e^3)/(e^(3t/2)*(2+e^3)) *")
println("* l_1*(t) = (2*e^(3-t)-2*e^(2t))/(e^(t/2)*(2+e^3)) *")
println("* l_2*(t) = 1 *")
println("* *")
println("**************************************************************************")

return
end

function muoptcontrol(scheme, steps)
println( "\n STEPS -> number of time steps to perform")
println("SNAPS -> number of checkpoints")
println("INFO = 1 -> calculate only approximate solution")
println("INFO = 2 -> calculate approximate solution + takeshots")
println("INFO = 3 -> calculate approximate solution + all information ")
println(" ENTER: STEPS, SNAPS, INFO \n")


# F : output
# F_H : input
# L : seed the output adjoint
# L_H : set input adjoint to 0
F = [1.0, 0.0]
F_H = [0.0, 0.0]
L = [0.0, 1.0]
L_H = [0.0, 0.0]
t = 0.0
h = 1.0/steps
model = Model(F, F_H, t, h)
# t and h are not active so, set their adjoints to zero.
shadowmodel = Model(L, L_H, 0.0, 0.0)

@checkpoint_mutable scheme adtool model shadowmodel for i in 1:steps
model.F_H .= model.F
advance(model)
model.t += h
end

F = model.F
L = shadowmodel.F

F_opt = Array{Float64, 1}(undef, 2)
L_opt = Array{Float64, 1}(undef, 2)
opt_sol(F_opt,1.0)
opt_lambda(L_opt,0.0)
println("\n\n")
println("y_1*(1) = " , F_opt[1] , " y_2*(1) = " , F_opt[2])
println("y_1 (1) = " , F[1] , " y_2 (1) = " , F[2] , " \n\n")
println("l_1*(0) = " , L_opt[1] , " l_2*(0) = " , L_opt[2])
println("l_1 (0) = " , L[1] , " sl_2 (0) = " , L[2] , " ")
return model, shadowmodel, F_opt, L_opt
end

# main(10,3,3)
46 changes: 46 additions & 0 deletions examples/mutable/optcontrolfunc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using Enzyme
mutable struct Model
F::Vector{Float64}
F_H::Vector{Float64}
t::Float64
h::Float64
end

function func_U(t)
e = exp(1)
return 2.0*((e^(3.0*t))-(e^3))/((e^(3.0*t/2.0))*(2.0+(e^3)))
end

function func(F, X,t)
F[2] = X[1]*X[1]+0.5*(func_U(t)*func_U(t))
F[1] = 0.5*X[1]+ func_U(t)
return nothing
end

function advance(model)
F_H = model.F_H
F = model.F
t = model.t
h = model.h
func(F, F_H,t)
F[1] = F_H[1] + h/2.0*F[1]
F[2] = F_H[2] + h/2.0*F[2]
func(F,F,t+h/2.0)
model.F[1] = F_H[1] + h*F[1]
model.F[2] = F_H[2] + h*F[2]
return nothing
end

function opt_sol(Y,t)
e = exp(1)
Y[1] = (2.0*e^(3.0*t)+e^3)/(e^(3.0*t/2.0)*(2.0+e^3))
Y[2] = (2.0*e^(3.0*t)-e^(6.0-3.0*t)-2.0+e^6)/((2.0+e^3)^2)
return
end

function opt_lambda(L,t)
e = exp(1)
L[1] = (2.0*e^(3-t)-2.0*e^(2.0*t))/(e^(t/2.0)*(2+e^3))
L[2] = 1.0
return
end
78 changes: 73 additions & 5 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ function jacobian(tobedifferentiated, F_H, ::AbstractADTool)
error("No AD tool interface implemented")
end

export AbstractADTool, jacobian

export AbstractADTool, jacobian, @checkpoint, @checkpoint_mutable

include("Schemes/Revolve.jl")
include("Schemes/Periodic.jl")
Expand Down Expand Up @@ -78,8 +77,6 @@ macro checkpoint(alg, adtool, forloop)
L .= [0, 1]
t = 1.0-h
L_H .= L
lF = length(F)
lF_H = length(F_H)
L = Checkpointing.jacobian(tobedifferentiated, F_H, $adtool)[2,:]
elseif (next_action.actionflag == Checkpointing.uturn)
L_H .= L
Expand Down Expand Up @@ -138,6 +135,77 @@ macro checkpoint(alg, adtool, forloop)
esc(ex)
end

export @checkpoint
macro checkpoint_mutable(alg, adtool, model, shadowmodel, forloop)
ex = quote
function tobedifferentiated($model)
$(forloop.args[2])
return nothing
end
if isa($alg, Revolve)
storemap = Dict{Int32,Int32}()
check = 0
MT = typeof($model)
model_check = Array{MT}(undef, $alg.acp)
model_final = deepcopy($model)
while true
next_action = next_action!($alg)
if (next_action.actionflag == Checkpointing.store)
check = check+1
storemap[next_action.iteration-1]=check
model_check[check] = deepcopy($model)
elseif (next_action.actionflag == Checkpointing.forward)
for j= next_action.startiteration:(next_action.iteration - 1)
$(forloop.args[2])
end
elseif (next_action.actionflag == Checkpointing.firstuturn)
$(forloop.args[2])
model_final = deepcopy($model)
Enzyme.autodiff(tobedifferentiated, Duplicated($model,$shadowmodel))
elseif (next_action.actionflag == Checkpointing.uturn)
Enzyme.autodiff(tobedifferentiated, Duplicated($model,$shadowmodel))
if haskey(storemap,next_action.iteration-1-1)
delete!(storemap,next_action.iteration-1-1)
check=check-1
end
elseif (next_action.actionflag == Checkpointing.restore)
$model = deepcopy(model_check[storemap[next_action.iteration-1]])
elseif next_action.actionflag == Checkpointing.done
if haskey(storemap,next_action.iteration-1-1)
delete!(storemap,next_action.iteration-1-1)
check=check-1
end
break
end
end
$model = deepcopy(model_final)
elseif isa($alg, Periodic)
MT = typeof($model)
model_check_outer = Array{MT}(undef, $alg.acp)
model_check_inner = Array{MT}(undef, $alg.period)
model_final = deepcopy($model)
check = 0
for i = 1:$alg.acp
model_check_outer[i] = deepcopy($model)
for j= (i-1)*$alg.period: (i)*$alg.period-1
$(forloop.args[2])
end
end
model_final = deepcopy($model)
for i = $alg.acp:-1:1
$model = deepcopy(model_check_outer[i])
for j= 1:$alg.period
model_check_inner[j] = deepcopy($model)
$(forloop.args[2])
end
for j= $alg.period:-1:1
$model = deepcopy(model_check_inner[j])
Enzyme.autodiff(tobedifferentiated, Duplicated($model,$shadowmodel))
end
end
$model = deepcopy(model_final)
end
end
esc(ex)
end

end
14 changes: 11 additions & 3 deletions src/Schemes/Periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@ mutable struct Periodic <: Scheme
acp::Int
period::Int
verbose::Int
fstore::Function
frestore::Function
fstore::Union{Function,Nothing}
frestore::Union{Function,Nothing}
end

function Periodic(steps::Int, checkpoints::Int, fstore::Function, frestore::Function; anActionInstance::Union{Nothing,Action} = nothing, bundle_::Union{Nothing,Int} = nothing, verbose::Int = 0)
function Periodic(
steps::Int,
checkpoints::Int,
fstore::Union{Function,Nothing} = nothing,
frestore::Union{Function,Nothing} = nothing;
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0
)
if !isa(anActionInstance, Nothing)
# same as default init above
anActionInstance.actionflag = 0
Expand Down
14 changes: 11 additions & 3 deletions src/Schemes/Revolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@ mutable struct Revolve <: Scheme
firstuturned::Bool
stepof::Vector{Int}
verbose::Int
fstore::Function
frestore::Function
fstore::Union{Function,Nothing}
frestore::Union{Function,Nothing}
end

function Revolve(steps::Int, checkpoints::Int, fstore::Function, frestore::Function; anActionInstance::Union{Nothing,Action} = nothing, bundle_::Union{Nothing,Int} = nothing, verbose::Int = 0)
function Revolve(
steps::Int,
checkpoints::Int,
fstore::Union{Function,Nothing} = nothing,
frestore::Union{Function,Nothing} = nothing;
anActionInstance::Union{Nothing,Action} = nothing,
bundle_::Union{Nothing,Int} = nothing,
verbose::Int = 0
)
if !isa(anActionInstance, Nothing)
# same as default init above
anActionInstance.actionflag = 0
Expand Down
19 changes: 19 additions & 0 deletions test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Checkpointing
using LinearAlgebra
using Enzyme
using Test

include("examples/mutable/optcontrol.jl")

function chkmutable()
global steps = 100
global snaps = 3
global info = 0

revolve = Revolve(steps, snaps; verbose=info)

model, shadowmodel = muoptcontrol(revolve, steps)
return model.F, shadowmodel.F
end

mF, mL = chkmutable()
Loading