Skip to content

Commit

Permalink
Fix for iterators with shape
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Sep 24, 2019
1 parent c771e5d commit b6b9fa3
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/Containers/Containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ export DenseAxisArray, SparseAxisArray
include("DenseAxisArray.jl")
include("SparseAxisArray.jl")
include("generate_container.jl")
include("vectorized_product_iterator.jl")
include("nested_iterator.jl")
include("container.jl")
include("macro.jl")

Expand Down
2 changes: 0 additions & 2 deletions src/Containers/SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

include("nested_iterator.jl")

"""
struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N}
data::Dict{K,T}
Expand Down
12 changes: 6 additions & 6 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const ArrayIndices{N} = Iterators.ProductIterator{NTuple{N, Base.OneTo{Int}}}
const ArrayIndices{N} = VectorizedProductIterator{NTuple{N, Base.OneTo{Int}}}
container(f::Function, indices) = container(f, indices, default_container(indices))
default_container(::ArrayIndices) = Array
function container(f::Function, indices::ArrayIndices, ::Type{Array})
Expand All @@ -10,14 +10,14 @@ function _oneto(indices)
end
error("Index set for array is not one-based interval.")
end
function container(f::Function, indices::Iterators.ProductIterator,
function container(f::Function, indices::VectorizedProductIterator,
::Type{Array})
container(f, Iterators.ProductIterator(_oneto.(indices.iterators)), Array)
container(f, vectorized_product(_oneto.(indices.prod.iterators)...), Array)
end
default_container(::Iterators.ProductIterator) = DenseAxisArray
function container(f::Function, indices::Iterators.ProductIterator,
default_container(::VectorizedProductIterator) = DenseAxisArray
function container(f::Function, indices::VectorizedProductIterator,
::Type{DenseAxisArray})
return DenseAxisArray(map(I -> f(I...), indices), indices.iterators...)
return DenseAxisArray(map(I -> f(I...), indices), indices.prod.iterators...)
end
default_container(::NestedIterator) = SparseAxisArray
function container(f::Function, indices,
Expand Down
6 changes: 3 additions & 3 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ function _build_ref_sets(_error::Function, expr)
esc_idxvars = esc.(idxvars)
idxfuns = [:(($(esc_idxvars[1:(i - 1)]...),) -> $(idxsets[i])) for i in 1:length(idxvars)]
if condition == :()
indices = :(Containers.NestedIterator(($(idxfuns...),)))
indices = :(Containers.nested($(idxfuns...)))
else
condition_fun = :(($(esc_idxvars...),) -> $(esc(condition)))
indices = :(Containers.NestedIterator(($(idxfuns...),), $condition_fun))
indices = :(Containers.nested($(idxfuns...); condition = $condition_fun))
end
else
indices = :(Base.Iterators.product(($(_explicit_oneto.(idxsets)...))))
indices = :(Containers.vectorized_product($(_explicit_oneto.(idxsets)...)))
end
return idxvars, indices
end
Expand Down
4 changes: 3 additions & 1 deletion src/Containers/nested_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ struct NestedIterator{T}
iterators::T # Tuple of functions
condition::Function
end
NestedIterator(iterator) = NestedIterator(iterator, (args...) -> true)
function nested(iterators...; condition = (args...) -> true)
return NestedIterator(iterators, condition)
end
Base.IteratorSize(::Type{<:NestedIterator}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:NestedIterator}) = Base.EltypeUnknown()
function next_iterate(it::NestedIterator, i, elems, states, iterator, elem_state)
Expand Down
6 changes: 6 additions & 0 deletions test/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ using JuMP.Containers
@test x isa Containers.DenseAxisArray{Int, 1}
Containers.@container(x[i = 2:3, j = 1:2], i + j)
@test x isa Containers.DenseAxisArray{Int, 2}
Containers.@container(x[4], 0.0)
@test x isa Containers.DenseAxisArray{Float64, 1}
Containers.@container(x[4, 5], 0)
@test x isa Containers.DenseAxisArray{Int, 2}
Containers.@container(x[4, 1:3, 5], 0)
@test x isa Containers.DenseAxisArray{Int, 3}
end
@testset "SparseAxisArray" begin
Containers.@container(x[i = 1:3, j = 1:i], i + j)
Expand Down

0 comments on commit b6b9fa3

Please sign in to comment.