Skip to content

Commit

Permalink
Merge branch 'master' into continuous
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 22, 2020
2 parents 48a8b4e + fd60da8 commit 10de5ca
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqGPU"
uuid = "071ae1c0-96b5-11e9-1965-c90190d839ea"
authors = ["Chris Rackauckas", "JuliaDiffEq"]
version = "1.5.0"
version = "1.6.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# DiffEqGPU

[![GitlabCI](https://gitlab.com/JuliaGPU/DiffEqGPU-jl/badges/master/pipeline.svg)](https://gitlab.com/JuliaGPU/DiffEqGPU-jl/pipelines)
[![GitlabCI](https://gitlab.com/JuliaGPU/DiffEqGPU.jl/badges/master/pipeline.svg)](https://gitlab.com/JuliaGPU/DiffEqGPU.jl/pipelines)

This library is a component package of the DifferentialEquations.jl ecosystem. It includes functionality for making
use of GPUs in the differential equation solvers.
Expand Down Expand Up @@ -52,7 +52,7 @@ tspan = (0.0f0,100.0f0)
p = [10.0f0,28.0f0,8/3f0]
prob = ODEProblem(lorenz,u0,tspan,p)
prob_func = (prob,i,repeat) -> remake(prob,p=rand(Float32,3).*p)
monteprob = EnsembleProblem(prob, prob_func = prob_func)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy=false)
@time sol = solve(monteprob,Tsit5(),EnsembleGPUArray(),trajectories=10_000,saveat=1.0f0)
```

Expand All @@ -70,7 +70,7 @@ Not everything is supported yet, but most of the standard features have support,
#### Current Limitations

If you receive a compilation error, it is likely because something is not allowed by the automated
kernel building of [GPUifyLoops.jl](https://github.com/vchuravy/GPUifyLoops.jl). The most common issues are:
kernel building of [KernelAbstractions.jl](https://github.com/JuliaGPU/KernelAbstractions.jl). The most common issues are:

- Bounds checking is not allowed
- Return values are not allowed
Expand Down
34 changes: 20 additions & 14 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@ using CUDA: CuPtr, CU_NULL, Mem, CuDefaultStream
using CUDA: CUBLAS
using ForwardDiff

@kernel function gpu_kernel(f,du,u,p,t)
@kernel function gpu_kernel(@Const(f),du,@Const(u),@Const(p),@Const(t))
i = @index(Global, Linear)
@views @inbounds f(du[:,i],u[:,i],p[:,i],t)
end

@kernel function jac_kernel(f,J,u,p,t)
@kernel function jac_kernel(@Const(f),J,@Const(u),@Const(p),@Const(t))
i = @index(Global, Linear)-1
section = 1 + (i*size(u,1)) : ((i+1)*size(u,1))
@views @inbounds f(J[section,section],u[:,i+1],p[:,i+1],t)
end

@kernel function discrete_condition_kernel(condition,cur,u,t,p)
@kernel function discrete_condition_kernel(@Const(condition),cur,@Const(u),@Const(t),@Const(p))
i = @index(Global, Linear)
@views @inbounds cur[i] = condition(u[:,i],t,FakeIntegrator(u[:,i],t,p[:,i]))
end

@kernel function discrete_affect!_kernel(affect!,cur,u,t,p)
@kernel function discrete_affect!_kernel(@Const(affect!),cur,u,t,p)
i = @index(Global, Linear)
@views @inbounds cur[i] && affect!(FakeIntegrator(u[:,i],t,p[:,i]))
end

@kernel function continuous_condition_kernel(condition,out,u,t,p)
@kernel function continuous_condition_kernel(@Const(condition),out,@Const(u),@Const(t),@Const(p))
i = @index(Global, Linear)
@views @inbounds out[i] = condition(u[:,i],t,FakeIntegrator(u[:,i],t,p[:,i]))
end
Expand All @@ -44,7 +44,7 @@ function workgroupsize(backend, n)
min(maxthreads(backend),n)
end

@kernel function W_kernel(jac, W, u, p, gamma, t)
@kernel function W_kernel(@Const(jac), W, @Const(u), @Const(p), @Const(gamma), @Const(t))
i = @index(Global, Linear)
len = size(u,1)
_W = @inbounds @view(W[:, :, i])
Expand All @@ -58,7 +58,7 @@ end
end
end

@kernel function Wt_kernel(jac, W, u, p, gamma, t)
@kernel function Wt_kernel(@Const(jac), W, @Const(u), @Const(p), @Const(gamma), @Const(t))
i = @index(Global, Linear)
len = size(u,1)
_W = @inbounds @view(W[:, :, i])
Expand Down Expand Up @@ -95,18 +95,20 @@ struct EnsembleGPUArray <: EnsembleArrayAlgorithm end
function DiffEqBase.__solve(ensembleprob::DiffEqBase.AbstractEnsembleProblem,
alg::Union{DiffEqBase.DEAlgorithm,Nothing},
ensemblealg::EnsembleArrayAlgorithm;
trajectories, batch_size = trajectories, kwargs...)
trajectories, batch_size = trajectories,
unstable_check = (dt,u,p,t)->false,
kwargs...)

num_batches = trajectories ÷ batch_size
num_batches * batch_size != trajectories && (num_batches += 1)

if num_batches == 1 && ensembleprob.reduction === DiffEqBase.DEFAULT_REDUCTION
time = @elapsed sol = batch_solve(ensembleprob,alg,ensemblealg,1:trajectories;kwargs...)
time = @elapsed sol = batch_solve(ensembleprob,alg,ensemblealg,1:trajectories;unstable_check=unstable_check,kwargs...)
return DiffEqBase.EnsembleSolution(sol,time,true)
end

converged::Bool = false
u = ensembleprob.u_init === nothing ? similar(batch_solve(ensembleprob,alg,ensemblealg,1:batch_size;kwargs...), 0) : ensembleprob.u_init
u = ensembleprob.u_init === nothing ? similar(batch_solve(ensembleprob,alg,ensemblealg,1:batch_size;unstable_check=unstable_check,kwargs...), 0) : ensembleprob.u_init

if nprocs() == 1
# While pmap works, this makes much better error messages.
Expand All @@ -117,7 +119,7 @@ function DiffEqBase.__solve(ensembleprob::DiffEqBase.AbstractEnsembleProblem,
else
I = (batch_size*(i-1)+1):batch_size*i
end
batch_data = batch_solve(ensembleprob,alg,ensemblealg,I;kwargs...)
batch_data = batch_solve(ensembleprob,alg,ensemblealg,I;unstable_check=unstable_check,kwargs...)
if ensembleprob.reduction !== DiffEqBase.DEFAULT_REDUCTION
u, _ = ensembleprob.reduction(u,batch_data,I)
return u
Expand All @@ -134,7 +136,7 @@ function DiffEqBase.__solve(ensembleprob::DiffEqBase.AbstractEnsembleProblem,
else
I = (batch_size*(i-1)+1):batch_size*i
end
x = batch_solve(ensembleprob,alg,ensemblealg,I;kwargs...)
x = batch_solve(ensembleprob,alg,ensemblealg,I;unstable_check=unstable_check,kwargs...)
yield()
if ensembleprob.reduction !== DiffEqBase.DEFAULT_REDUCTION
u, _ = ensembleprob.reduction(u,x,I)
Expand All @@ -159,7 +161,11 @@ diffeqgpunorm(u::AbstractArray{<:ForwardDiff.Dual},t) = sqrt(sum(abs2∘ForwardD
diffeqgpunorm(u::ForwardDiff.Dual,t) = abs(ForwardDiff.value(u))

function batch_solve(ensembleprob,alg,ensemblealg,I;kwargs...)
probs = [ensembleprob.prob_func(deepcopy(ensembleprob.prob),i,1) for i in I]
if ensembleprob.safetycopy
probs = [ensembleprob.prob_func(deepcopy(ensembleprob.prob),i,1) for i in I]
else
probs = [ensembleprob.prob_func(ensembleprob.prob,i,1) for i in I]
end
@assert all(p->p.tspan == probs[1].tspan,probs)
@assert !isempty(I)
#@assert all(p->p.f === probs[1].f,probs)
Expand Down Expand Up @@ -435,7 +441,7 @@ function (p::LinSolveGPUSplitFactorize)(::Type{Val{:init}},f,u0_prototype)
LinSolveGPUSplitFactorize(size(u0_prototype)...)
end

@kernel function ldiv!_kernel(W,u,len,nfacts)
@kernel function ldiv!_kernel(W,u,@Const(len),@Const(nfacts))
i = @index(Global, Linear)
section = 1 + ((i-1)*len) : (i*len)
_W = @inbounds @view(W[:, :, i])
Expand Down

0 comments on commit 10de5ca

Please sign in to comment.