Skip to content

Commit 019dcab

Browse files
authored
Merge 18687b2 into 4ee09d2
2 parents 4ee09d2 + 18687b2 commit 019dcab

File tree

7 files changed

+253
-17
lines changed

7 files changed

+253
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.7.20"
4+
version = "0.7.21"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,14 @@ const BlockIndexRangeSlices = BlockIndices{
167167
const BlockIndexVectorSlices = BlockIndices{
168168
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}
169169
}
170+
const GenericBlockIndexVectorSlices = BlockIndices{
171+
<:BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector}}
172+
}
170173
const SubBlockSliceCollection = Union{
171-
BlockIndexRangeSlice,BlockIndexRangeSlices,BlockIndexVectorSlices
174+
BlockIndexRangeSlice,
175+
BlockIndexRangeSlices,
176+
BlockIndexVectorSlices,
177+
GenericBlockIndexVectorSlices,
172178
}
173179

174180
# TODO: This is type piracy. This is used in `reindex` when making
@@ -392,6 +398,13 @@ function blockrange(
392398
return map(Block, blocks(r))
393399
end
394400

401+
function blockrange(
402+
axis::AbstractUnitRange,
403+
r::BlockVector{<:GenericBlockIndex{1},<:AbstractVector{<:BlockIndexVector}},
404+
)
405+
return map(Block, blocks(r))
406+
end
407+
395408
function blockrange(axis::AbstractUnitRange, r)
396409
return error("Slicing not implemented for range of type `$(typeof(r))`.")
397410
end

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,15 @@ BlockArrays.blockindex(b::GenericBlockIndex{1}) = b.α[1]
244244
function GenericBlockIndex(indcs::Tuple{Vararg{GenericBlockIndex{1},N}}) where {N}
245245
GenericBlockIndex(block.(indcs), blockindex.(indcs))
246246
end
247+
248+
function Base.checkindex(
249+
::Type{Bool}, axis::AbstractBlockedUnitRange, ind::GenericBlockIndex{1}
250+
)
251+
return checkindex(Bool, axis, block(ind)) &&
252+
checkbounds(Bool, axis[block(ind)], blockindex(ind))
253+
end
254+
Base.to_index(i::GenericBlockIndex) = i
255+
247256
function print_tuple_elements(io::IO, @nospecialize(t))
248257
if !isempty(t)
249258
print(io, t[1])
@@ -261,6 +270,13 @@ function Base.show(io::IO, B::GenericBlockIndex)
261270
return nothing
262271
end
263272

273+
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L31-L32
274+
_maybetail(::Tuple{}) = ()
275+
_maybetail(t::Tuple) = Base.tail(t)
276+
@inline function Base.to_indices(A, inds, I::Tuple{GenericBlockIndex{1},Vararg{Any}})
277+
return (inds[1][I[1]], to_indices(A, _maybetail(inds), Base.tail(I))...)
278+
end
279+
264280
using Base: @propagate_inbounds
265281
@propagate_inbounds function Base.getindex(b::AbstractVector, K::GenericBlockIndex{1})
266282
return b[Block(K.I[1])][K.α[1]]
@@ -276,35 +292,54 @@ end
276292
return b[GenericBlockIndex(tuple(K, J...))]
277293
end
278294

279-
function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type{<:Integer},N}) where {N}
280-
return BlockIndex{N,NTuple{N,TB},Tuple{TI...}}
281-
end
282-
function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type,N}) where {N}
283-
return GenericBlockIndex{N,NTuple{N,TB},Tuple{TI...}}
284-
end
295+
# Work around the fact that it is type piracy to define
296+
# `Base.getindex(a::Block, b...)`.
297+
_getindex(a::Block{N}, b::Vararg{Any,N}) where {N} = GenericBlockIndex(a, b)
298+
_getindex(a::Block{N}, b::Vararg{Integer,N}) where {N} = a[b...]
299+
# Fix ambiguity.
300+
_getindex(a::Block{0}) = a[]
301+
302+
## function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type{<:Integer},N}) where {N}
303+
## return BlockIndex{N,NTuple{N,TB},Tuple{TI...}}
304+
## end
305+
## function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type,N}) where {N}
306+
## return GenericBlockIndex{N,NTuple{N,TB},Tuple{TI...}}
307+
## end
285308

