Skip to content

Commit 70b0820

Browse files
authoredJan 30, 2023
Merge pull request #226 from utkarsh530/u/gpusde
Move SDE solvers to KernelAbstractions.jl
2 parents 6007f06 + 9c80aa0 commit 70b0820

File tree

4 files changed

+18
-32
lines changed

4 files changed

+18
-32
lines changed
 

‎src/perform_step/gpu_em_perform_step.jl

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
function em_kernel(probs, _us, _ts, dt,
2-
saveat, ::Val{save_everystep}) where {save_everystep}
3-
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
4-
i > length(probs) && return
1+
@kernel function em_kernel(@Const(probs), _us, _ts, dt,
2+
saveat, ::Val{save_everystep}) where {save_everystep}
3+
i = @index(Global, Linear)
54

65
# get the actual problem for this thread
76
prob = @inbounds probs[i]
@@ -73,6 +72,4 @@ function em_kernel(probs, _us, _ts, dt,
7372
@inbounds us[2] = u
7473
@inbounds ts[2] = t
7574
end
76-
77-
return nothing
7875
end

‎src/perform_step/gpu_siea_perform_step.jl

+3-6
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,9 @@ function SIEAConstantCache(::Type{T}, ::Type{T2}) where {T, T2}
6161
β2, β3, δ2, δ3)
6262
end
6363

64-
function siea_kernel(probs, _us, _ts, dt,
65-
saveat, ::Val{save_everystep}) where {save_everystep}
66-
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
67-
i > length(probs) && return
64+
@kernel function siea_kernel(@Const(probs), _us, _ts, dt,
65+
saveat, ::Val{save_everystep}) where {save_everystep}
66+
i = @index(Global, Linear)
6867

6968
# get the actual problem for this thread
7069
prob = @inbounds probs[i]
@@ -155,6 +154,4 @@ function siea_kernel(probs, _us, _ts, dt,
155154
@inbounds us[2] = u
156155
@inbounds ts[2] = t
157156
end
158-
159-
return nothing
160157
end

‎src/perform_step/gpu_vern9_perform_step.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
end
103103

104104
@kernel function vern9_kernel(probs, _us, _ts, dt, callback, tstops, nsteps,
105-
saveat, ::Val{save_everystep}) where {save_everystep}
105+
saveat, ::Val{save_everystep}) where {save_everystep}
106106
i = @index(Global, Linear)
107107

108108
# get the actual problem for this thread

‎src/solve.jl

+11-19
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
5252

5353
# Handle tstops
5454
tstops = cu(tstops)
55-
dev = CUDADevice{#=prefer_blocks=#true}()
55+
dev = CUDADevice{true}() #=prefer_blocks=#
5656
if alg isa GPUTsit5
5757
kernel = tsit5_kernel(dev)
5858
elseif alg isa GPUVern7
@@ -61,7 +61,7 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
6161
kernel = vern9_kernel(dev)
6262
end
6363
event = kernel(probs, us, ts, dt, callback, tstops, nsteps, saveat, Val(save_everystep);
64-
ndrange=length(probs), dependencies=Event(dev))
64+
ndrange = length(probs), dependencies = Event(dev))
6565
wait(dev, event)
6666

6767
# we build the actual solution object on the CPU because the GPU would create one
@@ -94,27 +94,19 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
9494
us = CuMatrix{typeof(prob.u0)}(undef, (length(saveat), length(probs)))
9595
end
9696

97+
dev = CUDADevice{true}() #=prefer_blocks=#
98+
9799
if alg isa GPUEM
98-
kernel = @cuda launch=false em_kernel(probs, us, ts, dt,
99-
saveat, Val(save_everystep))
100+
kernel = em_kernel(dev)
100101
elseif alg isa Union{GPUSIEA}
101102
SciMLBase.is_diagonal_noise(prob) ? nothing :
102103
error("The algorithm is not compatible with the chosen noise type. Please see the documentation on the solver methods")
103-
kernel = @cuda launch=false siea_kernel(probs, us, ts, dt,
104-
saveat, Val(save_everystep))
105-
end
106-
if debug
107-
@show CUDA.registers(kernel)
108-
@show CUDA.memory(kernel)
104+
kernel = siea_kernel(dev)
109105
end
110106

111-
config = launch_configuration(kernel.fun)
112-
threads = min(length(probs), config.threads)
113-
# XXX: this kernel performs much better with all blocks active
114-
blocks = max(cld(length(probs), threads), config.blocks)
115-
threads = cld(length(probs), blocks)
116-
117-
kernel(probs, us, ts, dt, saveat; threads, blocks)
107+
event = kernel(probs, us, ts, dt, saveat, Val(save_everystep);
108+
ndrange = length(probs), dependencies = Event(dev))
109+
wait(dev, event)
118110

119111
ts, us
120112
end
@@ -147,7 +139,7 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
147139
end
148140

149141
tstops = cu(tstops)
150-
dev = CUDADevice{#=prefer_blocks=#true}()
142+
dev = CUDADevice{true}() #=prefer_blocks=#
151143
if alg isa GPUTsit5
152144
kernel = atsit5_kernel(dev)
153145
elseif alg isa GPUVern7
@@ -157,7 +149,7 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
157149
end
158150
event = kernel(probs, us, ts, dt, callback, tstops,
159151
abstol, reltol, saveat, Val(save_everystep);
160-
ndrange=length(probs), dependencies=Event(dev))
152+
ndrange = length(probs), dependencies = Event(dev))
161153
wait(dev, event)
162154

163155
# we build the actual solution object on the CPU because the GPU would create one

0 commit comments

Comments
 (0)
Failed to load comments.