Skip to content

Commit

Permalink
Merge c2e20c7 into 849beb7
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jul 28, 2020
2 parents 849beb7 + c2e20c7 commit 894068e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,33 @@ using CUDA: CuPtr, CU_NULL, Mem, CuDefaultStream
using CUDA: CUBLAS
using ForwardDiff

@kernel function gpu_kernel(f,du,u,p,t)
@kernel function gpu_kernel(@Const(f),du,@Const(u),@Const(p),@Const(t))
i = @index(Global, Linear)
@views @inbounds f(du[:,i],u[:,i],p[:,i],t)
end

@kernel function jac_kernel(f,J,u,p,t)
@kernel function jac_kernel(@Const(f),J,@Const(u),@Const(p),@Const(t))
i = @index(Global, Linear)-1
section = 1 + (i*size(u,1)) : ((i+1)*size(u,1))
@views @inbounds f(J[section,section],u[:,i+1],p[:,i+1],t)
end

@kernel function discrete_condition_kernel(condition,cur,u,t,p)
@kernel function discrete_condition_kernel(@Const(condition),cur,@Const(u),@Const(t),@Const(p))
i = @index(Global, Linear)
@views @inbounds cur[i] = condition(u[:,i],t,FakeIntegrator(u[:,i],t,p[:,i]))
end

@kernel function discrete_affect!_kernel(affect!,cur,u,t,p)
@kernel function discrete_affect!_kernel(@Const(affect!),cur,u,t,p)
i = @index(Global, Linear)
@views @inbounds cur[i] && affect!(FakeIntegrator(u[:,i],t,p[:,i]))
end

@kernel function continuous_condition_kernel(condition,out,u,t,p)
@kernel function continuous_condition_kernel(@Const(condition),out,@Const(u),@Const(t),@Const(p))
i = @index(Global, Linear)
@views @inbounds out[i] = condition(u[:,i],t,FakeIntegrator(u[:,i],t,p[:,i]))
end

@kernel function continuous_affect!_kernel(affect!,event_idx,u,t,p)
@kernel function continuous_affect!_kernel(@Const(affect!),@Const(event_idx),u,t,p)
i = @index(Global, Linear)
@views @inbounds affect!(FakeIntegrator(u[:,i],t,p[:,i]))
end
Expand All @@ -43,7 +43,7 @@ function workgroupsize(backend, n)
min(maxthreads(backend),n)
end

@kernel function W_kernel(jac, W, u, p, gamma, t)
@kernel function W_kernel(@Const(jac), W, @Const(u), @Const(p), @Const(gamma), @Const(t))
i = @index(Global, Linear)
len = size(u,1)
_W = @inbounds @view(W[:, :, i])
Expand All @@ -57,7 +57,7 @@ end
end
end

@kernel function Wt_kernel(jac, W, u, p, gamma, t)
@kernel function Wt_kernel(@Const(jac), W, @Const(u), @Const(p), @Const(gamma), @Const(t))
i = @index(Global, Linear)
len = size(u,1)
_W = @inbounds @view(W[:, :, i])
Expand Down Expand Up @@ -434,7 +434,7 @@ function (p::LinSolveGPUSplitFactorize)(::Type{Val{:init}},f,u0_prototype)
LinSolveGPUSplitFactorize(size(u0_prototype)...)
end

@kernel function ldiv!_kernel(W,u,len,nfacts)
@kernel function ldiv!_kernel(W,@Const(u),@Const(len),@Const(nfacts))
i = @index(Global, Linear)
section = 1 + ((i-1)*len) : (i*len)
_W = @inbounds @view(W[:, :, i])
Expand Down

0 comments on commit 894068e

Please sign in to comment.