Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve the buffer type when broadcasting. #383

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"

[compat]
Adapt = "2.0, 3.0"
Adapt = "4"
CEnum = "0.4, 0.5"
ExprTools = "0.1"
GPUArrays = "9"
GPUArrays = "10"
GPUCompiler = "0.23, 0.24, 0.25"
KernelAbstractions = "0.9.1"
LLVM = "6"
Expand Down
86 changes: 72 additions & 14 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export oneArray, oneVector, oneMatrix, oneVecOrMat
export oneArray, oneVector, oneMatrix, oneVecOrMat,
is_device, is_shared, is_host


## array type
Expand Down Expand Up @@ -168,6 +169,12 @@ function device(A::oneArray)
return oneL0.device(A.data[])
end

buftype(x::oneArray) = buftype(typeof(x))
buftype(::Type{<:oneArray{<:Any,<:Any,B}}) where {B} = @isdefined(B) ? B : Any

is_device(a::oneArray) = isa(a.data[], oneL0.DeviceBuffer)
is_shared(a::oneArray) = isa(a.data[], oneL0.SharedBuffer)
is_host(a::oneArray) = isa(a.data[], oneL0.HostBuffer)

## derived types

Expand Down Expand Up @@ -195,9 +202,15 @@ const oneStridedVector{T} = oneStridedArray{T,1}
const oneStridedMatrix{T} = oneStridedArray{T,2}
const oneStridedVecOrMat{T} = Union{oneStridedVector{T}, oneStridedMatrix{T}}

Base.pointer(x::oneStridedArray{T}) where {T} = Base.unsafe_convert(ZePtr{T}, x)
@inline function Base.pointer(x::oneStridedArray{T}, i::Integer) where T
Base.unsafe_convert(ZePtr{T}, x) + Base._memory_offset(x, i)
@inline function Base.pointer(x::oneStridedArray{T}, i::Integer=1; type=oneL0.DeviceBuffer) where T
PT = if type == oneL0.DeviceBuffer
ZePtr{T}
elseif type == oneL0.HostBuffer
Ptr{T}
else
error("unknown memory type")
end
Base.unsafe_convert(PT, x) + Base._memory_offset(x, i)
end

# anything that's (secretly) backed by a oneArray
Expand Down Expand Up @@ -241,12 +254,20 @@ oneL0.ZeRef{T}() where {T} = oneL0.ZeRefArray(oneArray{T}(undef, 1))
Base.convert(::Type{T}, x::T) where T <: oneArray = x


## interop with C libraries
## interop with libraries

function Base.unsafe_convert(::Type{Ptr{T}}, x::oneArray{T}) where {T}
buf = x.data[]
if is_device(x)
throw(ArgumentError("cannot take the CPU address of a $(typeof(x))"))
end
convert(Ptr{T}, x.data[]) + x.offset*Base.elsize(x)
end

Base.unsafe_convert(::Type{Ptr{T}}, x::oneArray{T}) where {T} =
throw(ArgumentError("cannot take the host address of a $(typeof(x))"))
Base.unsafe_convert(::Type{ZePtr{T}}, x::oneArray{T}) where {T} =
function Base.unsafe_convert(::Type{ZePtr{T}}, x::oneArray{T}) where {T}
convert(ZePtr{T}, x.data[]) + x.offset*Base.elsize(x)
end