286-
struct BlockIndexVector{N,I<:NTuple{N,AbstractVector},TB<:Integer,BT} <: AbstractArray{BT,N}
309+
struct BlockIndexVector{N,BT,I<:NTuple{N,AbstractVector},TB<:Integer} <: AbstractArray{BT,N}
287310
block::Block{N,TB}
288311
indices::I
289-
function BlockIndexVector(
312+
function BlockIndexVector{N,BT}(
290313
block::Block{N,TB}, indices::I
291-
) where {N,I<:NTuple{N,AbstractVector},TB<:Integer}
292-
BT = blockindextype(TB, eltype.(indices)...)
293-
return new{N,I,TB,BT}(block, indices)
314+
) where {N,BT,I<:NTuple{N,AbstractVector},TB<:Integer}
315+
return new{N,BT,I,TB}(block, indices)
294316
end
295317
end
318+
function BlockIndexVector{1,BT}(block::Block{1}, indices::AbstractVector) where {BT}
319+
return BlockIndexVector{1,BT}(block, (indices,))
320+
end
321+
function BlockIndexVector(
322+
block::Block{N,TB}, indices::NTuple{N,AbstractVector}
323+
) where {N,TB<:Integer}
324+
BT = Base.promote_op(_getindex, typeof(block), eltype.(indices)...)
325+
return BlockIndexVector{N,BT}(block, indices)
326+
end
296327
function BlockIndexVector(block::Block{1}, indices::AbstractVector)
297328
return BlockIndexVector(block, (indices,))
298329
end
299330
Base.size(a::BlockIndexVector) = length.(a.indices)
300331
function Base.getindex(a::BlockIndexVector{N}, I::Vararg{Integer,N}) where {N}
301-
return a.block[map((r, i) -> r[i], a.indices, I)...]
332+
return _getindex(Block(a), getindex.(a.indices, I)...)
302333
end
303334
BlockArrays.block(b::BlockIndexVector) = b.block
304335
BlockArrays.Block(b::BlockIndexVector) = b.block
305336

306337
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices))
307338

339+
function Base.getindex(b::AbstractBlockedUnitRange, Kkr::BlockIndexVector{1})
340+
b[block(Kkr)][Kkr.indices...]
341+
end
342+
308343
using ArrayLayouts: LayoutArray
309344
@propagate_inbounds Base.getindex(b::AbstractArray{T,N}, K::BlockIndexVector{N}) where {T,N} = b[block(
310345
K
@@ -316,6 +351,19 @@ using ArrayLayouts: LayoutArray
316351
K
317352
)][K.indices...]
318353

354+
function blockedunitrange_getindices(
355+
a::AbstractBlockedUnitRange,
356+
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
357+
)
358+
return mortar(map(b -> a[b], blocks(indices)))
359+
end
360+
function blockedunitrange_getindices(
361+
a::AbstractBlockedUnitRange,
362+
indices::BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
363+
)
364+
return mortar(map(b -> a[b], blocks(indices)))
365+
end
366+
319367
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::AbstractArray{Bool})
320368
I_blocks = blocks(BlockedVector(I, blocklengths(a)))
321369
I′_blocks = map(eachindex(I_blocks)) do b

src/abstractblocksparsearray/views.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ end
9292
# TODO: Move to `GradedUnitRanges` or `BlockArraysExtensions`.
9393
to_block(I::Block{1}) = I
9494
to_block(I::BlockIndexRange{1}) = Block(I)
95-
to_block(I::BlockIndexVector) = Block(I)
95+
to_block(I::BlockIndexVector{1}) = Block(I)
9696
to_block_indices(I::Block{1}) = Colon()
9797
to_block_indices(I::BlockIndexRange{1}) = only(I.indices)
98-
to_block_indices(I::BlockIndexVector) = only(I.indices)
98+
to_block_indices(I::BlockIndexVector{1}) = only(I.indices)
9999

