-
Notifications
You must be signed in to change notification settings - Fork 3
/
FixedSizeArrays.jl
319 lines (260 loc) · 10.3 KB
/
FixedSizeArrays.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
module FixedSizeArrays
export FixedSizeArray, FixedSizeVector, FixedSizeMatrix
public collect_as
"""
Internal()
Implementation detail. Do not use.
"""
struct Internal end
struct FixedSizeArray{T,N} <: DenseArray{T,N}
mem::Memory{T}
size::NTuple{N,Int}
function FixedSizeArray{T,N}(::Internal, mem::Memory{T}, size::NTuple{N,Int}) where {T,N}
new{T,N}(mem, size)
end
end
const FixedSizeVector{T} = FixedSizeArray{T,1}
const FixedSizeMatrix{T} = FixedSizeArray{T,2}
function FixedSizeArray{T,N}(::UndefInitializer, size::NTuple{N,Int}) where {T,N}
FixedSizeArray{T,N}(Internal(), Memory{T}(undef, checked_dims(size)), size)
end
function FixedSizeArray{T,N}(::UndefInitializer, size::NTuple{N,Integer}) where {T,N}
ints = map(Int, size)::NTuple{N,Int} # prevent infinite recursion
FixedSizeArray{T,N}(undef, ints)
end
function FixedSizeArray{T,N}(::UndefInitializer, size::Vararg{Integer,N}) where {T,N}
FixedSizeArray{T,N}(undef, size)
end
function FixedSizeArray{T}(::UndefInitializer, size::Vararg{Integer,N}) where {T,N}
FixedSizeArray{T,N}(undef, size)
end
function FixedSizeArray{T}(::UndefInitializer, size::NTuple{N,Integer}) where {T,N}
FixedSizeArray{T,N}(undef, size)
end
Base.IndexStyle(::Type{<:FixedSizeArray}) = IndexLinear()
Base.@propagate_inbounds Base.getindex(A::FixedSizeArray, i::Int) = A.mem[i]
Base.@propagate_inbounds Base.setindex!(A::FixedSizeArray, v, i::Int) = A.mem[i] = v
Base.size(a::FixedSizeArray) = a.size
function Base.similar(::FixedSizeArray, ::Type{S}, size::NTuple{N,Int}) where {S,N}
FixedSizeArray{S,N}(undef, size)
end
Base.isassigned(a::FixedSizeArray, i::Int) = isassigned(a.mem, i)
# safe product of a tuple of integers, for calculating dimensions size
checked_dims_impl(a::Int, ::Tuple{}, have_overflow::Bool) = (a, have_overflow)
function checked_dims_impl(a::Int, t::Tuple{Int,Vararg{Int,N}}, have_overflow::Bool) where {N}
b = first(t)
(m, o) = Base.Checked.mul_with_overflow(a, b)
r = Base.tail(t)::NTuple{N,Int}
checked_dims_impl(m, r, have_overflow | o)::Tuple{Int,Bool}
end
checked_dims(::Tuple{}) = 1
function checked_dims(t::Tuple{Int,Vararg{Int,N}}) where {N}
any_is_zero = any(iszero, t)::Bool
any_is_negative = any((x -> x < false), t)::Bool
any_is_typemax = any((x -> x == typemax(x)), t)::Bool
a = first(t)
r = Base.tail(t)::NTuple{N,Int}
(product, have_overflow) = checked_dims_impl(a, r, false)::Tuple{Int,Bool}
if any_is_negative
throw(ArgumentError("array dimension size can't be negative"))
end
if any_is_typemax
throw(ArgumentError("array dimension size can't be the maximum representable value"))
end
if have_overflow & !any_is_zero
throw(ArgumentError("array dimensions too great, can't represent length"))
end
product
end
# broadcasting
function Base.BroadcastStyle(::Type{<:FixedSizeArray})
Broadcast.ArrayStyle{FixedSizeArray}()
end
function Base.similar(
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{FixedSizeArray}},
::Type{E},
) where {E}
similar(FixedSizeArray{E}, axes(bc))
end
# helper functions
dimension_count_of(::Base.SizeUnknown) = 1
dimension_count_of(::Base.HasLength) = 1
dimension_count_of(::Base.HasShape{N}) where {N} = convert(Int, N)::Int
struct LengthIsUnknown end
struct LengthIsKnown end
length_status(::Base.SizeUnknown) = LengthIsUnknown()
length_status(::Base.HasLength) = LengthIsKnown()
length_status(::Base.HasShape) = LengthIsKnown()
function check_count_value(n::Int)
if n < 0
throw(ArgumentError("count can't be negative"))
end
end
function check_count_value(n)
throw(ArgumentError("count must be an `Int`"))
end
struct SpecFSA{T,N} end
function fsa_spec_from_type(::Type{FixedSizeArray})
SpecFSA{nothing,nothing}()
end
function fsa_spec_from_type(::Type{FixedSizeArray{<:Any,M}}) where {M}
check_count_value(M)
SpecFSA{nothing,M}()
end
function fsa_spec_from_type(::Type{FixedSizeArray{E}}) where {E}
SpecFSA{E::Type,nothing}()
end
function fsa_spec_from_type(::Type{FixedSizeArray{E,M}}) where {E,M}
check_count_value(M)
SpecFSA{E::Type,M}()
end
parent_type(::Type{<:FixedSizeArray{T}}) where {T} = Memory{T}
underlying_storage(m) = m
underlying_storage(f::FixedSizeArray) = f.mem
axes_are_one_based(axes) = all(isone ∘ first, axes)
# converting constructors for copying other array types
function FixedSizeArray{T,N}(src::AbstractArray{S,N}) where {T,N,S}
axs = axes(src)
if !axes_are_one_based(axs)
throw(DimensionMismatch("source array has a non-one-based indexing axis"))
end
# Can't use `Base.size` because, according to it's doc string, it's not
# available for all `AbstractArray` types.
size = map(length, axs)
dst = FixedSizeArray{T,N}(undef, size)
copyto!(dst.mem, src)
dst
end
FixedSizeArray{T}(a::AbstractArray{<:Any,N}) where {T,N} = FixedSizeArray{T,N}(a)
FixedSizeArray{<:Any,N}(a::AbstractArray{T,N}) where {T,N} = FixedSizeArray{T,N}(a)
FixedSizeArray(a::AbstractArray{T,N}) where {T,N} = FixedSizeArray{T,N}(a)
# conversion
Base.convert(::Type{T}, a::T) where {T<:FixedSizeArray} = a
Base.convert(::Type{T}, a::AbstractArray) where {T<:FixedSizeArray} = T(a)::T
# `copyto!`
Base.@propagate_inbounds function copyto5!(dst, doff, src, soff, n)
copyto!(underlying_storage(dst), doff, underlying_storage(src), soff, n)
dst
end
Base.@propagate_inbounds function copyto2!(dst, src)
copyto!(underlying_storage(dst), underlying_storage(src))
dst
end
Base.@propagate_inbounds Base.copyto!(dst::FixedSizeArray, doff::Integer, src::FixedSizeArray, soff::Integer, n::Integer) = copyto5!(dst, doff, src, soff, n)
Base.@propagate_inbounds Base.copyto!(dst::FixedSizeArray, src::FixedSizeArray) = copyto2!(dst, src)
for A ∈ (Array, GenericMemory) # Add more? Too bad we have to hardcode to avoid ambiguity.
@eval begin
Base.@propagate_inbounds Base.copyto!(dst::FixedSizeArray, doff::Integer, src::$A, soff::Integer, n::Integer) = copyto5!(dst, doff, src, soff, n)
Base.@propagate_inbounds Base.copyto!(dst::$A, doff::Integer, src::FixedSizeArray, soff::Integer, n::Integer) = copyto5!(dst, doff, src, soff, n)
Base.@propagate_inbounds Base.copyto!(dst::FixedSizeArray, src::$A ) = copyto2!(dst, src)
Base.@propagate_inbounds Base.copyto!(dst::$A, src::FixedSizeArray) = copyto2!(dst, src)
end
end
# unsafe: the native address of the array's storage
Base.cconvert(::Type{<:Ptr}, a::FixedSizeArray) = a.mem
# `elsize`: part of the strided arrays interface, used for `pointer`
Base.elsize(::Type{A}) where {A<:FixedSizeArray} = Base.elsize(parent_type(A))
# `reshape`: specializing it to ensure it returns a `FixedSizeArray`
function Base.reshape(a::FixedSizeArray{T}, size::NTuple{N,Int}) where {T,N}
len = checked_dims(size)
if length(a) != len
throw(DimensionMismatch("new shape not consistent with existing array length"))
end
FixedSizeArray{T,N}(Internal(), a.mem, size)
end
# `collect_as`
function collect_as_fsa0(iterator, ::Val{nothing})
x = only(iterator)
ret = FixedSizeArray{typeof(x),0}(undef)
ret[] = x
ret
end
function collect_as_fsa0(iterator, ::Val{E}) where {E}
E::Type
x = only(iterator)
ret = FixedSizeArray{E,0}(undef)
ret[] = x
ret
end
function fill_fsa_from_iterator!(a, iterator)
actual_count = 0
for e ∈ iterator
actual_count += 1
a[actual_count] = e
end
if actual_count != length(a)
throw(ArgumentError("`size`-`length` inconsistency"))
end
end
function collect_as_fsam_with_shape(
iterator, ::SpecFSA{nothing,M}, shape::Tuple{Vararg{Int}},
) where {M}
E = eltype(iterator)::Type
ret = FixedSizeArray{E,M}(undef, shape)
fill_fsa_from_iterator!(ret, iterator)
map(identity, ret)::FixedSizeArray{<:Any,M}
end
function collect_as_fsam_with_shape(
iterator, ::SpecFSA{E,M}, shape::Tuple{Vararg{Int}},
) where {E,M}
E::Type
ret = FixedSizeArray{E,M}(undef, shape)
fill_fsa_from_iterator!(ret, iterator)
ret::FixedSizeArray{E,M}
end
function collect_as_fsam(iterator, spec::SpecFSA{<:Any,M}) where {M}
check_count_value(M)
shape = if isone(M)
(length(iterator),)
else
size(iterator)
end::NTuple{M,Any}
shap = map(Int, shape)::NTuple{M,Int}
collect_as_fsam_with_shape(iterator, spec, shap)::FixedSizeArray{<:Any,M}
end
function collect_as_fsa1_from_unknown_length(iterator, ::Val{nothing})
v = collect(iterator)::AbstractVector
T = FixedSizeVector
map(identity, T(v))::T
end
function collect_as_fsa1_from_unknown_length(iterator, ::Val{E}) where {E}
E::Type
v = collect(E, iterator)::AbstractVector{E}
T = FixedSizeVector{E}
T(v)::T
end
function collect_as_fsa_impl(iterator, ::SpecFSA{E,0}, ::LengthIsKnown) where {E}
collect_as_fsa0(iterator, Val(E))::FixedSizeArray{<:Any,0}
end
function collect_as_fsa_impl(iterator, spec::SpecFSA, ::LengthIsKnown)
collect_as_fsam(iterator, spec)::FixedSizeArray
end
function collect_as_fsa_impl(iterator, ::SpecFSA{E,1}, ::LengthIsUnknown) where {E}
collect_as_fsa1_from_unknown_length(iterator, Val(E))::FixedSizeVector
end
function collect_as_fsa_checked(iterator, ::SpecFSA{E,nothing}, ::Val{M}, length_status) where {E,M}
check_count_value(M)
collect_as_fsa_impl(iterator, SpecFSA{E,M}(), length_status)::FixedSizeArray{<:Any,M}
end
function collect_as_fsa_checked(iterator, ::SpecFSA{E,M}, ::Val{M}, length_status) where {E,M}
check_count_value(M)
collect_as_fsa_impl(iterator, SpecFSA{E,M}(), length_status)::FixedSizeArray{<:Any,M}
end
"""
collect_as(t::Type{<:FixedSizeArray}, iterator)
Tries to construct a value of type `t` from the iterator `iterator`. The type `t`
must either be concrete, or a `UnionAll` without constraints.
"""
function collect_as(::Type{T}, iterator) where {T<:FixedSizeArray}
spec = fsa_spec_from_type(T)::SpecFSA
size_class = Base.IteratorSize(iterator)
if size_class == Base.IsInfinite()
throw(ArgumentError("iterator is infinite, can't fit infinitely many elements into a `FixedSizeArray`"))
end
dim_count_int = dimension_count_of(size_class)
check_count_value(dim_count_int)
dim_count = Val(dim_count_int)::Val
len_stat = length_status(size_class)
collect_as_fsa_checked(iterator, spec, dim_count, len_stat)::T
end
end # module FixedSizeArrays