## interop with GPU arrays
Expand All @@ -256,9 +277,6 @@ function Base.unsafe_convert(::Type{oneDeviceArray{T,N,AS.Global}}, a::oneArray{
a.maxsize - a.offset*Base.elsize(a))
end

Adapt.adapt_storage(::KernelAdaptor, xs::oneArray{T,N}) where {T,N} =
Base.unsafe_convert(oneDeviceArray{T,N,AS.Global}, xs)


## memory copying

Expand Down Expand Up @@ -310,7 +328,7 @@ Base.copyto!(dest::oneDenseArray{T}, src::oneDenseArray{T}) where {T} =
copyto!(dest, 1, src, 1, length(src))

function Base.unsafe_copyto!(ctx::ZeContext, dev::ZeDevice,
dest::oneDenseArray{T}, doffs, src::Array{T}, soffs, n) where T
dest::oneDenseArray{T,<:Any,oneL0.DeviceBuffer}, doffs, src::Array{T}, soffs, n) where T
GC.@preserve src dest unsafe_copyto!(ctx, dev, pointer(dest, doffs), pointer(src, soffs), n)
if Base.isbitsunion(T)
# copy selector bytes
Expand All @@ -320,7 +338,7 @@ function Base.unsafe_copyto!(ctx::ZeContext, dev::ZeDevice,
end

function Base.unsafe_copyto!(ctx::ZeContext, dev::ZeDevice,
dest::Array{T}, doffs, src::oneDenseArray{T}, soffs, n) where T
dest::Array{T}, doffs, src::oneDenseArray{T,<:Any,oneL0.DeviceBuffer}, soffs, n) where T
GC.@preserve src dest unsafe_copyto!(ctx, dev, pointer(dest, doffs), pointer(src, soffs), n)
if Base.isbitsunion(T)
# copy selector bytes
Expand All @@ -343,6 +361,46 @@ function Base.unsafe_copyto!(ctx::ZeContext, dev::ZeDevice,
return dest
end

# between Array and host-accessible oneArray

function Base.unsafe_copyto!(ctx::ZeContext, dev,
dest::oneDenseArray{T,<:Any,<:Union{oneL0.SharedBuffer,oneL0.HostBuffer}}, doffs, src::Array{T}, soffs, n) where T
if Base.isbitsunion(T)
# copy selector bytes
error("oneArray does not yet support isbits-union arrays")
end
# XXX: maintain queue-ordered semantics? HostBuffers don't have a device...
GC.@preserve src dest begin
ptr = pointer(dest, doffs)
unsafe_copyto!(pointer(dest, doffs; type=oneL0.HostBuffer), pointer(src, soffs), n)
if Base.isbitsunion(T)
# copy selector bytes
error("oneArray does not yet support isbits-union arrays")
end
end

return dest
end

function Base.unsafe_copyto!(ctx::ZeContext, dev,
dest::Array{T}, doffs, src::oneDenseArray{T,<:Any,<:Union{oneL0.SharedBuffer,oneL0.HostBuffer}}, soffs, n) where T
if Base.isbitsunion(T)
# copy selector bytes
error("oneArray does not yet support isbits-union arrays")
end
# XXX: maintain queue-ordered semantics? HostBuffers don't have a device...
GC.@preserve src dest begin
ptr = pointer(dest, doffs)
unsafe_copyto!(pointer(dest, doffs), pointer(src, soffs; type=oneL0.HostBuffer), n)
if Base.isbitsunion(T)
# copy selector bytes
error("oneArray does not yet support isbits-union arrays")
end
end

return dest
end


## gpu array adaptor

Expand Down Expand Up @@ -375,7 +433,7 @@ end

## derived arrays

function GPUArrays.derive(::Type{T}, N::Int, a::oneArray, dims::Dims, offset::Int) where {T}
function GPUArrays.derive(::Type{T}, a::oneArray, dims::Dims{N}, offset::Int) where {T,N}
offset = (a.offset * Base.elsize(a)) ÷ sizeof(T) + offset
oneArray{T,N}(a.data, dims; a.maxsize, offset)
end
Expand Down
26 changes: 14 additions & 12 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

using Base.Broadcast: BroadcastStyle, Broadcasted

struct oneArrayStyle{N} <: AbstractGPUArrayStyle{N} end
oneArrayStyle(::Val{N}) where N = oneArrayStyle{N}()
oneArrayStyle{M}(::Val{N}) where {N,M} = oneArrayStyle{N}()
struct oneArrayStyle{N,B} <: AbstractGPUArrayStyle{N} end
oneArrayStyle{M,B}(::Val{N}) where {N,M,B} = oneArrayStyle{N,B}()

BroadcastStyle(::Type{<:oneArray{T,N}}) where {T,N} = oneArrayStyle{N}()
# identify the broadcast style of a (wrapped) oneArray
BroadcastStyle(::Type{<:oneArray{T,N,B}}) where {T,N,B} = oneArrayStyle{N,B}()
BroadcastStyle(W::Type{<:oneWrappedArray{T,N}}) where {T,N} =
oneArrayStyle{N, buftype(Adapt.unwrap_type(W))}()

Base.similar(bc::Broadcasted{oneArrayStyle{N}}, ::Type{T}) where {N,T} =
similar(oneArray{T}, axes(bc))
# when we are dealing with different buffer styles, we cannot know
# which one is better, so use shared memory
BroadcastStyle(::oneArrayStyle{N, B1},
::oneArrayStyle{N, B2}) where {N,B1,B2} =
oneArrayStyle{N, oneL0.SharedBuffer}()

Base.similar(bc::Broadcasted{oneArrayStyle{N}}, ::Type{T}, dims...) where {N,T} =
oneArray{T}(undef, dims...)

# broadcasting type ctors isn't GPU compatible
Broadcast.broadcasted(::oneArrayStyle{N}, f::Type{T}, args...) where {N, T} =
Broadcasted{oneArrayStyle{N}}((x...) -> T(x...), args, nothing)
# allocation of output arrays
Base.similar(bc::Broadcasted{oneArrayStyle{N,B}}, ::Type{T}, dims) where {T,N,B} =
similar(oneArray{T,length(dims),B}, dims)
14 changes: 12 additions & 2 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ end

struct KernelAdaptor end

# convert oneL0 host pointers to device pointers
# convert oneAPI host pointers to device pointers
Adapt.adapt_storage(to::KernelAdaptor, p::ZePtr{T}) where {T} = reinterpret(Ptr{T}, p)

# Base.RefValue isn't GPU compatible, so provide a compatible alternative
# convert oneAPI host arrays to device arrays
Adapt.adapt_storage(::KernelAdaptor, xs::oneArray{T,N}) where {T,N} =
Base.unsafe_convert(oneDeviceArray{T,N,AS.Global}, xs)

# Base.RefValue isn't GPU compatible, so provide a compatible alternative.
# TODO: port improvements from CUDA.jl
struct ZeRefValue{T} <: Ref{T}
x::T
end
Expand All @@ -100,6 +105,11 @@ Base.getindex(r::oneRefType{T}) where T = T
Adapt.adapt_structure(to::KernelAdaptor, r::Base.RefValue{<:Union{DataType,Type}}) =
oneRefType{r[]}()

# case where type is the function being broadcasted
Adapt.adapt_structure(to::KernelAdaptor,
bc::Broadcast.Broadcasted{Style, <:Any, Type{T}}) where {Style, T} =
Broadcast.Broadcasted{Style}((x...) -> T(x...), adapt(to, bc.args), bc.axes)

"""
kernel_convert(x)

Expand Down
10 changes: 5 additions & 5 deletions src/oneAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ include("device/quirks.jl")
# essential stuff
include("context.jl")

# compiler implementation
include("compiler/compilation.jl")
include("compiler/execution.jl")
include("compiler/reflection.jl")

# array abstraction
include("memory.jl")
include("pool.jl")
include("array.jl")

# compiler implementation
include("compiler/compilation.jl")
include("compiler/execution.jl")
include("compiler/reflection.jl")

# array libraries
include("../lib/mkl/oneMKL.jl")
export oneMKL
Expand Down
13 changes: 10 additions & 3 deletions src/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ function allocate(::Type{oneL0.SharedBuffer}, ctx, dev, bytes::Int, alignment::I
return buf
end

function allocate(::Type{oneL0.HostBuffer}, ctx, dev, bytes::Int, alignment::Int)
bytes == 0 && return oneL0.HostBuffer(ZE_NULL, bytes, ctx)
host_alloc(ctx, bytes, alignment)
end

function release(buf::oneL0.AbstractBuffer)
sizeof(buf) == 0 && return

ctx = oneL0.context(buf)
dev = oneL0.device(buf)
if buf isa oneL0.DeviceBuffer || buf isa oneL0.SharedBuffer
ctx = oneL0.context(buf)
dev = oneL0.device(buf)
evict(ctx, dev, buf)
end

evict(ctx, dev, buf)
free(buf; policy=oneL0.ZE_DRIVER_MEMORY_FREE_POLICY_EXT_FLAG_BLOCKING_FREE)

# TODO: queue-ordered free from non-finalizer tasks once we have
Expand Down
16 changes: 16 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,19 @@ end
oneAPI.@sync copyto!(a, 2, [200], 1, 1)
@test b == [100, 200]
end

# https://github.com/JuliaGPU/CUDA.jl/issues/2191
@testset "preserving buffer types" begin
a = oneVector{Int,oneL0.SharedBuffer}([1])
@test oneAPI.buftype(a) == oneL0.SharedBuffer

# unified-ness should be preserved
b = a .+ 1
@test oneAPI.buftype(b) == oneL0.SharedBuffer

# when there's a conflict, we should defer to unified memory
c = oneVector{Int,oneL0.HostBuffer}([1])
d = oneVector{Int,oneL0.DeviceBuffer}([1])
e = c .+ d
@test oneAPI.buftype(e) == oneL0.SharedBuffer
end