diff --git a/Project.toml b/Project.toml index 0459c94..7355d8c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" -version = "0.10.12" authors = ["ITensor developers and contributors"] +version = "0.10.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -22,11 +22,9 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [weakdeps] TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" [extensions] BlockSparseArraysTensorAlgebraExt = "TensorAlgebra" -BlockSparseArraysTensorProductsExt = "TensorProducts" [compat] Adapt = "4.1.1" @@ -44,8 +42,7 @@ MapBroadcast = "0.1.5" MatrixAlgebraKit = "0.6" SparseArraysBase = "0.7.1" SplitApplyCombine = "1.2.3" -TensorAlgebra = "0.5" -TensorProducts = "0.1.7" +TensorAlgebra = "0.6" Test = "1.10" TypeParameterAccessors = "0.4.1" julia = "1.10" diff --git a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl index 5efcc89..94e658c 100644 --- a/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl +++ b/ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl @@ -1,25 +1,29 @@ module BlockSparseArraysTensorAlgebraExt -using BlockSparseArrays: AbstractBlockSparseArray, blockreshape -using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, fuseaxes - -struct BlockReshapeFusion <: FusionStyle end +using BlockArrays: Block, blocklength, blocks, eachblockaxes1 +using BlockSparseArrays: AbstractBlockSparseArray, AbstractBlockSparseMatrix, + BlockUnitRange, blockrange, blocksparse +using SparseArraysBase: eachstoredindex +using TensorAlgebra: TensorAlgebra, BlockReshapeFusion, BlockedTuple, matricize, + matricize_axes, tensor_product_axis, unmatricize -function TensorAlgebra.FusionStyle(::Type{<:AbstractBlockSparseArray}) - return BlockReshapeFusion() +function TensorAlgebra.tensor_product_axis( + ::BlockReshapeFusion, r1::BlockUnitRange, r2::BlockUnitRange + ) + isone(first(r1)) || isone(first(r2)) || + throw(ArgumentError("Only one-based axes are supported")) + blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2)) + blockaxs = vec(splat(tensor_product_axis).(blockaxpairs)) + return blockrange(blockaxs) end -using BlockArrays: Block, blocklength, blocks -using BlockSparseArrays: blocksparse -using SparseArraysBase: eachstoredindex -using TensorAlgebra: TensorAlgebra, matricize, unmatricize function TensorAlgebra.matricize( - ::BlockReshapeFusion, a::AbstractArray, length1::Val, length2::Val + style::BlockReshapeFusion, a::AbstractBlockSparseArray, length_codomain::Val ) - ax = fuseaxes(axes(a), length1, length2) - reshaped_blocks_a = reshape(blocks(a), map(blocklength, ax)) + ax = matricize_axes(style, a, length_codomain) + reshaped_blocks_a = reshape(blocks(a), blocklength.(ax)) key(I) = Block(Tuple(I)) - value(I) = matricize(reshaped_blocks_a[I], length1, length2) + value(I) = matricize(reshaped_blocks_a[I], length_codomain) Is = eachstoredindex(reshaped_blocks_a) bs = if isempty(Is) # Catch empty case and make sure the type is constrained properly. @@ -35,16 +39,16 @@ function TensorAlgebra.matricize( return blocksparse(bs, ax) end -using BlockArrays: blocklengths function TensorAlgebra.unmatricize( ::BlockReshapeFusion, - m::AbstractMatrix, + m::AbstractBlockSparseMatrix, codomain_axes::Tuple{Vararg{AbstractUnitRange}}, domain_axes::Tuple{Vararg{AbstractUnitRange}}, ) ax = (codomain_axes..., domain_axes...) - reshaped_blocks_m = reshape(blocks(m), map(blocklength, ax)) - function f(I) + reshaped_blocks_m = reshape(blocks(m), blocklength.(ax)) + key(I) = Block(Tuple(I)) + function value(I) block_axes_I = BlockedTuple( map(ntuple(identity, length(ax))) do i return Base.axes1(ax[i][Block(I[i])]) @@ -53,7 +57,7 @@ function TensorAlgebra.unmatricize( ) return unmatricize(reshaped_blocks_m[I], block_axes_I) end - bs = Dict(Block(Tuple(I)) => f(I) for I in eachstoredindex(reshaped_blocks_m)) + bs = Dict(key(I) => value(I) for I in eachstoredindex(reshaped_blocks_m)) return blocksparse(bs, ax) end diff --git a/ext/BlockSparseArraysTensorProductsExt/BlockSparseArraysTensorProductsExt.jl b/ext/BlockSparseArraysTensorProductsExt/BlockSparseArraysTensorProductsExt.jl deleted file mode 100644 index 82f46fc..0000000 --- a/ext/BlockSparseArraysTensorProductsExt/BlockSparseArraysTensorProductsExt.jl +++ /dev/null @@ -1,14 +0,0 @@ -module BlockSparseArraysTensorProductsExt - -using BlockSparseArrays: BlockUnitRange, blockrange, eachblockaxis -using TensorProducts: TensorProducts, tensor_product -# TODO: Dispatch on `FusionStyle` to allow different kinds of products, -# for example to allow merging common symmetry sectors. -function TensorProducts.tensor_product(a1::BlockUnitRange, a2::BlockUnitRange) - new_blockaxes = vec( - map(splat(tensor_product), Iterators.product(eachblockaxis(a1), eachblockaxis(a2))) - ) - return blockrange(new_blockaxes) -end - -end diff --git a/test/Project.toml b/test/Project.toml index d3284cc..d330c16 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -40,7 +40,7 @@ SafeTestsets = "0.1" SparseArraysBase = "0.7" StableRNGs = "1" Suppressor = "0.2" -TensorAlgebra = "0.5" +TensorAlgebra = "0.6" Test = "1" TestExtras = "0.3" TypeParameterAccessors = "0.4"