Skip to content

Commit

Permalink
change CUDA device
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jul 7, 2020
1 parent 9787de4 commit 6e420a6
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ end
function generate_problem(prob::ODEProblem,u0,p,jac_prototype,colorvec)
_f = let f=prob.f.f
function (du,u,p,t)
version = u isa CuArray ? CUDA() : CPU()
version = u isa CuArray ? CUDADevice() : CPU()
wait(version, gpu_kernel(version)(f,du,u,p,t;ndrange=size(u,2),dependencies=Event(version)))
end
end
Expand All @@ -226,15 +226,15 @@ function generate_problem(prob::ODEProblem,u0,p,jac_prototype,colorvec)
_Wfact! = let jac=prob.f.jac
function (W,u,p,gamma,t)
iscuda = u isa CuArray
version = iscuda ? CUDA() : CPU()
version = iscuda ? CUDADevice() : CPU()
wait(version, W_kernel(version)(jac, W, u, p, gamma, t; ndrange=size(u,2),dependencies=Event(version)))
iscuda ? cuda_lufact!(W) : cpu_lufact!(W)
end
end
_Wfact!_t = let jac=prob.f.jac
function (W,u,p,gamma,t)
iscuda = u isa CuArray
version = iscuda ? CUDA() : CPU()
version = iscuda ? CUDADevice() : CPU()
wait(version, Wt_kernel(version)(jac, W, u, p, gamma, t; ndrange=size(u,2),dependencies=Event(version)))
iscuda ? cuda_lufact!(W) : cpu_lufact!(W)
end
Expand All @@ -247,7 +247,7 @@ function generate_problem(prob::ODEProblem,u0,p,jac_prototype,colorvec)
if DiffEqBase.has_tgrad(prob.f)
_tgrad = let tgrad=prob.f.tgrad
function (J,u,p,t)
version = u isa CuArray ? CUDA() : CPU()
version = u isa CuArray ? CUDADevice() : CPU()
wait(version, gpu_kernel(version)(tgrad,J,u,p,t;ndrange=size(u,2),dependencies=Event(version)))
end
end
Expand All @@ -267,14 +267,14 @@ end
function generate_problem(prob::SDEProblem,u0,p,jac_prototype,colorvec)
_f = let f=prob.f.f
function (du,u,p,t)
version = u isa CuArray ? CUDA() : CPU()
version = u isa CuArray ? CUDADevice() : CPU()
wait(version, gpu_kernel(version)(f,du,u,p,t;ndrange=size(u,2),dependencies=Event(version)))
end
end

_g = let f=prob.f.g
function (du,u,p,t)
version = u isa CuArray ? CUDA() : CPU()
version = u isa CuArray ? CUDADevice() : CPU()
wait(version, gpu_kernel(version)(f,du,u,p,t;ndrange=size(u,2),dependencies=Event(version)))
end
end
Expand All @@ -283,15 +283,15 @@ function generate_problem(prob::SDEProblem,u0,p,jac_prototype,colorvec)
_Wfact! = let jac=prob.f.jac
function (W,u,p,gamma,t)
iscuda = u isa CuArray
version = iscuda ? CUDA() : CPU()
version = iscuda ? CUDADevice() : CPU()
wait(version, W_kernel(version)(jac, W, u, p, gamma, t; ndrange=size(u,2),dependencies=Event(version)))
iscuda ? cuda_lufact!(W) : cpu_lufact!(W)
end
end
_Wfact!_t = let jac=prob.f.jac
function (W,u,p,gamma,t)
iscuda = u isa CuArray
version = iscuda ? CUDA() : CPU()
version = iscuda ? CUDADevice() : CPU()
wait(version, Wt_kernel(version)(jac, W, u, p, gamma, t; ndrange=size(u,2),dependencies=Event(version)))
iscuda ? cuda_lufact!(W) : cpu_lufact!(W)
end
Expand All @@ -304,7 +304,7 @@ function generate_problem(prob::SDEProblem,u0,p,jac_prototype,colorvec)
if DiffEqBase.has_tgrad(prob.f)
_tgrad = let tgrad=prob.f.tgrad
function (J,u,p,t)
version = u isa CuArray ? CUDA() : CPU()
version = u isa CuArray ? CUDADevice() : CPU()
wait(version, gpu_kernel(version)(tgrad,J,u,p,t;ndrange=size(u,2),dependencies=Event(version)))
end
end
Expand Down Expand Up @@ -334,13 +334,13 @@ function generate_callback(prob,I,ensemblealg)
_affect! = prob.kwargs[:callback].affect!

condition = function (u,t,integrator)
version = u isa CuArray ? CUDA() : CPU()
version = u isa CuArray ? CUDADevice() : CPU()
wait(version, discrete_condition_kernel(version)(_condition,cur,u,t,integrator.p;ndrange=size(u,2),dependencies=Event(version)))
any(cur)
end

affect! = function (integrator)
version = integrator.u isa CuArray ? CUDA() : CPU()
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wait(version, discrete_affect!_kernel(version)(_affect!,cur,integrator.u,integrator.t,integrator.p;ndrange=size(integrator.u,2),dependencies=Event(version)))
end

Expand All @@ -351,18 +351,18 @@ function generate_callback(prob,I,ensemblealg)
_affect_neg! = prob.kwargs[:callback].affect_neg!

condition = function (out,u,t,integrator)
version = u isa CuArray ? CUDA() : CPU()
version = u isa CuArray ? CUDADevice() : CPU()
wait(version, continuous_condition_kernel(version)(_condition,out,u,t,integrator.p;ndrange=size(u,2),dependencies=Event(version)))
nothing
end

affect! = function (integrator,event_idx)
version = integrator.u isa CuArray ? CUDA() : CPU()
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wait(version, continuous_affect!_kernel(version)(_affect!,event_idx,integrator.u,integrator.t,integrator.p;ndrange=size(integrator.u,2),dependencies=Event(version)))
end

affect_neg! = function (integrator,event_idx)
version = integrator.u isa CuArray ? CUDA() : CPU()
version = integrator.u isa CuArray ? CUDADevice() : CPU()
wait(version, continuous_affect!_kernel(version)(_affect_neg!,event_idx,integrator.u,integrator.t,integrator.p;ndrange=size(integrator.u,2),dependencies=Event(version)))
end

Expand All @@ -380,7 +380,7 @@ end
LinSolveGPUSplitFactorize() = LinSolveGPUSplitFactorize(0, 0)

function (p::LinSolveGPUSplitFactorize)(x,A,b,update_matrix=false;kwargs...)
version = b isa CuArray ? CUDA() : CPU()
version = b isa CuArray ? CUDADevice() : CPU()
copyto!(x, b)
wait(version, ldiv!_kernel(version)(A,x,p.len,p.nfacts;ndrange=p.nfacts,dependencies=Event(version)))
return nothing
Expand Down

0 comments on commit 6e420a6

Please sign in to comment.