Skip to content

Commit

Permalink
allow out of place function definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 22, 2020
1 parent 13886fc commit 5f178dc
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,23 @@ using ForwardDiff
@views @inbounds f(du[:,i],u[:,i],p[:,i],t)
end

@kernel function gpu_kernel_oop(@Const(f),du,@Const(u),@Const(p),@Const(t))
i = @index(Global, Linear)
@views @inbounds du[:,i] = f(u[:,i],p[:,i],t)
end

@kernel function jac_kernel(@Const(f),J,@Const(u),@Const(p),@Const(t))
i = @index(Global, Linear)-1
section = 1 + (i*size(u,1)) : ((i+1)*size(u,1))
@views @inbounds f(J[section,section],u[:,i+1],p[:,i+1],t)
end

@kernel function jac_kernel_oop(@Const(f),J,@Const(u),@Const(p),@Const(t))
i = @index(Global, Linear)-1
section = 1 + (i*size(u,1)) : ((i+1)*size(u,1))
@views @inbounds J[section,section] = f(u[:,i+1],p[:,i+1],t)
end

@kernel function discrete_condition_kernel(@Const(condition),cur,@Const(u),@Const(t),@Const(p))
i = @index(Global, Linear)
@views @inbounds cur[i] = condition(u[:,i],t,FakeIntegrator(u[:,i],t,p[:,i]))
Expand Down Expand Up @@ -207,11 +218,11 @@ function batch_solve(ensembleprob,alg,ensemblealg,I;kwargs...)
end

function generate_problem(prob::ODEProblem,u0,p,jac_prototype,colorvec)
_f = let f=prob.f.f
_f = let f=prob.f.f, kernel = DiffEqBase.isinplace(prob) ? gpu_kernel : gpu_kernel_oop
function (du,u,p,t)
version = u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(u,2))
wait(version, gpu_kernel(version)(f,du,u,p,t;ndrange=size(u,2),
wait(version, kernel(version)(f,du,u,p,t;ndrange=size(u,2),
dependencies=Event(version),
workgroupsize=wgs))
end
Expand Down Expand Up @@ -272,11 +283,11 @@ function generate_problem(prob::ODEProblem,u0,p,jac_prototype,colorvec)
end

function generate_problem(prob::SDEProblem,u0,p,jac_prototype,colorvec)
_f = let f=prob.f.f
_f = let f=prob.f.f, kernel = DiffEqBase.isinplace(prob) ? gpu_kernel : gpu_kernel_oop
function (du,u,p,t)
version = u isa CuArray ? CUDADevice() : CPU()
wgs = workgroupsize(version,size(u,2))
wait(version, gpu_kernel(version)(f,du,u,p,t;
wait(version, kernel(version)(f,du,u,p,t;
ndrange=size(u,2),
dependencies=Event(version),
workgroupsize=wgs))
Expand Down
18 changes: 18 additions & 0 deletions test/ensemblegpuarray_oop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using DiffEqGPU, OrdinaryDiffEq, StaticArrays

function lorenz(u,p,t)
@inbounds begin
du1 = p[1]*(u[2]-u[1])
du2 = u[1]*(p[2]-u[3]) - u[2]
du3 = u[1]*u[2] - p[3]*u[3]
SA[du1,du2,du3]
end
end

u0 = SA[1f0;0f0;0f0]
tspan = (0.0f0,100.0f0)
p = SA[10.0f0,28.0f0,8/3f0]
prob = ODEProblem(lorenz,u0,tspan,p)
prob_func = (prob,i,repeat) -> remake(prob,p=rand(Float32,3).*p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy=false)
@time sol = solve(monteprob,Tsit5(),EnsembleGPUArray(),trajectories=10_000,saveat=1.0f0)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ end
using SafeTestsets, Test

@time @safetestset "EnsembleGPUArray" begin include("ensemblegpuarray.jl") end
@time @safetestset "EnsembleGPUArray OOP" begin include("ensemblegpuarray_oop.jl") end
@time @safetestset "EnsembleGPUArray SDE" begin include("ensemblegpuarray_sde.jl") end
@time @safetestset "EnsembleGPUArray Input Types" begin include("ensemblegpuarray_inputtypes.jl") end
@time @safetestset "Reduction" begin include("reduction.jl") end
Expand Down

0 comments on commit 5f178dc

Please sign in to comment.