Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Use Adapt.jl for device conversions.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 23, 2018
1 parent 040e5c3 commit 9130f59
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 28 deletions.
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -2,6 +2,7 @@ name = "CUDAnative"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down
3 changes: 3 additions & 0 deletions src/CUDAnative.jl
Expand Up @@ -5,6 +5,9 @@ using CUDAdrv
using LLVM
using LLVM.Interop

using Adapt
struct Adaptor end

using Pkg
using Libdl

Expand Down
3 changes: 2 additions & 1 deletion src/device/array.jl
Expand Up @@ -67,7 +67,8 @@ function Base.convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::CuArray{T,N}) whe
ptr = Base.unsafe_convert(Ptr{T}, Base.cconvert(Ptr{T}, a))
CuDeviceArray{T,N,AS.Global}(a.shape, DevicePtr{T,AS.Global}(ptr))
end
cudaconvert(a::CuArray{T,N}) where {T,N} = convert(CuDeviceArray{T,N,AS.Global}, a)
Adapt.adapt_storage(::CUDAnative.Adaptor, a::CuArray{T,N}) where {T,N} =
convert(CuDeviceArray{T,N,AS.Global}, a)


## indexing
Expand Down
42 changes: 17 additions & 25 deletions src/execution.jl
@@ -1,6 +1,6 @@
# Native execution support

export @cuda, cudaconvert, cufunction, nearest_warpsize
export @cuda, cufunction, nearest_warpsize


## kernel object and query functions
Expand Down Expand Up @@ -124,6 +124,16 @@ function method_age(f, tt)::UInt
end


## adaptors

# Base.RefValue isn't GPU compatible, so provide a compatible alternative
struct CuRefValue{T} <: Ref{T}
x::T
end
Base.getindex(r::CuRefValue) = r.x
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = CuRefValue(adapt(to, r[]))


## high-level @cuda interface

"""
Expand All @@ -132,8 +142,8 @@ end
High-level interface for executing code on a GPU. The `@cuda` macro should prefix a call,
with `func` a callable function or object that should return nothing. It will be compiled to
a CUDA function upon first use, and to a certain extent arguments will be converted and
managed automatically (see [`cudaconvert`](@ref)). Finally, a call to `CUDAdrv.cudacall` is
performed, scheduling a kernel launch on the current CUDA context.
managed automatically using Adapt.jl. Finally, a call to `CUDAdrv.cudacall` is performed,
scheduling a kernel launch on the current CUDA context.
Several keyword arguments are supported that influence kernel compilation and execution. For
more information, refer to the documentation of respectively [`cufunction`](@ref) and
Expand All @@ -143,9 +153,11 @@ The underlying operations (argument conversion, kernel compilation, kernel call)
performed explicitly when more control is needed, e.g. to reflect on the resource usage of a
kernel to determine the launch configuration:
using Adapt
args = ...
GC.@preserve args begin
kernel_args = cudaconvert.(args)
kernel_args = Tuple(adapt(CUDAnative.Adaptor(), arg) for arg in kernel_args)
kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
kernel = CUDAnative.cufunction(f, kernel_tt; compilation_kwargs)
kernel(kernel_args...; launch_kwargs)
Expand Down Expand Up @@ -176,7 +188,7 @@ macro cuda(ex...)
push!(code.args,
quote
GC.@preserve $(vars...) begin
$kernel_args = cudaconvert.(($(var_exprs...),))
$kernel_args = Tuple(adapt(Adaptor(), var) for var in ($(var_exprs...),))
$kernel_tt = Tuple{Core.Typeof.($kernel_args)...}
$kernel = cufunction($(esc(f)), $kernel_tt; $(map(esc, compiler_kwargs)...))
$kernel($kernel_args...; $(map(esc, call_kwargs)...))
Expand All @@ -188,26 +200,6 @@ end

## APIs for manual compilation

"""
cudaconvert(x)
Low-level interface to convert values to a representation that is GPU compatible.
For a higher-level interface, use [`@cuda`](@ref).
By default, CUDAnative does only provide a minimal set of conversions for elementary types
such as tuples. If you need your type to convert before execution on a GPU, be sure to add
methods to this function.
For the time being, conversions for `CUDAdrv.CuArray` objects are also provided, returning a
corresponding `CuDeviceArray` object in global memory. This will be deprecated in favor of
functionality from the CuArrays.jl package.
"""
cudaconvert(x) = x
cudaconvert(x::Tuple) = cudaconvert.(x)
@generated function cudaconvert(x::NamedTuple)
Expr(:tuple, (:($f=cudaconvert(x.$f)) for f in fieldnames(x))...)
end

const agecache = Dict{UInt, UInt}()
const compilecache = Dict{UInt, Kernel}()

Expand Down
2 changes: 1 addition & 1 deletion test/device/execution.jl
Expand Up @@ -364,7 +364,7 @@ end
@eval struct Host end
@eval struct Device end

CUDAnative.cudaconvert(a::Host) = Device()
Adapt.adapt_storage(::CUDAnative.Adaptor, a::Host) = Device()

Base.convert(::Type{Int}, ::Host) = 1
Base.convert(::Type{Int}, ::Device) = 2
Expand Down
3 changes: 2 additions & 1 deletion test/util.jl
Expand Up @@ -100,7 +100,8 @@ function Base.Array(src::CuTestArray{T,N}) where {T,N}
return dst
end
## conversions
function CUDAnative.cudaconvert(a::CuTestArray{T,N}) where {T,N}
using Adapt
function Adapt.adapt_storage(::CUDAnative.Adaptor, a::CuTestArray{T,N}) where {T,N}
ptr = Base.unsafe_convert(Ptr{T}, a.buf)
devptr = CUDAnative.DevicePtr{T,AS.Global}(ptr)
CuDeviceArray{T,N,AS.Global}(a.shape, devptr)
Expand Down

0 comments on commit 9130f59

Please sign in to comment.