Skip to content

Commit

Permalink
Merge pull request #81 from aml5600/andrew/callbackset
Browse files Browse the repository at this point in the history
Expand callbacks to include/use CallbackSet
  • Loading branch information
ChrisRackauckas authored Sep 15, 2020
2 parents 0a88484 + 82cdce0 commit 0f4946f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 64 deletions.
144 changes: 82 additions & 62 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ function batch_solve(ensembleprob,alg,ensemblealg::EnsembleArrayAlgorithm,I;kwar
colorvec = nothing
end

_callback = generate_callback(probs[1],length(I),ensemblealg)
_callback = generate_callback(probs[1],length(I),ensemblealg; kwargs...)
prob = generate_problem(probs[1],u0,p,jac_prototype,colorvec)

if hasproperty(alg, :linsolve)
Expand All @@ -235,9 +235,8 @@ function batch_solve(ensembleprob,alg,ensemblealg::EnsembleArrayAlgorithm,I;kwar
_alg = alg
end

sol = solve(prob,_alg; callback = _callback,merge_callbacks = false,
internalnorm=diffeqgpunorm,
kwargs...)
sol = solve(prob,_alg; kwargs..., callback = _callback,merge_callbacks = false,
internalnorm=diffeqgpunorm)

us = Array.(sol.u)
solus = [[@view(us[i][:,j]) for i in 1:length(us)] for j in 1:length(probs)]
Expand Down Expand Up @@ -386,74 +385,95 @@ function generate_problem(prob::SDEProblem,u0,p,jac_prototype,colorvec)
prob.kwargs...)
end

function generate_callback(prob,I,ensemblealg)
if :callback keys(prob.kwargs)
_callback = nothing
elseif prob.kwargs[:callback] isa DiscreteCallback
if ensemblealg isa EnsembleGPUArray
cur = CuArray([false for i in 1:I])
else
cur = [false for i in 1:I]
end
_condition = prob.kwargs[:callback].condition
_affect! = prob.kwargs[:callback].affect!
function generate_callback(callback::DiscreteCallback,I,ensemblealg)
if ensemblealg isa EnsembleGPUArray
cur = CuArray([false for i in 1:I])
else
cur = [false for i in 1:I]
end
_condition = callback.condition
_affect! = callback.affect!

condition = function (u,t,integrator)
version = u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(u,2))
wait(version, discrete_condition_kernel(version)(_condition,cur,u,t,integrator.p;
ndrange=size(u,2),
dependencies=Event(version),
workgroupsize=wgs))
any(cur)
end

condition = function (u,t,integrator)
version = u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(u,2))
wait(version, discrete_condition_kernel(version)(_condition,cur,u,t,integrator.p;
ndrange=size(u,2),
dependencies=Event(version),
workgroupsize=wgs))
any(cur)
end
affect! = function (integrator)
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(integrator.u,2))
wait(version, discrete_affect!_kernel(version)(_affect!,cur,integrator.u,integrator.t,integrator.p;
ndrange=size(integrator.u,2),
dependencies=Event(version),
workgroupsize=wgs))
end

affect! = function (integrator)
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(integrator.u,2))
wait(version, discrete_affect!_kernel(version)(_affect!,cur,integrator.u,integrator.t,integrator.p;
ndrange=size(integrator.u,2),
return DiscreteCallback(condition,affect!,save_positions=callback.save_positions)
end

function generate_callback(callback::ContinuousCallback,I,ensemblealg)
_condition = callback.condition
_affect! = callback.affect!
_affect_neg! = callback.affect_neg!

condition = function (out,u,t,integrator)
version = u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(u,2))
wait(version, continuous_condition_kernel(version)(_condition,out,u,t,integrator.p;
ndrange=size(u,2),
dependencies=Event(version),
workgroupsize=wgs))
end
nothing
end

_callback = DiscreteCallback(condition,affect!,save_positions=prob.kwargs[:callback].save_positions)
elseif prob.kwargs[:callback] isa ContinuousCallback
_condition = prob.kwargs[:callback].condition
_affect! = prob.kwargs[:callback].affect!
_affect_neg! = prob.kwargs[:callback].affect_neg!
affect! = function (integrator,event_idx)
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(integrator.u,2))
wait(version, continuous_affect!_kernel(version)(_affect!,event_idx,integrator.u,integrator.t,integrator.p;
ndrange=size(integrator.u,2),
dependencies=Event(version),
workgroupsize=wgs))
end

