Skip to content

Commit

Permalink
Don't assume host pointers are directly usable on the device. (#1342)
Browse files Browse the repository at this point in the history
Since CUDA.jl 3.4, we now always reconstruct a buffer object when
creating an array from a pointer. In the case of a HostBuffer,
that means we want the host pointer in the buffer object,
which is converted to a device pointer on request.

However, unsafe_wrap is invoked with a device pointer. Generally,
those pointers are identical, but on GPUs where
CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM==0,
that is not the case.

Instead, recover the host pointer from the device pointer passed to
unsafe_wrap. This is going back and forth between both pointer
representations, but those calls are reasonably cheap.
  • Loading branch information
maleadt authored Jan 26, 2022
1 parent 0c9d886 commit 430787d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
18 changes: 18 additions & 0 deletions lib/cudadrv/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,24 @@ memory_type(x) = CUmemorytype(attribute(Cuint, x, POINTER_ATTRIBUTE_MEMORY_TYPE)

is_managed(x) = convert(Bool, attribute(Cuint, x, POINTER_ATTRIBUTE_IS_MANAGED))

"""
host_pointer(ptr::CuPtr)
Returns the host pointer value through which `ptr`` may be accessed by by the
host program.
"""
host_pointer(x::CuPtr{T}) where {T} =
attribute(Ptr{T}, x, POINTER_ATTRIBUTE_HOST_POINTER)

"""
device_pointer(ptr::Ptr)
Returns the device pointer value through which `ptr` may be accessed by kernels
running in the current context.
"""
device_pointer(x::Ptr{T}) where {T} =
attribute(CuPtr{T}, x, POINTER_ATTRIBUTE_HOST_POINTER)

function is_pinned(ptr::Ptr)
# unpinned memory makes cuPointerGetAttribute return ERROR_INVALID_VALUE; but instead of
# calling `memory_type` with an expensive try/catch we perform low-level API calls.
Expand Down
10 changes: 5 additions & 5 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,
# TODO: can we identify whether this pointer was allocated asynchronously?
Mem.DeviceBuffer(ptr, sz, false)
elseif typ == CU_MEMORYTYPE_HOST
Mem.HostBuffer(reinterpret(Ptr{T}, ptr), sz)
Mem.HostBuffer(host_pointer(ptr), sz)
else
error("Unknown memory type; please file an issue.")
end
Expand Down Expand Up @@ -444,10 +444,10 @@ function Base.unsafe_copyto!(dest::DenseCuArray{T,<:Any,<:Union{Mem.UnifiedBuffe

GC.@preserve src dest begin
cpu_ptr = pointer(src, soffs)
unsafe_copyto!(reinterpret(typeof(cpu_ptr), pointer(dest, doffs)), cpu_ptr, n)
unsafe_copyto!(host_pointer(pointer(dest, doffs)), cpu_ptr, n)
if Base.isbitsunion(T)
cpu_ptr = typetagdata(src, soffs)
unsafe_copyto!(reinterpret(typeof(cpu_ptr), typetagdata(dest, doffs)), cpu_ptr, n)
unsafe_copyto!(host_pointer(typetagdata(dest, doffs)), cpu_ptr, n)
end
end
return dest
Expand All @@ -460,10 +460,10 @@ function Base.unsafe_copyto!(dest::Array{T}, doffs,

GC.@preserve src dest begin
cpu_ptr = pointer(dest, doffs)
unsafe_copyto!(cpu_ptr, reinterpret(typeof(cpu_ptr), pointer(src, soffs)), n)
unsafe_copyto!(cpu_ptr, host_pointer(pointer(src, soffs)), n)
if Base.isbitsunion(T)
cpu_ptr = typetagdata(dest, doffs)
unsafe_copyto!(cpu_ptr, reinterpret(typeof(cpu_ptr), typetagdata(src, soffs)), n)
unsafe_copyto!(cpu_ptr, host_pointer(typetagdata(src, soffs)), n)
end
end

Expand Down
9 changes: 9 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,15 @@ end
end
end

@testset "issue: invalid handling of device pointers" begin
# failed when DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM == 0
cpu = rand(2,2)
buf = Mem.register(Mem.Host, pointer(cpu), sizeof(cpu), Mem.HOSTREGISTER_DEVICEMAP)
gpu_ptr = convert(CuPtr{eltype(cpu)}, buf)
gpu = unsafe_wrap(CuArray, gpu_ptr, size(cpu))
@test Array(gpu) == cpu
end

if length(devices()) > 1
@testset "multigpu" begin
dev = device()
Expand Down

0 comments on commit 430787d

Please sign in to comment.