diff --git a/src/abstract_axis.jl b/src/abstract_axis.jl index c68deb8..b94efea 100644 --- a/src/abstract_axis.jl +++ b/src/abstract_axis.jl @@ -63,6 +63,7 @@ Base.pairs(axis::AbstractAxis) = Base.Iterators.Pairs(a, keys(axis)) # This is required for performing `similar` on arrays Base.to_shape(axis::AbstractAxis) = length(axis) +Base.to_shape(r::IdentityUnitRange) = length(r) Base.haskey(axis::AbstractAxis, key) = key in keys(axis) diff --git a/src/abstractarray.jl b/src/abstractarray.jl index e5ac05c..b9c2ae3 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -175,7 +175,7 @@ function Base.similar(A::AxisArray) return unsafe_reconstruct(A, p; axes=map(assign_indices, axes(A), axes(p))) end -function Base.similar(::Type{T}, shape::Tuple{DimOrAxes,Vararg{DimOrAxes}}) where {T<:AbstractArray} +function Base.similar(::Type{T}, shape::S) where {T<:AbstractArray, DoA<:DimOrAxes, S<:Tuple{DoA,Vararg{DoA}}} p = similar(T, Base.to_shape(shape)) axs = map((key, axis) -> compose_axis(key, axis, NoChecks), shape, axes(p)) return AxisArray{eltype(p),ndims(p),typeof(p),typeof(axs)}(p, axs; checks=NoChecks)