condition = function (out,u,t,integrator)
version = u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(u,2))
wait(version, continuous_condition_kernel(version)(_condition,out,u,t,integrator.p;
ndrange=size(u,2),
dependencies=Event(version),
workgroupsize=wgs))
nothing
end
affect_neg! = function (integrator,event_idx)
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(integrator.u,2))
wait(version, continuous_affect!_kernel(version)(_affect_neg!,event_idx,integrator.u,integrator.t,integrator.p;
ndrange=size(integrator.u,2),
dependencies=Event(version),
workgroupsize=wgs))
end

affect! = function (integrator,event_idx)
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(integrator.u,2))
wait(version, continuous_affect!_kernel(version)(_affect!,event_idx,integrator.u,integrator.t,integrator.p;
ndrange=size(integrator.u,2),
dependencies=Event(version),
workgroupsize=wgs))
end
return VectorContinuousCallback(condition,affect!,affect_neg!,I,save_positions=callback.save_positions)
end

affect_neg! = function (integrator,event_idx)
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(integrator.u,2))
wait(version, continuous_affect!_kernel(version)(_affect_neg!,event_idx,integrator.u,integrator.t,integrator.p;
ndrange=size(integrator.u,2),
dependencies=Event(version),
workgroupsize=wgs))
end
function generate_callback(callback::CallbackSet,I,ensemblealg)
return CallbackSet(map(cb->generate_callback(cb,I,ensemblealg),
(callback.continuous_callbacks..., callback.discrete_callbacks...))...)
end

_callback = VectorContinuousCallback(condition,affect!,affect_neg!,I,save_positions=prob.kwargs[:callback].save_positions)
generate_callback(::Tuple{},I,ensemblealg) = nothing

function generate_callback(x)
# will catch any VectorContinuousCallbacks
error("Callback unsupported")
end

function generate_callback(prob,I,ensemblealg; kwargs...)
prob_cb = get(prob.kwargs, :callback, ())
kwarg_cb = get(kwargs, :merge_callbacks, false) ? get(kwargs, :callback, ()) : ()

if isempty(prob_cb) && isempty(kwarg_cb)
return nothing
else
return CallbackSet(generate_callback(prob_cb,I,ensemblealg),
generate_callback(kwarg_cb,I,ensemblealg))
end
_callback
end

### GPU Factorization
Expand Down
19 changes: 17 additions & 2 deletions test/ensemblegpuarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ affect! = function (integrator)
@inbounds integrator.u[1] = -4
end

callback_prob = ODEProblem(lorenz,u0,tspan,p,callback=DiscreteCallback(condition,affect!,save_positions=(false,false)))
# test discrete
discrete_callback = DiscreteCallback(condition,affect!,save_positions=(false,false))
callback_prob = ODEProblem(lorenz,u0,tspan,p,callback=discrete_callback)
callback_monteprob = EnsembleProblem(callback_prob, prob_func = prob_func)
@time solve(callback_monteprob,Tsit5(),EnsembleGPUArray(),trajectories=10,saveat=1.0f0)

Expand All @@ -101,10 +103,23 @@ c_affect! = function (integrator)
@inbounds integrator.u[1] += 20
end

callback_prob = ODEProblem(lorenz,u0,tspan,p,callback=ContinuousCallback(c_condition,c_affect!,save_positions=(false,false)))
# test continuous
continuous_callback = ContinuousCallback(c_condition,c_affect!,save_positions=(false,false))
callback_prob = ODEProblem(lorenz,u0,tspan,p,callback=continuous_callback)
callback_monteprob = EnsembleProblem(callback_prob, prob_func = prob_func)
solve(callback_monteprob,Tsit5(),EnsembleGPUArray(),trajectories=2,saveat=1.0f0)

# test callback set
callback_set = CallbackSet(discrete_callback, continuous_callback)
callback_prob = ODEProblem(lorenz,u0,tspan,p,callback=callback_set)
callback_monteprob = EnsembleProblem(callback_prob, prob_func = prob_func)
solve(callback_monteprob,Tsit5(),EnsembleGPUArray(),trajectories=2,saveat=1.0f0)

# test merge
callback_prob = ODEProblem(lorenz,u0,tspan,p,callback=discrete_callback)
callback_monteprob = EnsembleProblem(callback_prob, prob_func = prob_func)
solve(callback_monteprob,Tsit5(),EnsembleGPUArray(),trajectories=2,saveat=1.0f0, callback=continuous_callback)

@info "ROBER"

#=
Expand Down

0 comments on commit 0f4946f

Please sign in to comment.