100100
function Base.view(
101101
a::AbstractBlockSparseArray{<:Any,N},
@@ -166,7 +166,7 @@ function BlockArrays.viewblock(
166166
<:AbstractBlockSparseArray{T,N},
167167
<:Tuple{Vararg{Union{BlockSliceCollection,SubBlockSliceCollection},N}},
168168
},
169-
block::Union{Block{N},BlockIndexRange{N}},
169+
block::Union{Block{N},BlockIndexRange{N},BlockIndexVector{N}},
170170
) where {T,N}
171171
return viewblock(a, to_tuple(block)...)
172172
end
@@ -228,6 +228,14 @@ function to_blockindexrange(
228228
# work right now.
229229
return blocks(a.blocks)[Int(I)]
230230
end
231+
function to_blockindexrange(
232+
a::BlockIndices{<:BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector}}},
233+
I::Block{1},
234+
)
235+
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
236+
# work right now.
237+
return blocks(a.blocks)[Int(I)]
238+
end
231239
function to_blockindexrange(
232240
a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1}
233241
)

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,36 @@ function Base.to_indices(
7878
return @interface interface(a) to_indices(a, inds, I)
7979
end
8080

81+
# a[mortar([Block(1)[[1, 2]], Block(2)[[1, 3]]])]
82+
function Base.to_indices(
83+
a::AnyAbstractBlockSparseArray,
84+
inds,
85+
I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
86+
)
87+
return @interface interface(a) to_indices(a, inds, I)
88+
end
89+
function Base.to_indices(
90+
a::AnyAbstractBlockSparseArray,
91+
inds,
92+
I::Tuple{BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
93+
)
94+
return @interface interface(a) to_indices(a, inds, I)
95+
end
96+
8197
# a[[Block(1)[1:2], Block(2)[1:2]], [Block(1)[1:2], Block(2)[1:2]]]
8298
function Base.to_indices(
8399
a::AnyAbstractBlockSparseArray, inds, I::Tuple{Vector{<:BlockIndexRange{1}},Vararg{Any}}
84100
)
85101
return to_indices(a, inds, (mortar(I[1]), Base.tail(I)...))
86102
end
87103

104+
# a[[Block(1)[[1, 2]], Block(2)[[1, 2]]], [Block(1)[[1, 2]], Block(2)[[1, 2]]]]
105+
function Base.to_indices(
106+
a::AnyAbstractBlockSparseArray, inds, I::Tuple{Vector{<:BlockIndexVector{1}},Vararg{Any}}
107+
)
108+
return to_indices(a, inds, (mortar(I[1]), Base.tail(I)...))
109+
end
110+
88111
# BlockArrays `AbstractBlockArray` interface
89112
function BlockArrays.blocks(a::AnyAbstractBlockSparseArray)
90113
@interface interface(a) blocks(a)

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,23 @@ end
229229
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
230230
end
231231

232+
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
233+
a,
234+
inds,
235+
I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
236+
)
237+
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
238+
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
239+
end
240+
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
241+
a,
242+
inds,
243+
I::Tuple{BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
244+
)
245+
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
246+
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
247+
end
248+
232249
# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])]
233250
# Permute and merge blocks.
234251
# TODO: This isn't merging blocks yet, that needs to be implemented that.

test/test_genericblockindex.jl

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
using BlockArrays: Block, BlockIndex, BlockedVector, block, blockindex
2+
using BlockSparseArrays: BlockSparseArrays, BlockIndexVector, GenericBlockIndex
3+
using Test: @test, @test_broken, @testset
4+
5+
# blockrange
6+
# checkindex
7+
# to_indices
8+
# to_index
9+
# blockedunitrange_getindices
10+
# viewblock
11+
# to_blockindexrange
12+
13+
@testset "GenericBlockIndex" begin
14+
i1 = GenericBlockIndex(Block(1), ("x",))
15+
i2 = GenericBlockIndex(Block(2), ("y",))
16+
i = GenericBlockIndex(Block(1, 2), ("x", "y"))
17+
@test sprint(show, i) == "Block(1, 2)[x, y]"
18+
@test i isa GenericBlockIndex{2,Tuple{Int64,Int64},Tuple{String,String}}
19+
@test GenericBlockIndex(Block(1), "x") === i1
20+
@test GenericBlockIndex(1, "x") === i1
21+
@test GenericBlockIndex(1, ("x",)) === i1
22+
@test GenericBlockIndex((1,), "x") === i1
23+
@test GenericBlockIndex((1, 2), ("x", "y")) === i
24+
@test GenericBlockIndex((Block(1), Block(2)), ("x", "y")) === i
25+
@test GenericBlockIndex((i1, i2)) === i
26+
@test block(i1) == Block(1)
27+
@test block(i) == Block(1, 2)
28+
@test blockindex(i1) == "x"
29+
@test GenericBlockIndex((), ()) == GenericBlockIndex(Block(), ())
30+
@test GenericBlockIndex(Block(1, 2), ("x",)) == GenericBlockIndex(Block(1, 2), ("x", 1))
31+
32+
i1 = GenericBlockIndex(Block(1), (1,))
33+
i2 = GenericBlockIndex(Block(2), (2,))
34+
i = GenericBlockIndex(Block(1, 2), (1, 2))
35+
v = BlockedVector(["a", "b", "c", "d"], [2, 2])
36+
@test v[i1] == "a"
37+
@test v[i2] == "d"
38+
39+
a = collect(Iterators.product(v, v))
40+
@test a[i1, i1] == ("a", "a")
41+
@test a[i2, i1] == ("d", "a")
42+
@test a[i1, i2] == ("a", "d")
43+
@test a[i] == ("a", "d")
44+
@test a[i2, i2] == ("d", "d")
45+
46+
I = BlockIndexVector(Block(1), [1, 2])
47+
@test eltype(I) === BlockIndex{1,Tuple{Int},Tuple{Int}}
48+
@test ndims(I) === 1
49+
@test length(I) === 2
50+
@test size(I) === (2,)
51+
@test I[1] === Block(1)[1]
52+
@test I[2] === Block(1)[2]
53+
@test block(I) === Block(1)
54+
@test Block(I) === Block(1)
55+
@test copy(I) == BlockIndexVector(Block(1), [1, 2])
56+
57+
I = BlockIndexVector(Block(1, 2), ([1, 2], [3, 4]))
58+
@test eltype(I) === BlockIndex{2,Tuple{Int,Int},Tuple{Int,Int}}
59+
@test ndims(I) === 2
60+
@test length(I) === 4
61+
@test size(I) === (2, 2)
62+
@test I[1, 1] === Block(1, 2)[1, 3]
63+
@test I[2, 1] === Block(1, 2)[2, 3]
64+
@test I[1, 2] === Block(1, 2)[1, 4]
65+
@test I[2, 2] === Block(1, 2)[2, 4]
66+
@test block(I) === Block(1, 2)
67+
@test Block(I) === Block(1, 2)
68+
@test copy(I) == BlockIndexVector(Block(1, 2), ([1, 2], [3, 4]))
69+
70+
I = BlockIndexVector(Block(1), ["x", "y"])
71+
@test eltype(I) === GenericBlockIndex{1,Tuple{Int},Tuple{String}}
72+
@test ndims(I) === 1
73+
@test length(I) === 2
74+
@test size(I) === (2,)
75+
@test I[1] === GenericBlockIndex(Block(1), "x")
76+
@test I[2] === GenericBlockIndex(Block(1), "y")
77+
@test block(I) === Block(1)
78+
@test Block(I) === Block(1)
79+
@test copy(I) == BlockIndexVector(Block(1), ["x", "y"])
80+
81+
I = BlockIndexVector(Block(1, 2), (["x", "y"], ["z", "w"]))
82+
@test eltype(I) === GenericBlockIndex{2,Tuple{Int,Int},Tuple{String,String}}
83+
@test ndims(I) === 2
84+
@test length(I) === 4
85+
@test size(I) === (2, 2)
86+
@test I[1, 1] === GenericBlockIndex(Block(1, 2), ("x", "z"))
87+
@test I[2, 1] === GenericBlockIndex(Block(1, 2), ("y", "z"))
88+
@test I[1, 2] === GenericBlockIndex(Block(1, 2), ("x", "w"))
89+
@test I[2, 2] === GenericBlockIndex(Block(1, 2), ("y", "w"))
90+
@test block(I) === Block(1, 2)
91+
@test Block(I) === Block(1, 2)
92+
@test copy(I) == BlockIndexVector(Block(1, 2), (["x", "y"], ["z", "w"]))
93+
94+
v = BlockedVector(["a", "b", "c", "d"], [2, 2])
95+
i = BlockIndexVector(Block(1), [2, 1])
96+
@test v[i] == ["b", "a"]
97+
i = BlockIndexVector(Block(2), [2, 1])
98+
@test v[i] == ["d", "c"]
99+
100+
v = BlockedVector(["a", "b", "c", "d"], [2, 2])
101+
i = BlockIndexVector{1,GenericBlockIndex{1,Tuple{Int},Tuple{String}}}(Block(1), [2, 1])
102+
@test v[i] == ["b", "a"]
103+
i = BlockIndexVector(Block(2), [2, 1])
104+
@test v[i] == ["d", "c"]
105+
106+
a = collect(Iterators.product(v, v))
107+
i1 = BlockIndexVector(Block(1), [2, 1])
108+
i2 = BlockIndexVector(Block(2), [1, 2])
109+
i = BlockIndexVector(Block(1, 2), ([2, 1], [1, 2]))
110+
@test a[i1, i1] == [("b", "b") ("b", "a"); ("a", "b") ("a", "a")]
111+
@test a[i2, i1] == [("c", "b") ("c", "a"); ("d", "b") ("d", "a")]
112+
@test a[i1, i2] == [("b", "c") ("b", "d"); ("a", "c") ("a", "d")]
113+
@test a[i] == [("b", "c") ("b", "d"); ("a", "c") ("a", "d")]
114+
@test a[i2, i2] == [("c", "c") ("c", "d"); ("d", "c") ("d", "d")]
115+
116+
a = collect(Iterators.product(v, v))
117+
i1 = BlockIndexVector{1,GenericBlockIndex{1,Tuple{Int},Tuple{String}}}(Block(1), [2, 1])
118+
i2 = BlockIndexVector{1,GenericBlockIndex{1,Tuple{Int},Tuple{String}}}(Block(2), [1, 2])
119+
i = BlockIndexVector{2,GenericBlockIndex{2,Tuple{Int,Int},Tuple{String,String}}}(
120+
Block(1, 2), ([2, 1], [1, 2])
121+
)
122+
@test a[i1, i1] == [("b", "b") ("b", "a"); ("a", "b") ("a", "a")]
123+
@test a[i2, i1] == [("c", "b") ("c", "a"); ("d", "b") ("d", "a")]
124+
@test a[i1, i2] == [("b", "c") ("b", "d"); ("a", "c") ("a", "d")]
125+
@test a[i] == [("b", "c") ("b", "d"); ("a", "c") ("a", "d")]
126+
@test a[i2, i2] == [("c", "c") ("c", "d"); ("d", "c") ("d", "d")]
127+
end

0 commit comments

Comments
 (0)