diff --git a/src/array.jl b/src/array.jl index c76045480..e204a8594 100644 --- a/src/array.jl +++ b/src/array.jl @@ -3,6 +3,16 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N} dims::Dims{N} offset::Int # Offset is in number of elements (not bytes). + function ROCArray{T, N, B}( + ::UndefInitializer, dims::Dims{N}, + ) where {T, N, B <: Union{Mem.HIPBuffer, Mem.HostBuffer}} + @assert isbitstype(T) "ROCArray only supports bits types" + buf = B(prod(dims) * sizeof(T); stream=stream()) + xs = new{T, N, B}(DataRef(_free_buf, buf), dims, 0) + finalizer(unsafe_finalize!, xs) + return xs + end + function ROCArray{T, N}( buf::DataRef{B}, dims::Dims{N}; offset::Integer = 0, ) where {T, N, B <: Union{Mem.HIPBuffer, Mem.HostBuffer}} @@ -40,8 +50,9 @@ const DenseROCMatrix{T} = DenseROCArray{T,2} const DenseROCVecOrMat{T} = Union{DenseROCVector{T}, DenseROCMatrix{T}} # strided arrays -const StridedSubROCArray{T,N,I<:Tuple{Vararg{Union{Base.RangeIndex, Base.ReshapedUnitRange, - Base.AbstractCartesianIndex}}}} = SubArray{T,N,<:ROCArray,I} +const StridedSubROCArray{T,N,I<:Tuple{Vararg{Union{ + Base.RangeIndex, Base.ReshapedUnitRange, Base.AbstractCartesianIndex, +}}}} = SubArray{T,N,<:ROCArray,I} const StridedROCArray{T,N} = Union{ROCArray{T,N}, StridedSubROCArray{T,N}} const StridedROCVector{T} = StridedROCArray{T,1} const StridedROCMatrix{T} = StridedROCArray{T,2} @@ -61,6 +72,12 @@ function ROCArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} ROCArray{T, N}(DataRef(_free_buf, buf), dims) end +# buffer, type and dimensionality specified +ROCArray{T,N,B}(::UndefInitializer, dims::NTuple{N, Integer}) where {T,N,B} = + ROCArray{T,N,B}(undef, convert(Tuple{Vararg{Int}}, dims)) +ROCArray{T,N,B}(::UndefInitializer, dims::Vararg{Integer, N}) where {T,N,B} = + ROCArray{T,N,B}(undef, convert(Tuple{Vararg{Int}}, dims)) + # type and dimensionality specified ROCArray{T,N}(::UndefInitializer, dims::NTuple{N, Integer}) where {T,N} = ROCArray{T,N}(undef, convert(Tuple{Vararg{Int}}, dims)) diff --git a/test/rocarray/base.jl b/test/rocarray/base.jl index 1cead284c..f5124aa1a 100644 --- a/test/rocarray/base.jl +++ b/test/rocarray/base.jl @@ -1,5 +1,15 @@ @testset "Base" begin +@testset "Specifying buffer type" begin + B = AMDGPU.Runtime.Mem.HIPBuffer + x = ROCArray{Float32, 2, B}(undef, 16, 12) + @test size(x) == (16, 12) + @test x.buf[] isa B + x = ROCArray{Float32, 2, B}(undef, (16, 12)) + @test size(x) == (16, 12) + @test x.buf[] isa B +end + @testset "ones/zeros" begin x = @inferred AMDGPU.ones(4, 3) @test x isa ROCArray