Skip to content

Commit

Permalink
Adapt JLArray to changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 30, 2018
1 parent e710b47 commit 424a29b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 23 deletions.
75 changes: 58 additions & 17 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,78 @@
# Very simple Julia back-end which is just for testing the implementation and can be used as
# a reference implementation


## construction

struct JLArray{T, N} <: GPUArray{T, N}
data::Array{T, N}
size::Dims{N}
dims::Dims{N}

function JLArray{T,N}(data::Array{T, N}, size::NTuple{N, Int}) where {T,N}
new(data, size)
function JLArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
new(data, dims)
end
end

JLArray(data::AbstractArray{T, N}, size::Dims{N}) where {T,N} = JLArray{T,N}(data, size)
Base.pointer(x::JLArray) = pointer(x.data)

(::Type{<: JLArray{T}})(x::AbstractArray) where T = JLArray(convert(Array{T}, x), size(x))

function JLArray{T, N}(size::NTuple{N, Integer}) where {T, N}
JLArray{T, N}(Array{T, N}(undef, size), size)
end
## construction

struct JLBackend <: GPUBackend end
backend(::Type{<:JLArray}) = JLBackend()
# type and dimensionality specified, accepting dims as tuples of Ints
JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} =
JLArray{T,N}(Array{T, N}(undef, dims), dims)

## getters
# type and dimensionality specified, accepting dims as series of Ints
JLArray{T,N}(::UndefInitializer, dims::Integer...) where {T,N} = JLArray{T,N}(undef, dims)

Base.size(x::JLArray) = x.size
# type but not dimensionality specified
JLArray{T}(::UndefInitializer, dims::Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
JLArray{T}(::UndefInitializer, dims::Integer...) where {T} =
JLArray{T}(undef, convert(Tuple{Vararg{Int}}, dims))

# empty vector constructor
JLArray{T,1}() where {T} = JLArray{T,1}(undef, 0)


Base.similar(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(undef, size(a))
Base.similar(a::JLArray{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
Base.similar(a::JLArray, ::Type{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)


## array interface

Base.size(x::JLArray) = x.dims

Base.pointer(x::JLArray) = pointer(x.data)

## interop with other arrays

## other
JLArray{T,N}(x::AbstractArray{S,N}) where {T,N,S} =
JLArray{T,N}(convert(Array{T}, x), size(x))

# underspecified constructors
JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
(::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x)
JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)

# idempotency
JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs


## conversions

Base.convert(::Type{T}, x::T) where T <: JLArray = x


## broadcast

BroadcastStyle(::Type{<:JLArray}) = ArrayStyle{JLArray}()

function Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}) where T
similar(JLArray{T}, axes(bc))
end


## gpuarray interface

struct JLBackend <: GPUBackend end
backend(::Type{<:JLArray}) = JLBackend()

"""
Thread group local memory
Expand Down
6 changes: 0 additions & 6 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ const GPUDestArray = Union{GPUArray,
LinearAlgebra.Adjoint{<:Any,<:GPUArray},
SubArray{<:Any,<:Any,<:GPUArray}}

# This method is responsible for selection the output type of broadcast
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where
{GPU <: GPUArray, ElType}
similar(GPU, ElType, axes(bc))
end

# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
# can handle:
# - `bc::Broadcast.Broadcasted{Style}`
Expand Down

0 comments on commit 424a29b

Please sign in to comment.