From a9cc1d951afa497d9d05900f086a149560a87782 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 3 Dec 2025 17:27:28 -0500 Subject: [PATCH 1/9] [WIP] Upgrade to TensorAlgebra v0.6 --- Project.toml | 6 ++--- src/fusion.jl | 30 +++++++++++++----------- src/sectorrange.jl | 16 +++++++------ src/tensoralgebra.jl | 56 +++++++++++++++++++++----------------------- test/Project.toml | 2 +- 5 files changed, 55 insertions(+), 55 deletions(-) diff --git a/Project.toml b/Project.toml index 65c0abd..f6062c3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GradedArrays" uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2" -version = "0.5.4" authors = ["ITensor developers and contributors"] +version = "0.5.5" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -16,7 +16,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f" -TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [weakdeps] @@ -37,8 +36,7 @@ MatrixAlgebraKit = "0.6" Random = "1.10" SUNRepresentations = "0.3" SplitApplyCombine = "1.2.3" -TensorAlgebra = "0.5" +TensorAlgebra = "0.6" TensorKitSectors = "0.1, 0.2" -TensorProducts = "0.1.3" TypeParameterAccessors = "0.4" julia = "1.10" diff --git a/src/fusion.jl b/src/fusion.jl index a0709e0..0269af8 100644 --- a/src/fusion.jl +++ b/src/fusion.jl @@ -1,22 +1,21 @@ using BlockArrays: Block, blocks using SplitApplyCombine: groupcount -using TensorProducts: TensorProducts, ⊗, OneToOne, tensor_product flip_dual(r::AbstractUnitRange) = isdual(r) ? flip(r) : r -# TensorProducts interface -function TensorProducts.tensor_product(sr1::SectorUnitRange, sr2::SectorUnitRange) +# TODO: Overload `TensorAlgebra.tensor_product_axis` for `SectorFusion`. +function tensor_product(sr1::SectorUnitRange, sr2::SectorUnitRange) return tensor_product(combine_styles(SymmetryStyle(sr1), SymmetryStyle(sr2)), sr1, sr2) end -function TensorProducts.tensor_product( +function tensor_product( ::AbelianStyle, sr1::SectorUnitRange, sr2::SectorUnitRange ) s = sector(flip_dual(sr1)) ⊗ sector(flip_dual(sr2)) return sectorrange(s, sector_multiplicity(sr1) * sector_multiplicity(sr2)) end -function TensorProducts.tensor_product( +function tensor_product( ::NotAbelianStyle, sr1::SectorUnitRange, sr2::SectorUnitRange ) g0 = sector(flip_dual(sr1)) ⊗ sector(flip_dual(sr2)) @@ -27,23 +26,23 @@ function TensorProducts.tensor_product( end # allow to fuse a Sector with a GradedUnitRange -function TensorProducts.tensor_product( +function tensor_product( s::Union{SectorRange, SectorUnitRange}, g::AbstractGradedUnitRange ) return to_gradedrange(s) ⊗ g end -function TensorProducts.tensor_product( +function tensor_product( g::AbstractGradedUnitRange, s::Union{SectorRange, SectorUnitRange} ) return g ⊗ to_gradedrange(s) end -function TensorProducts.tensor_product(sr::SectorUnitRange, s::SectorRange) +function tensor_product(sr::SectorUnitRange, s::SectorRange) return sr ⊗ sectorrange(s, 1) end -function TensorProducts.tensor_product(s::SectorRange, sr::SectorUnitRange) +function tensor_product(s::SectorRange, sr::SectorUnitRange) return sectorrange(s, 1) ⊗ sr end @@ -52,9 +51,10 @@ end # it is not aimed for generic use and does not support all tensor_product methods (no dispatch on SymmetryStyle) unmerged_tensor_product() = OneToOne() unmerged_tensor_product(a) = a -unmerged_tensor_product(a, ::OneToOne) = a -unmerged_tensor_product(::OneToOne, a) = a -unmerged_tensor_product(::OneToOne, ::OneToOne) = OneToOne() +# TODO: Delete. +# unmerged_tensor_product(a, ::OneToOne) = a +# unmerged_tensor_product(::OneToOne, a) = a +# unmerged_tensor_product(::OneToOne, ::OneToOne) = OneToOne() function unmerged_tensor_product(a1, a2, as...) return unmerged_tensor_product(unmerged_tensor_product(a1, a2), as...) end @@ -62,6 +62,8 @@ end # default to tensor_product unmerged_tensor_product(a1, a2) = a1 ⊗ a2 +# TODO: Use `TensorAlgebra.tensor_product_axis(::BlockReshapeFusion, ...)` instead. +using BlockSparseArrays: mortar_axis function unmerged_tensor_product(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange) new_axes = map(splat(⊗), Iterators.flatten((Iterators.product(blocks(a1), blocks(a2)),))) return mortar_axis(new_axes) @@ -103,9 +105,9 @@ end sectormergesort(g::AbstractUnitRange) = g # tensor_product produces a sorted, non-dual GradedUnitRange -TensorProducts.tensor_product(g::AbstractGradedUnitRange) = sectormergesort(flip_dual(g)) +tensor_product(g::AbstractGradedUnitRange) = sectormergesort(flip_dual(g)) -function TensorProducts.tensor_product( +function tensor_product( g1::AbstractGradedUnitRange, g2::AbstractGradedUnitRange ) return sectormergesort(unmerged_tensor_product(g1, g2)) diff --git a/src/sectorrange.jl b/src/sectorrange.jl index 4678179..1d78cde 100644 --- a/src/sectorrange.jl +++ b/src/sectorrange.jl @@ -1,6 +1,5 @@ # This file defines the interface for type Sector # all fusion categories (Z{2}, SU2, Ising...) are subtypes of Sector -using TensorProducts: TensorProducts, ⊗ import TensorKitSectors as TKS """ @@ -122,17 +121,20 @@ function fusion_rule(r1::SectorRange, r2::SectorRange) ) end -# ============================= TensorProducts interface =====--========================== +# ============================= Tensor products ========================================== -TensorProducts.tensor_product(s::SectorRange) = s -TensorProducts.tensor_product(c1::SectorRange, c2::SectorRange) = fusion_rule(c1, c2) -function TensorProducts.tensor_product(c1::TKS.Sector, c2::TKS.Sector) +# TODO: Overload `TensorAlgebra.tensor_product_axis` for `SectorFusion`. +function tensor_product end +const ⊗ = tensor_product +tensor_product(s::SectorRange) = s +tensor_product(c1::SectorRange, c2::SectorRange) = fusion_rule(c1, c2) +function tensor_product(c1::TKS.Sector, c2::TKS.Sector) return tensor_product(to_sector(c1), to_sector(c2)) end -function TensorProducts.tensor_product(c1::SectorRange, c2::TKS.Sector) +function tensor_product(c1::SectorRange, c2::TKS.Sector) return tensor_product(c1, to_sector(c2)) end -function TensorProducts.tensor_product(c1::TKS.Sector, c2::SectorRange) +function tensor_product(c1::TKS.Sector, c2::SectorRange) return tensor_product(to_sector(c1), c2) end diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 1aa618c..9a28f95 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -1,39 +1,23 @@ -using BlockArrays: blocks +using BlockArrays: blocks, eachblockaxes1 using BlockSparseArrays: BlockSparseArray, blockreshape -using GradedArrays: - AbstractGradedUnitRange, - SectorRange, - GradedArray, - flip, - gradedrange, - invblockperm, - sectormergesortperm, - sectorsortperm, - trivial, - unmerged_tensor_product -using TensorAlgebra: - TensorAlgebra, - ⊗, - AbstractBlockPermutation, - BlockedTuple, - FusionStyle, - trivial_axis, - unmatricize +using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockReshapeFusion, + BlockedTuple, FusionStyle, ReshapeFusion, matricize, tensor_product_axis, unmatricize struct SectorFusion <: FusionStyle end TensorAlgebra.FusionStyle(::Type{<:GradedArray}) = SectorFusion() -function TensorAlgebra.trivial_axis(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange} +function trivial_axis(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange} return trivial(first(t)) end # heterogeneous sectors -TensorAlgebra.trivial_axis(t::Tuple{Vararg{AbstractGradedUnitRange}}) = ⊗(trivial.(t)...) +trivial_axis(t::Tuple{Vararg{AbstractGradedUnitRange}}) = ⊗(trivial.(t)...) # trivial_axis from sector_type -function TensorAlgebra.trivial_axis(::Type{S}) where {S <: SectorRange} +function trivial_axis(::Type{S}) where {S <: SectorRange} return gradedrange([trivial(S) => 1]) end +# TODO: Use `TensorAlgebra.matricize_axes`. function matricize_axes( blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} ) @@ -45,14 +29,27 @@ function matricize_axes( return codomain_axis, flip(unflipped_domain_axis) end -using TensorAlgebra: blockedtrivialperm +function TensorAlgebra.trivial_axis(::BlockReshapeFusion, a::GradedArray) + return trivial_axis(axes(a)) +end +function TensorAlgebra.tensor_product_axis( + ::ReshapeFusion, r1::SectorUnitRange, r2::SectorUnitRange + ) + return r1 ⊗ r2 +end +function TensorAlgebra.tensor_product_axis( + ::BlockReshapeFusion, r1::AbstractGradedUnitRange, r2::AbstractGradedUnitRange + ) + (isone(first(r1)) && isone(first(r2))) || + throw(ArgumentError("Only one-based axes are supported")) + blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2)) + blockaxs = vec(map(splat(tensor_product_axis), blockaxpairs)) + return mortar_axis(blockaxs) +end function TensorAlgebra.matricize( - ::SectorFusion, a::AbstractArray, codomain_length::Val, domain_length::Val + ::SectorFusion, a::AbstractArray, length_codomain::Val ) - biperm = blockedtrivialperm((codomain_length, domain_length)) - codomain_axis, domain_axis = matricize_axes(axes(a)[biperm]) - a_reshaped = blockreshape(a, (codomain_axis, domain_axis)) - # Sort the blocks by sector and merge the equivalent sectors. + a_reshaped = matricize(BlockReshapeFusion(), a, length_codomain) return sectormergesort(a_reshaped) end @@ -74,6 +71,7 @@ function TensorAlgebra.unmatricize( # First, fuse axes to get `sectormergesortperm`. # Then unpermute the blocks. + # TODO: Use `TensorAlgebra.matricize_axes`. fused_axes = matricize_axes(blocked_axes) blockperms = sectorsortperm.(fused_axes) diff --git a/test/Project.toml b/test/Project.toml index 8d59748..cd65325 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ SUNRepresentations = "0.3" SafeTestsets = "0.1" SparseArraysBase = "0.7" Suppressor = "0.2.8" -TensorAlgebra = "0.5" +TensorAlgebra = "0.6" TensorKitSectors = "0.1, 0.2" TensorProducts = "0.1.3" Test = "1.10" From d3cd5e4074b03d9b14ebdb7a4a779f9ed6f35f46 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 4 Dec 2025 09:13:05 -0500 Subject: [PATCH 2/9] Define matricize_axes --- src/tensoralgebra.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 9a28f95..441363b 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -46,6 +46,23 @@ function TensorAlgebra.tensor_product_axis( blockaxs = vec(map(splat(tensor_product_axis), blockaxpairs)) return mortar_axis(blockaxs) end +using TensorAlgebra: trivialbiperm +unval(::Val{x}) where {x} = x +function TensorAlgebra.matricize_axes( + style::BlockReshapeFusion, a::GradedArray, ndims_codomain::Val + ) + unval(ndims_codomain) ≤ ndims(a) || + throw(ArgumentError("Codomain length exceeds number of dimensions.")) + biperm = trivialbiperm(ndims_codomain, Val(ndims(a))) + axesblocks = blocks(axes(a)[biperm]) + init_axis = TensorAlgebra.trivial_axis(style, a) + axis_codomain, axis_domain = map(axesblocks) do axesblock + return reduce(axesblock; init = init_axis) do ax1, ax2 + return tensor_product_axis(style, ax1, ax2) + end + end + return axis_codomain, flip(axis_domain) +end function TensorAlgebra.matricize( ::SectorFusion, a::AbstractArray, length_codomain::Val ) From 5210b1b866d3681afab08454c64536ad895f9f9c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 4 Dec 2025 11:10:24 -0500 Subject: [PATCH 3/9] Fix tests --- src/fusion.jl | 23 +++++++++++++------ test/Project.toml | 2 -- test/test_fusion_rule.jl | 3 ++- test/test_interface.jl | 11 --------- test/test_sectorproduct.jl | 2 +- test/test_show.jl | 42 ++++++++++++++++------------------- test/test_tensor_product.jl | 13 ++++------- test/test_tensoralgebraext.jl | 5 +++-- 8 files changed, 45 insertions(+), 56 deletions(-) diff --git a/src/fusion.jl b/src/fusion.jl index 0269af8..5298c1a 100644 --- a/src/fusion.jl +++ b/src/fusion.jl @@ -46,17 +46,26 @@ function tensor_product(s::SectorRange, sr::SectorUnitRange) return sectorrange(s, 1) ⊗ sr end +function tensor_product(r1::AbstractUnitRange, r2::AbstractUnitRange) + (isone(first(r1)) && isone(first(r2))) || + throw(ArgumentError("Only one-based axes are supported")) + return Base.OneTo(length(r1) * length(r2)) +end + +function tensor_product( + r1::AbstractUnitRange, r2::AbstractUnitRange, r3::AbstractUnitRange, + rs::AbstractUnitRange..., + ) + return tensor_product(tensor_product(r1, r2), r3, rs...) +end + # unmerged_tensor_product is a private function needed in GradedArraysTensorAlgebraExt # to get block permutation # it is not aimed for generic use and does not support all tensor_product methods (no dispatch on SymmetryStyle) -unmerged_tensor_product() = OneToOne() +unmerged_tensor_product() = Base.OneTo(1) unmerged_tensor_product(a) = a -# TODO: Delete. -# unmerged_tensor_product(a, ::OneToOne) = a -# unmerged_tensor_product(::OneToOne, a) = a -# unmerged_tensor_product(::OneToOne, ::OneToOne) = OneToOne() -function unmerged_tensor_product(a1, a2, as...) - return unmerged_tensor_product(unmerged_tensor_product(a1, a2), as...) +function unmerged_tensor_product(a1, a2, a3, as...) + return unmerged_tensor_product(unmerged_tensor_product(a1, a2), a3, as...) end # default to tensor_product diff --git a/test/Project.toml b/test/Project.toml index cd65325..ff8bf84 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,7 +12,6 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f" -TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" @@ -33,6 +32,5 @@ SparseArraysBase = "0.7" Suppressor = "0.2.8" TensorAlgebra = "0.6" TensorKitSectors = "0.1, 0.2" -TensorProducts = "0.1.3" Test = "1.10" TestExtras = "0.3.1" diff --git a/test/test_fusion_rule.jl b/test/test_fusion_rule.jl index 202c662..0ea997d 100644 --- a/test/test_fusion_rule.jl +++ b/test/test_fusion_rule.jl @@ -6,15 +6,16 @@ using GradedArrays: TrivialSector, U1, Z, + ⊗, dual, flip, gradedrange, nsymbol, quantum_dimension, space_isequal, + tensor_product, trivial, unmerged_tensor_product -using TensorProducts: ⊗, tensor_product using SUNRepresentations: SUNIrrep using Test: @test, @test_throws, @testset using TestExtras: @constinferred diff --git a/test/test_interface.jl b/test/test_interface.jl index be24886..2a5281d 100644 --- a/test/test_interface.jl +++ b/test/test_interface.jl @@ -3,19 +3,8 @@ using BlockArrays: BlockedOneTo, blockedrange, blockisequal using GradedArrays: NoSector, dag, dual, flip, isdual, map_sectors, sectors, space_isequal, ungrade using Test: @test, @testset -using TensorProducts: OneToOne @testset "GradedUnitRange interface for AbstractUnitRange" begin - a0 = OneToOne() - @test !isdual(a0) - @test dual(a0) isa OneToOne - @test space_isequal(a0, a0) - @test space_isequal(a0, dual(a0)) - @test only(sectors(a0)) == NoSector() - @test ungrade(a0) === a0 - @test map_sectors(identity, a0) === a0 - @test dag(a0) === a0 - a = 1:3 ad = dual(a) af = flip(a) diff --git a/test/test_sectorproduct.jl b/test/test_sectorproduct.jl index 5070262..7f836b9 100644 --- a/test/test_sectorproduct.jl +++ b/test/test_sectorproduct.jl @@ -5,6 +5,7 @@ using GradedArrays: TrivialSector, U1, Z, + ⊗, ×, arguments, dual, @@ -17,7 +18,6 @@ using GradedArrays: sectorrange, space_isequal, trivial -using TensorProducts: ⊗ using Test: @test, @test_broken, @test_throws, @testset using TestExtras: @constinferred using BlockArrays: blocklengths diff --git a/test/test_show.jl b/test/test_show.jl index 620a1b7..478fbad 100644 --- a/test/test_show.jl +++ b/test/test_show.jl @@ -2,57 +2,53 @@ # sometimes displays GradedArrays.GradedUnitRange and sometimes GradedUnitRange depending # on exact setup +using BlockArrays: BlockedOneTo, BlockedUnitRange +using GradedArrays: GradedArrays, Fib, GradedUnitRange, Ising, O2, SU2, SectorUnitRange, + TrivialSector, U1, ×, gradedrange, sectorrange using Test: @test, @testset -using GradedArrays: - GradedArrays, ×, Fib, Ising, O2, SU2, TrivialSector, U1, gradedrange, sectorrange - -show_namespaced(x) = sprint(show, x; context = (:module => GradedArrays)) -show_non_namespaced(x) = sprint(show, x) @testset "show SymmetrySector" begin q1 = U1(1) - @test show_namespaced(q1) == "U1(1)" - @test show_non_namespaced(q1) == "GradedArrays.U1(1)" || - show_non_namespaced(q1) == "U1(1)" + @test sprint(show, q1) == "$U1(1)" s0e = O2(0) s0o = O2(-1) s12 = O2(1 // 2) s1 = O2(1) @test isnothing(show(devnull, [s0o, s0e, s12])) - @test show_namespaced(s0e) == "O2(0)" - @test show_namespaced(s0o) == "O2(-1)" - @test show_namespaced(s12) == "O2(1/2)" - @test show_non_namespaced(s0e) == "GradedArrays.O2(0)" + @test sprint(show, s0e) == "$O2(0)" + @test sprint(show, s0o) == "$O2(-1)" + @test sprint(show, s12) == "$O2(1/2)" + @test sprint(show, s0e) == "$O2(0)" j1 = SU2(0) - @test show_namespaced(j1) == "SU2(0)" + @test sprint(show, j1) == "$SU2(0)" - @test show_namespaced(Fib.(("1", "τ"))) == "(Fib(\"1\"), Fib(\"τ\"))" - @test show_namespaced(Ising.(("1", "σ", "ψ"))) == - "(Ising(\"1\"), Ising(\"σ\"), Ising(\"ψ\"))" + @test sprint(show, Fib.(("1", "τ"))) == "($Fib(\"1\"), $Fib(\"τ\"))" + @test sprint(show, Ising.(("1", "σ", "ψ"))) == + "($Ising(\"1\"), $Ising(\"σ\"), $Ising(\"ψ\"))" s = (A = U1(1),) × (B = SU2(2),) - @test show_namespaced(s) == "((A=U1(1),) × (B=SU2(2),))" + @test sprint(show, s) == "((A=$U1(1),) × (B=$SU2(2),))" s = TrivialSector() × U1(3) × SU2(1 / 2) - @test show_namespaced(s) == "(TrivialSector() × U1(3) × SU2(1/2))" + @test sprint(show, s) == "(TrivialSector() × $U1(3) × $SU2(1/2))" end @testset "show GradedUnitRange" begin g1 = gradedrange(["x" => 2, "y" => 3, "z" => 2]) - @test show_namespaced(g1) == "GradedUnitRange[\"x\" => 2, \"y\" => 3, \"z\" => 2]" + @test sprint(show, g1) == "$GradedUnitRange[\"x\" => 2, \"y\" => 3, \"z\" => 2]" @test sprint(show, MIME("text/plain"), g1) == - "GradedArrays.GradedUnitRange{Int64, GradedArrays.SectorUnitRange{Int64, String, Base.OneTo{Int64}}, BlockArrays.BlockedOneTo{Int64, Vector{Int64}}, Vector{Int64}}\nSectorUnitRange x => 1:2\nSectorUnitRange y => 3:5\nSectorUnitRange z => 6:7" + "$GradedUnitRange{Int64, $SectorUnitRange{Int64, String, Base.OneTo{Int64}}, $BlockedOneTo{Int64, Vector{Int64}}, Vector{Int64}}\n$SectorUnitRange x => 1:2\n$SectorUnitRange y => 3:5\n$SectorUnitRange z => 6:7" g2 = gradedrange(1, ["x" => 2, "y" => 3, "z" => 2]) @test sprint(show, g2) == "GradedUnitRange[\"x\" => 2, \"y\" => 3, \"z\" => 2]" @test sprint(show, MIME("text/plain"), g2) == - "GradedArrays.GradedUnitRange{Int64, GradedArrays.SectorUnitRange{Int64, String, Base.OneTo{Int64}}, BlockArrays.BlockedUnitRange{Int64, Vector{Int64}}, Vector{Int64}}\nSectorUnitRange x => 1:2\nSectorUnitRange y => 3:5\nSectorUnitRange z => 6:7" + "$GradedUnitRange{Int64, $SectorUnitRange{Int64, String, Base.OneTo{Int64}}, $BlockedUnitRange{Int64, Vector{Int64}}, Vector{Int64}}\n$SectorUnitRange x => 1:2\n$SectorUnitRange y => 3:5\n$SectorUnitRange z => 6:7" g1d = gradedrange(["x" => 2, "y" => 3, "z" => 2]; isdual = true) - @test sprint(show, g1d) == "GradedUnitRange dual [\"x\" => 2, \"y\" => 3, \"z\" => 2]" + @test sprint(show, g1d) == "$GradedUnitRange dual [\"x\" => 2, \"y\" => 3, \"z\" => 2]" @test sprint(show, MIME("text/plain"), g1d) == - "GradedArrays.GradedUnitRange{Int64, GradedArrays.SectorUnitRange{Int64, String, Base.OneTo{Int64}}, BlockArrays.BlockedOneTo{Int64, Vector{Int64}}, Vector{Int64}}\nSectorUnitRange dual(x) => 1:2\nSectorUnitRange dual(y) => 3:5\nSectorUnitRange dual(z) => 6:7" + "$GradedUnitRange{Int64, $SectorUnitRange{Int64, String, Base.OneTo{Int64}}, $BlockedOneTo{Int64, Vector{Int64}}, Vector{Int64}}\n$SectorUnitRange dual(x) => 1:2\n$SectorUnitRange dual(y) => 3:5\n$SectorUnitRange dual(z) => 6:7" end @testset "show GradedArray" begin diff --git a/test/test_tensor_product.jl b/test/test_tensor_product.jl index 5f6b2ec..33d8f88 100644 --- a/test/test_tensor_product.jl +++ b/test/test_tensor_product.jl @@ -6,6 +6,7 @@ using GradedArrays: SectorUnitRange, SU2, U1, + ⊗, dual, gradedrange, isdual, @@ -13,8 +14,8 @@ using GradedArrays: sectorrange, sectors, space_isequal, + tensor_product, unmerged_tensor_product -using TensorProducts: ⊗, OneToOne, tensor_product using Test: @test, @testset using TestExtras: @constinferred @@ -29,8 +30,8 @@ end Base.length(s::NotAbelianString) = length(s.str) @testset "unmerged_tensor_product" begin - @test unmerged_tensor_product() isa OneToOne - @test unmerged_tensor_product(OneToOne(), OneToOne()) isa OneToOne + @test unmerged_tensor_product() ≡ Base.OneTo(1) + @test unmerged_tensor_product(Base.OneTo(1), Base.OneTo(1)) ≡ Base.OneTo(1) @test unmerged_tensor_product(1:1, 1:1) == 1:1 @test sectormergesort(1:1) isa UnitRange @@ -79,13 +80,9 @@ Base.length(s::NotAbelianString) = length(s.str) ), ) @test space_isequal(unmerged_tensor_product(a), a) - @test space_isequal(unmerged_tensor_product(a, OneToOne()), a) - @test space_isequal(unmerged_tensor_product(OneToOne(), a), a) @test space_isequal(tensor_product(a), gradedrange([U1(1) => 2, U1(2) => 3])) @test space_isequal(a ⊗ a, gradedrange([U1(2) => 4, U1(3) => 12, U1(4) => 9])) - @test space_isequal(a ⊗ OneToOne(), gradedrange([U1(1) => 2, U1(2) => 3])) - @test space_isequal(OneToOne() ⊗ a, gradedrange([U1(1) => 2, U1(2) => 3])) d = tensor_product(a, a, a) @test space_isequal(d, gradedrange([U1(3) => 8, U1(4) => 36, U1(5) => 54, U1(6) => 27])) @@ -98,8 +95,6 @@ end b = unmerged_tensor_product(ad) @test isdual(b) @test space_isequal(b, ad) - @test space_isequal(unmerged_tensor_product(ad, OneToOne()), ad) - @test space_isequal(unmerged_tensor_product(OneToOne(), ad), ad) b = tensor_product(ad) @test b isa GradedOneTo diff --git a/test/test_tensoralgebraext.jl b/test/test_tensoralgebraext.jl index f397142..16d1aca 100644 --- a/test/test_tensoralgebraext.jl +++ b/test/test_tensoralgebraext.jl @@ -1,9 +1,10 @@ using BlockArrays: Block, blocksize using BlockSparseArrays: BlockSparseArray using GradedArrays: - GradedArray, GradedMatrix, SU2, U1, dual, flip, gradedrange, sector_type, space_isequal + GradedArray, GradedMatrix, SU2, U1, dual, flip, gradedrange, sector_type, space_isequal, + trivial_axis using Random: randn! -using TensorAlgebra: contract, matricize, trivial_axis, unmatricize +using TensorAlgebra: contract, matricize, unmatricize using Test: @test, @testset function randn_blockdiagonal(elt::Type, axes::Tuple) From c20d7cd553df5a39000dbb77af8774d57f2ad8b9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 4 Dec 2025 11:48:11 -0500 Subject: [PATCH 4/9] Cleanup --- src/fusion.jl | 2 -- src/sectorrange.jl | 1 - src/tensoralgebra.jl | 12 ++---------- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/src/fusion.jl b/src/fusion.jl index 5298c1a..13fd0e3 100644 --- a/src/fusion.jl +++ b/src/fusion.jl @@ -3,7 +3,6 @@ using SplitApplyCombine: groupcount flip_dual(r::AbstractUnitRange) = isdual(r) ? flip(r) : r -# TODO: Overload `TensorAlgebra.tensor_product_axis` for `SectorFusion`. function tensor_product(sr1::SectorUnitRange, sr2::SectorUnitRange) return tensor_product(combine_styles(SymmetryStyle(sr1), SymmetryStyle(sr2)), sr1, sr2) end @@ -71,7 +70,6 @@ end # default to tensor_product unmerged_tensor_product(a1, a2) = a1 ⊗ a2 -# TODO: Use `TensorAlgebra.tensor_product_axis(::BlockReshapeFusion, ...)` instead. using BlockSparseArrays: mortar_axis function unmerged_tensor_product(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange) new_axes = map(splat(⊗), Iterators.flatten((Iterators.product(blocks(a1), blocks(a2)),))) diff --git a/src/sectorrange.jl b/src/sectorrange.jl index 1d78cde..65146ae 100644 --- a/src/sectorrange.jl +++ b/src/sectorrange.jl @@ -123,7 +123,6 @@ end # ============================= Tensor products ========================================== -# TODO: Overload `TensorAlgebra.tensor_product_axis` for `SectorFusion`. function tensor_product end const ⊗ = tensor_product tensor_product(s::SectorRange) = s diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 441363b..111cec6 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -51,16 +51,8 @@ unval(::Val{x}) where {x} = x function TensorAlgebra.matricize_axes( style::BlockReshapeFusion, a::GradedArray, ndims_codomain::Val ) - unval(ndims_codomain) ≤ ndims(a) || - throw(ArgumentError("Codomain length exceeds number of dimensions.")) - biperm = trivialbiperm(ndims_codomain, Val(ndims(a))) - axesblocks = blocks(axes(a)[biperm]) - init_axis = TensorAlgebra.trivial_axis(style, a) - axis_codomain, axis_domain = map(axesblocks) do axesblock - return reduce(axesblock; init = init_axis) do ax1, ax2 - return tensor_product_axis(style, ax1, ax2) - end - end + # TODO: Remove `TensorAlgebra.` once we delete `GradedArrays.matricize_axes`. + axis_codomain, axis_domain = @invoke TensorAlgebra.matricize_axes(style, a::AbstractArray, ndims_codomain) return axis_codomain, flip(axis_domain) end function TensorAlgebra.matricize( From 8f8f63fec395c12af6b8c69532e5c53d93b95d65 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 4 Dec 2025 16:54:19 -0500 Subject: [PATCH 5/9] Update to latest TensorAlgebra interface --- src/tensoralgebra.jl | 91 +++++++++++++++++++++++------------ test/Project.toml | 2 + test/test_show.jl | 50 ++++++++++++++----- test/test_tensoralgebraext.jl | 13 +++-- 4 files changed, 106 insertions(+), 50 deletions(-) diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 111cec6..227fc90 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -1,60 +1,92 @@ using BlockArrays: blocks, eachblockaxes1 using BlockSparseArrays: BlockSparseArray, blockreshape using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockReshapeFusion, - BlockedTuple, FusionStyle, ReshapeFusion, matricize, tensor_product_axis, unmatricize + BlockedTuple, FusionStyle, ReshapeFusion, matricize, matricize_axes, + tensor_product_axis, unmatricize struct SectorFusion <: FusionStyle end TensorAlgebra.FusionStyle(::Type{<:GradedArray}) = SectorFusion() -function trivial_axis(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange} +function trivial_gradedrange(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange} return trivial(first(t)) end # heterogeneous sectors -trivial_axis(t::Tuple{Vararg{AbstractGradedUnitRange}}) = ⊗(trivial.(t)...) +trivial_gradedrange(t::Tuple{Vararg{AbstractGradedUnitRange}}) = ⊗(trivial.(t)...) # trivial_axis from sector_type -function trivial_axis(::Type{S}) where {S <: SectorRange} +function trivial_gradedrange(::Type{S}) where {S <: SectorRange} return gradedrange([trivial(S) => 1]) end -# TODO: Use `TensorAlgebra.matricize_axes`. -function matricize_axes( - blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} - ) - @assert !isempty(blocked_axes) - default_axis = trivial_axis(Tuple(blocked_axes)) - codomain_axes, domain_axes = blocks(blocked_axes) - codomain_axis = unmerged_tensor_product(default_axis, codomain_axes...) - unflipped_domain_axis = unmerged_tensor_product(default_axis, domain_axes...) - return codomain_axis, flip(unflipped_domain_axis) -end +## # TODO: Use `TensorAlgebra.matricize_axes`. +## function matricize_axes( +## blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} +## ) +## @assert !isempty(blocked_axes) +## default_axis = trivial_axis(Tuple(blocked_axes)) +## codomain_axes, domain_axes = blocks(blocked_axes) +## codomain_axis = unmerged_tensor_product(default_axis, codomain_axes...) +## unflipped_domain_axis = unmerged_tensor_product(default_axis, domain_axes...) +## return codomain_axis, flip(unflipped_domain_axis) +## end -function TensorAlgebra.trivial_axis(::BlockReshapeFusion, a::GradedArray) - return trivial_axis(axes(a)) +function TensorAlgebra.trivial_axis( + ::BlockReshapeFusion, + a::GradedArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_gradedrange(axes(a)) end function TensorAlgebra.tensor_product_axis( - ::ReshapeFusion, r1::SectorUnitRange, r2::SectorUnitRange + ::ReshapeFusion, ::Val{:codomain}, r1::SectorUnitRange, r2::SectorUnitRange ) return r1 ⊗ r2 end function TensorAlgebra.tensor_product_axis( - ::BlockReshapeFusion, r1::AbstractGradedUnitRange, r2::AbstractGradedUnitRange + ::ReshapeFusion, ::Val{:domain}, r1::SectorUnitRange, r2::SectorUnitRange + ) + return flip(r1 ⊗ r2) +end +function tensor_product_gradedrange( + ::BlockReshapeFusion, + side::Val, + r1::AbstractUnitRange, + r2::AbstractUnitRange, ) (isone(first(r1)) && isone(first(r2))) || throw(ArgumentError("Only one-based axes are supported")) blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2)) - blockaxs = vec(map(splat(tensor_product_axis), blockaxpairs)) - return mortar_axis(blockaxs) + blockaxs = map(blockaxpairs) do (b1, b2) + return tensor_product_axis(ReshapeFusion(), side, b1, b2) + end + return mortar_axis(vec(blockaxs)) end -using TensorAlgebra: trivialbiperm -unval(::Val{x}) where {x} = x -function TensorAlgebra.matricize_axes( - style::BlockReshapeFusion, a::GradedArray, ndims_codomain::Val +function TensorAlgebra.tensor_product_axis( + style::BlockReshapeFusion, + side::Val{:codomain}, + r1::AbstractGradedUnitRange, + r2::AbstractGradedUnitRange, + ) + return tensor_product_gradedrange(style, side, r1, r2) +end +function TensorAlgebra.tensor_product_axis( + style::BlockReshapeFusion, + side::Val{:domain}, + r1::AbstractGradedUnitRange, + r2::AbstractGradedUnitRange, ) - # TODO: Remove `TensorAlgebra.` once we delete `GradedArrays.matricize_axes`. - axis_codomain, axis_domain = @invoke TensorAlgebra.matricize_axes(style, a::AbstractArray, ndims_codomain) - return axis_codomain, flip(axis_domain) + return tensor_product_gradedrange(style, side, r1, r2) end +## using TensorAlgebra: trivialbiperm +## unval(::Val{x}) where {x} = x +## function TensorAlgebra.matricize_axes( +## style::BlockReshapeFusion, a::GradedArray, ndims_codomain::Val +## ) +## # TODO: Remove `TensorAlgebra.` once we delete `GradedArrays.matricize_axes`. +## axis_codomain, axis_domain = @invoke TensorAlgebra.matricize_axes(style, a::AbstractArray, ndims_codomain) +## return axis_codomain, flip(axis_domain) +## end function TensorAlgebra.matricize( ::SectorFusion, a::AbstractArray, length_codomain::Val ) @@ -80,8 +112,7 @@ function TensorAlgebra.unmatricize( # First, fuse axes to get `sectormergesortperm`. # Then unpermute the blocks. - # TODO: Use `TensorAlgebra.matricize_axes`. - fused_axes = matricize_axes(blocked_axes) + fused_axes = matricize_axes(BlockReshapeFusion(), m, codomain_axes, domain_axes) blockperms = sectorsortperm.(fused_axes) sorted_axes = map((r, I) -> only(axes(r[I])), fused_axes, blockperms) diff --git a/test/Project.toml b/test/Project.toml index ff8bf84..8a10764 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,7 +16,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [sources] +BlockSparseArrays = {path = "/Users/mfishman/.julia/dev/BlockSparseArrays"} GradedArrays = {path = ".."} +TensorAlgebra = {path = "/Users/mfishman/.julia/dev/TensorAlgebra"} [compat] Aqua = "0.8.11" diff --git a/test/test_show.jl b/test/test_show.jl index 478fbad..24cd9eb 100644 --- a/test/test_show.jl +++ b/test/test_show.jl @@ -1,7 +1,3 @@ -# test show separately as it may behave differently locally and on CI. -# sometimes displays GradedArrays.GradedUnitRange and sometimes GradedUnitRange depending -# on exact setup - using BlockArrays: BlockedOneTo, BlockedUnitRange using GradedArrays: GradedArrays, Fib, GradedUnitRange, Ising, O2, SU2, SectorUnitRange, TrivialSector, U1, ×, gradedrange, sectorrange @@ -31,24 +27,39 @@ using Test: @test, @testset s = (A = U1(1),) × (B = SU2(2),) @test sprint(show, s) == "((A=$U1(1),) × (B=$SU2(2),))" s = TrivialSector() × U1(3) × SU2(1 / 2) - @test sprint(show, s) == "(TrivialSector() × $U1(3) × $SU2(1/2))" + @test sprint(show, s) == "($TrivialSector() × $U1(3) × $SU2(1/2))" end @testset "show GradedUnitRange" begin g1 = gradedrange(["x" => 2, "y" => 3, "z" => 2]) - @test sprint(show, g1) == "$GradedUnitRange[\"x\" => 2, \"y\" => 3, \"z\" => 2]" + @test sprint(show, g1) == "GradedUnitRange[\"x\" => 2, \"y\" => 3, \"z\" => 2]" @test sprint(show, MIME("text/plain"), g1) == - "$GradedUnitRange{Int64, $SectorUnitRange{Int64, String, Base.OneTo{Int64}}, $BlockedOneTo{Int64, Vector{Int64}}, Vector{Int64}}\n$SectorUnitRange x => 1:2\n$SectorUnitRange y => 3:5\n$SectorUnitRange z => 6:7" + "$GradedUnitRange{Int64, " * + "$SectorUnitRange{Int64, String, Base.OneTo{Int64}}, " * + "$BlockedOneTo{Int64, Vector{Int64}}, Vector{Int64}}\n" * + "SectorUnitRange x => 1:2\n" * + "SectorUnitRange y => 3:5\n" * + "SectorUnitRange z => 6:7" g2 = gradedrange(1, ["x" => 2, "y" => 3, "z" => 2]) @test sprint(show, g2) == "GradedUnitRange[\"x\" => 2, \"y\" => 3, \"z\" => 2]" @test sprint(show, MIME("text/plain"), g2) == - "$GradedUnitRange{Int64, $SectorUnitRange{Int64, String, Base.OneTo{Int64}}, $BlockedUnitRange{Int64, Vector{Int64}}, Vector{Int64}}\n$SectorUnitRange x => 1:2\n$SectorUnitRange y => 3:5\n$SectorUnitRange z => 6:7" + "$GradedUnitRange{Int64, " * + "$SectorUnitRange{Int64, String, Base.OneTo{Int64}}, " * + "$BlockedUnitRange{Int64, Vector{Int64}}, Vector{Int64}}\n" * + "SectorUnitRange x => 1:2\n" * + "SectorUnitRange y => 3:5\n" * + "SectorUnitRange z => 6:7" g1d = gradedrange(["x" => 2, "y" => 3, "z" => 2]; isdual = true) - @test sprint(show, g1d) == "$GradedUnitRange dual [\"x\" => 2, \"y\" => 3, \"z\" => 2]" + @test sprint(show, g1d) == "GradedUnitRange dual [\"x\" => 2, \"y\" => 3, \"z\" => 2]" @test sprint(show, MIME("text/plain"), g1d) == - "$GradedUnitRange{Int64, $SectorUnitRange{Int64, String, Base.OneTo{Int64}}, $BlockedOneTo{Int64, Vector{Int64}}, Vector{Int64}}\n$SectorUnitRange dual(x) => 1:2\n$SectorUnitRange dual(y) => 3:5\n$SectorUnitRange dual(z) => 6:7" + "$GradedUnitRange{Int64, " * + "$SectorUnitRange{Int64, String, Base.OneTo{Int64}}, " * + "$BlockedOneTo{Int64, Vector{Int64}}, Vector{Int64}}\n" * + "SectorUnitRange dual(x) => 1:2\n" * + "SectorUnitRange dual(y) => 3:5\n" * + "SectorUnitRange dual(z) => 6:7" end @testset "show GradedArray" begin @@ -58,15 +69,28 @@ end a = zeros(elt, r) a[1] = one(elt) @test sprint(show, "text/plain", a) == - "2-blocked 4-element GradedVector{$(elt), Vector{$(elt)}, …, …}:\n $(one(elt))\n $(zero(elt))\n ───\n ⋅ \n ⋅ " + "2-blocked 4-element GradedVector{$(elt), Vector{$(elt)}, …, …}:\n" * + " $(one(elt))\n $(zero(elt))\n ───\n ⋅ \n ⋅ " a = zeros(elt, r, r) a[1, 1] = one(elt) @test sprint(show, "text/plain", a) == - "2×2-blocked 4×4 GradedMatrix{$(elt), Matrix{$(elt)}, …, …}:\n $(one(elt)) $(zero(elt)) │ ⋅ ⋅ \n $(zero(elt)) $(zero(elt)) │ ⋅ ⋅ \n ──────────┼──────────\n ⋅ ⋅ │ ⋅ ⋅ \n ⋅ ⋅ │ ⋅ ⋅ " + "2×2-blocked 4×4 GradedMatrix{$(elt), Matrix{$(elt)}, …, …}:\n" * + " $(one(elt)) $(zero(elt)) │ ⋅ ⋅ \n" * + " $(zero(elt)) $(zero(elt)) │ ⋅ ⋅ \n" * + " ──────────┼──────────\n ⋅ ⋅ │ ⋅ ⋅ \n ⋅ ⋅ │ ⋅ ⋅ " a = zeros(elt, r, r, r) a[1, 1, 1] = one(elt) @test sprint(show, "text/plain", a) == - "2×2×2-blocked 4×4×4 GradedArray{$(elt), 3, Array{$(elt), 3}, …, …}:\n[:, :, 1] =\n $(one(elt)) $(zero(elt)) ⋅ ⋅ \n $(zero(elt)) $(zero(elt)) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 2] =\n $(zero(elt)) $(zero(elt)) ⋅ ⋅ \n $(zero(elt)) $(zero(elt)) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 3] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 4] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ " + "2×2×2-blocked 4×4×4 GradedArray{$(elt), 3, Array{$(elt), 3}, …, …}:\n" * + "[:, :, 1] =\n $(one(elt)) $(zero(elt)) ⋅ ⋅ \n" * + " $(zero(elt)) $(zero(elt)) ⋅ ⋅ \n" * + " ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n" * + "\n[:, :, 2] =\n $(zero(elt)) $(zero(elt)) ⋅ ⋅ \n" * + " $(zero(elt)) $(zero(elt)) ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n" * + " ⋅ ⋅ ⋅ ⋅ \n\n[:, :, 3] =\n ⋅ ⋅ ⋅ ⋅ \n" * + " ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n" * + "\n[:, :, 4] =\n ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ \n" * + " ⋅ ⋅ ⋅ ⋅ \n ⋅ ⋅ ⋅ ⋅ " end diff --git a/test/test_tensoralgebraext.jl b/test/test_tensoralgebraext.jl index 16d1aca..8ee663b 100644 --- a/test/test_tensoralgebraext.jl +++ b/test/test_tensoralgebraext.jl @@ -1,8 +1,7 @@ using BlockArrays: Block, blocksize using BlockSparseArrays: BlockSparseArray -using GradedArrays: - GradedArray, GradedMatrix, SU2, U1, dual, flip, gradedrange, sector_type, space_isequal, - trivial_axis +using GradedArrays: GradedArray, GradedMatrix, SU2, U1, dual, flip, gradedrange, + sector_type, space_isequal, trivial_gradedrange using Random: randn! using TensorAlgebra: contract, matricize, unmatricize using Test: @test, @testset @@ -20,14 +19,14 @@ end @testset "trivial_axis" begin g1 = gradedrange([U1(1) => 1, U1(2) => 1]) g2 = gradedrange([U1(-1) => 2, U1(2) => 1]) - @test space_isequal(trivial_axis((g1, g2)), gradedrange([U1(0) => 1])) - @test space_isequal(trivial_axis(sector_type(g1)), gradedrange([U1(0) => 1])) + @test space_isequal(trivial_gradedrange((g1, g2)), gradedrange([U1(0) => 1])) + @test space_isequal(trivial_gradedrange(sector_type(g1)), gradedrange([U1(0) => 1])) gN = gradedrange([(; N = U1(1)) => 1]) gS = gradedrange([(; S = SU2(1 // 2)) => 1]) gNS = gradedrange([(; N = U1(0), S = SU2(0)) => 1]) - @test space_isequal(trivial_axis(sector_type(gN)), gradedrange([(; N = U1(0)) => 1])) - @test space_isequal(trivial_axis((gN, gS)), gNS) + @test space_isequal(trivial_gradedrange(sector_type(gN)), gradedrange([(; N = U1(0)) => 1])) + @test space_isequal(trivial_gradedrange((gN, gS)), gNS) end const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) From 8fc2d9055e896930e313cfae0809bd278cfc33bc Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 4 Dec 2025 17:07:30 -0500 Subject: [PATCH 6/9] Define trivial_axis for domain properly --- src/tensoralgebra.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 227fc90..eeb174e 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -32,12 +32,22 @@ end function TensorAlgebra.trivial_axis( ::BlockReshapeFusion, + ::Val{:codomain}, a::GradedArray, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) return trivial_gradedrange(axes(a)) end +function TensorAlgebra.trivial_axis( + ::BlockReshapeFusion, + ::Val{:domain}, + a::GradedArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return dual(trivial_gradedrange(axes(a))) +end function TensorAlgebra.tensor_product_axis( ::ReshapeFusion, ::Val{:codomain}, r1::SectorUnitRange, r2::SectorUnitRange ) @@ -78,15 +88,6 @@ function TensorAlgebra.tensor_product_axis( ) return tensor_product_gradedrange(style, side, r1, r2) end -## using TensorAlgebra: trivialbiperm -## unval(::Val{x}) where {x} = x -## function TensorAlgebra.matricize_axes( -## style::BlockReshapeFusion, a::GradedArray, ndims_codomain::Val -## ) -## # TODO: Remove `TensorAlgebra.` once we delete `GradedArrays.matricize_axes`. -## axis_codomain, axis_domain = @invoke TensorAlgebra.matricize_axes(style, a::AbstractArray, ndims_codomain) -## return axis_codomain, flip(axis_domain) -## end function TensorAlgebra.matricize( ::SectorFusion, a::AbstractArray, length_codomain::Val ) From fd4cdec19af3655ee843861f6c054f8265f314b0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 4 Dec 2025 17:15:21 -0500 Subject: [PATCH 7/9] Cleanup --- src/tensoralgebra.jl | 62 +++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index eeb174e..3445615 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -8,28 +8,6 @@ struct SectorFusion <: FusionStyle end TensorAlgebra.FusionStyle(::Type{<:GradedArray}) = SectorFusion() -function trivial_gradedrange(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange} - return trivial(first(t)) -end -# heterogeneous sectors -trivial_gradedrange(t::Tuple{Vararg{AbstractGradedUnitRange}}) = ⊗(trivial.(t)...) -# trivial_axis from sector_type -function trivial_gradedrange(::Type{S}) where {S <: SectorRange} - return gradedrange([trivial(S) => 1]) -end - -## # TODO: Use `TensorAlgebra.matricize_axes`. -## function matricize_axes( -## blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} -## ) -## @assert !isempty(blocked_axes) -## default_axis = trivial_axis(Tuple(blocked_axes)) -## codomain_axes, domain_axes = blocks(blocked_axes) -## codomain_axis = unmerged_tensor_product(default_axis, codomain_axes...) -## unflipped_domain_axis = unmerged_tensor_product(default_axis, domain_axes...) -## return codomain_axis, flip(unflipped_domain_axis) -## end - function TensorAlgebra.trivial_axis( ::BlockReshapeFusion, ::Val{:codomain}, @@ -48,6 +26,16 @@ function TensorAlgebra.trivial_axis( ) return dual(trivial_gradedrange(axes(a))) end +function trivial_gradedrange(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange} + return trivial(first(t)) +end +# heterogeneous sectors +trivial_gradedrange(t::Tuple{Vararg{AbstractGradedUnitRange}}) = ⊗(trivial.(t)...) +# trivial_axis from sector_type +function trivial_gradedrange(::Type{S}) where {S <: SectorRange} + return gradedrange([trivial(S) => 1]) +end + function TensorAlgebra.tensor_product_axis( ::ReshapeFusion, ::Val{:codomain}, r1::SectorUnitRange, r2::SectorUnitRange ) @@ -58,20 +46,6 @@ function TensorAlgebra.tensor_product_axis( ) return flip(r1 ⊗ r2) end -function tensor_product_gradedrange( - ::BlockReshapeFusion, - side::Val, - r1::AbstractUnitRange, - r2::AbstractUnitRange, - ) - (isone(first(r1)) && isone(first(r2))) || - throw(ArgumentError("Only one-based axes are supported")) - blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2)) - blockaxs = map(blockaxpairs) do (b1, b2) - return tensor_product_axis(ReshapeFusion(), side, b1, b2) - end - return mortar_axis(vec(blockaxs)) -end function TensorAlgebra.tensor_product_axis( style::BlockReshapeFusion, side::Val{:codomain}, @@ -88,6 +62,22 @@ function TensorAlgebra.tensor_product_axis( ) return tensor_product_gradedrange(style, side, r1, r2) end +# TODO: Could this call out to a generic tensor_product_axis for AbstractBlockedUnitRange? +function tensor_product_gradedrange( + ::BlockReshapeFusion, + side::Val, + r1::AbstractUnitRange, + r2::AbstractUnitRange, + ) + (isone(first(r1)) && isone(first(r2))) || + throw(ArgumentError("Only one-based axes are supported")) + blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2)) + blockaxs = map(blockaxpairs) do (b1, b2) + return tensor_product_axis(ReshapeFusion(), side, b1, b2) + end + return mortar_axis(vec(blockaxs)) +end + function TensorAlgebra.matricize( ::SectorFusion, a::AbstractArray, length_codomain::Val ) From 395fcda67a9957f2fcc43e752cc40114f78c065a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 4 Dec 2025 17:42:36 -0500 Subject: [PATCH 8/9] Tweak overloads --- Project.toml | 2 +- src/tensoralgebra.jl | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f6062c3..a5419f8 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ MatrixAlgebraKit = "0.6" Random = "1.10" SUNRepresentations = "0.3" SplitApplyCombine = "1.2.3" -TensorAlgebra = "0.6" +TensorAlgebra = "0.6.2" TensorKitSectors = "0.1, 0.2" TypeParameterAccessors = "0.4" julia = "1.10" diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 3445615..8534d70 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -73,7 +73,9 @@ function tensor_product_gradedrange( throw(ArgumentError("Only one-based axes are supported")) blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2)) blockaxs = map(blockaxpairs) do (b1, b2) - return tensor_product_axis(ReshapeFusion(), side, b1, b2) + # TODO: Store a FusionStyle for the blocks in `BlockReshapeFusion` + # and use that here. + return tensor_product_axis(side, b1, b2) end return mortar_axis(vec(blockaxs)) end From 453a4c04cfd12fa32ebe6fd57f51fa5e031c93a0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 5 Dec 2025 09:39:30 -0500 Subject: [PATCH 9/9] Clean up test/Project.toml --- src/tensoralgebra.jl | 2 +- test/Project.toml | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/tensoralgebra.jl b/src/tensoralgebra.jl index 8534d70..505f880 100644 --- a/src/tensoralgebra.jl +++ b/src/tensoralgebra.jl @@ -24,7 +24,7 @@ function TensorAlgebra.trivial_axis( axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) - return dual(trivial_gradedrange(axes(a))) + return flip(trivial_gradedrange(axes(a))) end function trivial_gradedrange(t::Tuple{Vararg{G}}) where {G <: AbstractGradedUnitRange} return trivial(first(t)) diff --git a/test/Project.toml b/test/Project.toml index 8a10764..ff8bf84 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,9 +16,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [sources] -BlockSparseArrays = {path = "/Users/mfishman/.julia/dev/BlockSparseArrays"} GradedArrays = {path = ".."} -TensorAlgebra = {path = "/Users/mfishman/.julia/dev/TensorAlgebra"} [compat] Aqua = "0.8.11"