Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Move code around.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Mar 20, 2019
1 parent 5af0b32 commit 7267c79
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 84 deletions.
51 changes: 51 additions & 0 deletions src/device/cuda/libcudadevrt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,57 @@
# the CUDA API for execution on the device, such as device synchronization primitives,
# dynamic kernel APIs, etc.

import CUDAdrv: CuDim3, CuStream_t

const cudaError_t = Cint
const cudaStream_t = CUDAdrv.CuStream_t

# device-side counterpart of CUDAdrv.launch
@inline function launch(f::Ptr{Cvoid}, blocks::CuDim, threads::CuDim,
shmem::Int, stream::CuStream,
args...)
blocks = CuDim3(blocks)
threads = CuDim3(threads)

buf = parameter_buffer(f, blocks, threads, shmem, args...)

ccall("extern cudaLaunchDeviceV2", llvmcall, cudaError_t,
(Ptr{Cvoid}, cudaStream_t),
buf, stream)

return
end

@generated function parameter_buffer(f::Ptr{Cvoid}, blocks::CuDim3, threads::CuDim3,
shmem::Int, args...)
# allocate a buffer
ex = quote
buf = ccall("extern cudaGetParameterBufferV2", llvmcall, Ptr{Cvoid},
(Ptr{Cvoid}, CuDim3, CuDim3, Cuint),
f, blocks, threads, shmem)
end

# store the parameters
#
# > Each individual parameter placed in the parameter buffer is required to be aligned.
# > That is, each parameter must be placed at the n-th byte in the parameter buffer,
# > where n is the smallest multiple of the parameter size that is greater than the
# > offset of the last byte taken by the preceding parameter. The maximum size of the
# > parameter buffer is 4KB.
offset = 0
for i in 1:length(args)
buf_index = Base.ceil(Int, offset / sizeof(args[i])) + 1
offset = buf_index * sizeof(args[i])
push!(ex.args, :(
unsafe_store!(Base.unsafe_convert(Ptr{$(args[i])}, buf), args[$i], $buf_index)
))
end

push!(ex.args, :(return buf))

return ex
end

"""
synchronize()
Expand Down
117 changes: 33 additions & 84 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,6 @@ macro cuda(ex...)

if dynamic
# dynamic, device-side kernel launch
#
# WIP
# TODO: GC.@preserve?
# TODO: error on, or support kwargs
push!(code.args,
quote
# we're in kernel land already, so no need to convert arguments
Expand Down Expand Up @@ -248,86 +244,6 @@ macro cuda(ex...)
return code
end

import CUDAdrv: CuDim3, CuStream_t

const cudaError_t = Cint
const cudaStream_t = CUDAdrv.CuStream_t

dynamic_cufunction(f::Core.Function, tt::Type=Tuple{}) =
ccall("extern cudanativeCompileKernel", llvmcall, Ptr{Cvoid}, (Any, Any), f, tt)

@generated function dynamic_cudacall(f::Ptr{Cvoid}, tt::Type, args...;
blocks::CuDim=1, threads::CuDim=1, shmem::Integer=0,
stream::CuStream=CuDefaultStream())
ex = quote
Base.@_inline_meta
end

# convert the argument values to match the kernel's signature (specified by the user)
# (this mimics `lower-ccall` in julia-syntax.scm)
converted_args = Vector{Symbol}(undef, length(args))
arg_ptrs = Vector{Symbol}(undef, length(args))
for i in 1:length(args)
converted_args[i] = gensym()
arg_ptrs[i] = gensym()
push!(ex.args, :($(converted_args[i]) = Base.cconvert($(args[i]), args[$i])))
push!(ex.args, :($(arg_ptrs[i]) = Base.unsafe_convert($(args[i]), $(converted_args[i]))))
end

append!(ex.args, (quote
#GC.@preserve $(converted_args...) begin
launch(f, blocks, threads, shmem, stream, ($(arg_ptrs...),))
#end
end).args)

return ex
end

@inline function launch(f::Ptr{Cvoid}, blocks::CuDim, threads::CuDim,
shmem::Int, stream::CuStream,
args...)
blocks = CuDim3(blocks)
threads = CuDim3(threads)

buf = parameter_buffer(f, blocks, threads, shmem, args...)

ccall("extern cudaLaunchDeviceV2", llvmcall, cudaError_t,
(Ptr{Cvoid}, cudaStream_t),
buf, stream)

return
end

@generated function parameter_buffer(f::Ptr{Cvoid}, blocks::CuDim3, threads::CuDim3,
shmem::Int, args...)
# allocate a buffer
ex = quote
buf = ccall("extern cudaGetParameterBufferV2", llvmcall, Ptr{Cvoid},
(Ptr{Cvoid}, CuDim3, CuDim3, Cuint),
f, blocks, threads, shmem)
end

# store the parameters
#
# > Each individual parameter placed in the parameter buffer is required to be aligned.
# > That is, each parameter must be placed at the n-th byte in the parameter buffer,
# > where n is the smallest multiple of the parameter size that is greater than the
# > offset of the last byte taken by the preceding parameter. The maximum size of the
# > parameter buffer is 4KB.
offset = 0
for i in 1:length(args)
buf_index = Base.ceil(Int, offset / sizeof(args[i])) + 1
offset = buf_index * sizeof(args[i])
push!(ex.args, :(
unsafe_store!(Base.unsafe_convert(Ptr{$(args[i])}, buf), args[$i], $buf_index)
))
end

push!(ex.args, :(return buf))

return ex
end


## APIs for manual compilation

Expand Down Expand Up @@ -445,6 +361,39 @@ The following keyword arguments are supported:
Kernel


## dynamic parallelism

dynamic_cufunction(f::Core.Function, tt::Type=Tuple{}) =
ccall("extern cudanativeCompileKernel", llvmcall, Ptr{Cvoid}, (Any, Any), f, tt)

@generated function dynamic_cudacall(f::Ptr{Cvoid}, tt::Type, args...;
blocks::CuDim=1, threads::CuDim=1, shmem::Integer=0,
stream::CuStream=CuDefaultStream())
ex = quote
Base.@_inline_meta
end

# convert the argument values to match the kernel's signature (specified by the user)
# (this mimics `lower-ccall` in julia-syntax.scm)
converted_args = Vector{Symbol}(undef, length(args))
arg_ptrs = Vector{Symbol}(undef, length(args))
for i in 1:length(args)
converted_args[i] = gensym()
arg_ptrs[i] = gensym()
push!(ex.args, :($(converted_args[i]) = Base.cconvert($(args[i]), args[$i])))
push!(ex.args, :($(arg_ptrs[i]) = Base.unsafe_convert($(args[i]), $(converted_args[i]))))
end

append!(ex.args, (quote
#GC.@preserve $(converted_args...) begin
launch(f, blocks, threads, shmem, stream, ($(arg_ptrs...),))
#end
end).args)

return ex
end


## other

"""
Expand Down

0 comments on commit 7267c79

Please sign in to comment.