From 495a4f08a5ff73d41612d91eaf1e3868b2af3984 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 11:44:09 -0500 Subject: [PATCH 1/9] Refactor matricize/unmatricize --- Project.toml | 2 +- src/blockedpermutation.jl | 11 +- src/contract/contract.jl | 29 +++- src/contract/contract_matricize/contract.jl | 9 +- src/matricize.jl | 162 +++++++++++++++----- test/Project.toml | 3 + 6 files changed, 164 insertions(+), 52 deletions(-) diff --git a/Project.toml b/Project.toml index fbcdabc..fd37d20 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.4.6" +version = "0.4.7" authors = ["ITensor developers and contributors"] [deps] diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index c4f4fd6..dd942dc 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -8,7 +8,7 @@ function istrivialperm(t::Tuple) return t == trivialperm(length(t)) end -value(::Val{N}) where {N} = N +unval(::Val{N}) where {N} = N _flatten_tuples(t::Tuple) = t function _flatten_tuples(t1::Tuple, t2::Tuple, trest::Tuple...) @@ -87,7 +87,7 @@ function blockedpermvcat( end function blockedpermvcat(len::Val, permblocks::Tuple{Vararg{Int}}...) - value(len) != sum(length.(permblocks); init = 0) && + unval(len) != sum(length.(permblocks); init = 0) && throw(ArgumentError("Invalid total length")) return permmortar(Tuple(permblocks)) end @@ -97,7 +97,7 @@ function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}}) end function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}}) - return value(vallength) + return unval(vallength) end # blockedpermvcat((4, 3), .., 1) == blockedpermvcat((4, 3), (2,), (1,)) @@ -199,8 +199,11 @@ end blockedperm(tp::BlockedTrivialPermutation) = tp +function blockedtrivialperm(blocklengths::Tuple{Vararg{Val}}) + return BlockedTrivialPermutation{length(blocklengths), unval.(blocklengths)}() +end function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}}) - return BlockedTrivialPermutation{length(blocklengths), blocklengths}() + return blockedtrivialperm(Val.(blocklengths)) end function trivialperm(blockedperm::AbstractBlockTuple) diff --git a/src/contract/contract.jl b/src/contract/contract.jl index 5d6fcf6..88c8335 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -4,9 +4,28 @@ abstract type Algorithm end Algorithm(alg::Algorithm) = alg -struct Matricize <: Algorithm end +struct Matricize{Style} <: Algorithm + fusion_style::Style +end -default_contract_alg() = Matricize() +function default_contract_alg(a1::AbstractArray, labels1, a2::AbstractArray, labels2) + style1 = FusionStyle(a1) + style2 = FusionStyle(a2) + style1 == style2 || error("Styles must match.") + return Matricize(style1) +end +function default_contractadd!_alg( + a_dest::AbstractArray, labels_dest, + a1::AbstractArray, labels1, + a2::AbstractArray, labels2, + α::Number, β::Number, + ) + style_dest = FusionStyle(a_dest) + style1 = FusionStyle(a1) + style2 = FusionStyle(a2) + style_dest == style1 == style2 || error("Styles must match.") + return Matricize(style_dest) +end # Required interface if not using # matricized contraction. @@ -29,7 +48,7 @@ function contract( labels1, a2::AbstractArray, labels2; - alg = default_contract_alg(), + alg = default_contract_alg(a1, labels1, a2, labels2), kwargs..., ) return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...) @@ -48,7 +67,7 @@ function contract( labels1, a2::AbstractArray, labels2; - alg = default_contract_alg(), + alg = default_contract_alg(a1, labels1, a2, labels2), kwargs..., ) return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...) @@ -75,7 +94,7 @@ function contractadd!( labels2, α::Number, β::Number; - alg = default_contract_alg(), + alg = default_contractadd!_alg(a_dest, labels_dest, a1, labels1, a2, labels2, α, β), kwargs..., ) contractadd!( diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index f7207ee..752ef7c 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -1,7 +1,7 @@ using LinearAlgebra: mul! function contractadd!( - ::Matricize, + alg::Matricize, a_dest::AbstractArray, biperm_dest::AbstractBlockPermutation{2}, a1::AbstractArray, @@ -12,11 +12,10 @@ function contractadd!( β::Number, ) invbiperm = biperm(invperm(biperm_dest), length_codomain(biperm1)) - check_input(contract, a_dest, invbiperm, a1, biperm1, a2, biperm2) - a1_mat = matricize(a1, biperm1) - a2_mat = matricize(a2, biperm2) + a1_mat = matricize(alg.fusion_style, a1, biperm1) + a2_mat = matricize(alg.fusion_style, a2, biperm2) a_dest_mat = a1_mat * a2_mat - unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β) + unmatricizeadd!(alg.fusion_style, a_dest, a_dest_mat, invbiperm, α, β) return a_dest end diff --git a/src/matricize.jl b/src/matricize.jl index 705d303..c04d7e2 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -21,23 +21,45 @@ trivial_axis(::Tuple{}) = Base.OneTo(1) trivial_axis(::Tuple{Vararg{AbstractUnitRange}}) = Base.OneTo(1) trivial_axis(::Tuple{Vararg{AbstractBlockedUnitRange}}) = blockedrange([1]) +# Inner version takes a list of sub-permutations, overload this one if needed. function fuseaxes( - axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation + axes::Tuple{Vararg{AbstractUnitRange}}, lengths::Val... ) - axesblocks = blocks(axes[blockedperm]) + axesblocks = blocks(axes[blockedtrivialperm(lengths)]) return map(block -> isempty(block) ? trivial_axis(axes) : ⊗(block...), axesblocks) end +# Inner version takes a list of sub-permutations, overload this one if needed. +function fuseaxes( + axes::Tuple{Vararg{AbstractUnitRange}}, permblocks::Tuple{Vararg{Int}}... + ) + axes′ = map(d -> axes[d], permmortar(permblocks)) + return fuseaxes(axes′, Val.(length.(permblocks))...) +end + +function fuseaxes( + axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation + ) + return fuseaxes(axes, blocks(blockedperm)...) +end + +# Inner version takes a list of sub-permutations, overload this one if needed. +function permuteblockeddims(a::AbstractArray, perm1, perm2) + return _permutedims(a, (perm1..., perm2...)) +end +function permuteblockeddims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2) + return _permutedims!(a_dest, a_src, (perm1..., perm2...)) +end + # TODO remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. -function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation) - return _permutedims(a, Tuple(biperm)) +function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation{2}) + return permuteblockeddims(a, blocks(biperm)...) end - function permuteblockeddims!( - a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation + a_dest::AbstractArray, a_src::AbstractArray, biperm::AbstractBlockPermutation{2} ) - return _permutedims!(a, b, Tuple(biperm)) + return permuteblockeddims!(a_dest, a_src, blocks(biperm)...) end # ===================================== matricize ======================================== @@ -45,53 +67,100 @@ end # matrix factorizations assume copy # maybe: copy=false kwarg -function matricize(a::AbstractArray, biperm_dest::AbstractBlockPermutation{2}) - ndims(a) == length(biperm_dest) || throw(ArgumentError("Invalid bipermutation")) - return matricize(FusionStyle(a), a, biperm_dest) +function matricize(a::AbstractArray, length1::Val, length2::Val) + return matricize(FusionStyle(a), a, length1, length2) +end +# This is the primary function that should be overloaded for new fusion styles. +# This assumes the permutation was already performed. +function matricize(style::FusionStyle, a::AbstractArray, length1::Val, length2::Val) + return throw( + MethodError( + matricize, Tuple{typeof(style), typeof(a), typeof(length1), typeof(length2)} + ) + ) end function matricize( - style::FusionStyle, a::AbstractArray, biperm_dest::AbstractBlockPermutation{2} + a::AbstractArray, permblock1::Tuple{Vararg{Int}}, permblock2::Tuple{Vararg{Int}} ) - a_perm = permuteblockeddims(a, biperm_dest) - return matricize(style, a_perm, trivialperm(biperm_dest)) + return matricize(FusionStyle(a), a, permblock1, permblock2) end - +# This is a more advanced version to overload where the permutation is actually performed. function matricize( - style::FusionStyle, a::AbstractArray, biperm_dest::BlockedTrivialPermutation{2} + style::FusionStyle, a::AbstractArray, + permblock1::NTuple{N1, Int}, permblock2::NTuple{N2, Int} + ) where {N1, N2} + ndims(a) == length(permblock1) + length(permblock2) || + throw(ArgumentError("Invalid bipermutation")) + a_perm = permuteblockeddims(a, permblock1, permblock2) + return matricize(style, a_perm, Val(length(permblock1)), Val(length(permblock2))) +end + +# Process inputs such as `EllipsisNotation.Ellipsis`. +function to_permblocks(a::AbstractArray, permblocks::NTuple{2, Tuple{Vararg{Int}}}) + isperm((permblocks[1]..., permblocks[2]...)) || + throw(ArgumentError("Invalid bipermutation")) + return permblocks +end +# Like `setcomplement` is like `setdiff` but assumes t2 ⊆ t1. +function tuplesetcomplement(t1::NTuple{N1}, t2::NTuple{N2}) where {N1, N2} + t2 ⊆ t1 || throw(ArgumentError("t2 must be a subset of t1")) + return NTuple{N1 - N2}(setdiff(t1, t2)) +end +function to_permblocks( + a::AbstractArray, permblocks::Tuple{Tuple{Ellipsis}, Tuple{Vararg{Int}}} + ) + permblocks1 = tuplesetcomplement(ntuple(identity, ndims(a)), permblocks[2]) + return (permblocks1, permblocks[2]) +end +function to_permblocks( + a::AbstractArray, permblocks::Tuple{Tuple{Vararg{Int}}, Tuple{Ellipsis}} ) - return throw(MethodError(matricize, Tuple{typeof(style), typeof(a), typeof(biperm_dest)})) + permblocks2 = tuplesetcomplement(ntuple(identity, ndims(a)), permblocks[1]) + return (permblocks[1], permblocks2) +end +function matricize(a::AbstractArray, permblock1, permblock2) + return matricize(FusionStyle(a), a, permblock1, permblock2) +end +function matricize(style::FusionStyle, a::AbstractArray, permblock1, permblock2) + return matricize(style, a, to_permblocks(a, (permblock1, permblock2))...) end -# default is reshape +function matricize(a::AbstractArray, biperm_dest::AbstractBlockPermutation{2}) + return matricize(FusionStyle(a), a, biperm_dest) +end function matricize( - ::ReshapeFusion, a::AbstractArray, biperm_dest::BlockedTrivialPermutation{2} + style::FusionStyle, a::AbstractArray, biperm_dest::AbstractBlockPermutation{2} ) - new_axes = fuseaxes(axes(a), biperm_dest) - return reshape(a, new_axes...) + return matricize(style, a, blocks(biperm_dest)...) end -function matricize(a::AbstractArray, permblock1::Tuple, permblock2::Tuple) - return matricize(a, blockedpermvcat(permblock1, permblock2; length = Val(ndims(a)))) +# default is reshape +function matricize(::ReshapeFusion, a::AbstractArray, length1::Val, length2::Val) + return reshape(a, fuseaxes(axes(a), length1, length2)...) end # ==================================== unmatricize ======================================= function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2}) - length(axes_dest) == length(invbiperm) || - throw(ArgumentError("axes do not match permutation")) return unmatricize(FusionStyle(m), m, axes_dest, invbiperm) end - function unmatricize( - ::FusionStyle, m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2} + style::FusionStyle, m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2} ) + length(axes_dest) == length(invbiperm) || + throw(ArgumentError("axes do not match permutation")) blocked_axes = axes_dest[invbiperm] - a12 = unmatricize(m, blocked_axes) + a12 = unmatricize(style, m, blocked_axes) biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest)) - return permuteblockeddims(a12, biperm_dest) end +function unmatricize( + m::AbstractMatrix, + blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}}, + ) + return unmatricize(FusionStyle(m), m, blocked_axes) +end function unmatricize( ::ReshapeFusion, m::AbstractMatrix, @@ -100,30 +169,49 @@ function unmatricize( return reshape(m, Tuple(blocked_axes)...) end -function unmatricize(m::AbstractMatrix, blocked_axes) - return unmatricize(FusionStyle(m), m, blocked_axes) -end - function unmatricize( m::AbstractMatrix, codomain_axes::Tuple{Vararg{AbstractUnitRange}}, domain_axes::Tuple{Vararg{AbstractUnitRange}}, ) + return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes) +end +function unmatricize( + style::FusionStyle, m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, + ) blocked_axes = tuplemortar((codomain_axes, domain_axes)) - return unmatricize(m, blocked_axes) + return unmatricize(style, m, blocked_axes) end -function unmatricize!(a_dest, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2}) +function unmatricize!( + a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2} + ) + return unmatricize!(FusionStyle(m), a_dest, m, invbiperm) +end +function unmatricize!( + style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, + invbiperm::AbstractBlockPermutation{2}, + ) ndims(a_dest) == length(invbiperm) || throw(ArgumentError("destination does not match permutation")) blocked_axes = axes(a_dest)[invbiperm] - a_perm = unmatricize(m, blocked_axes) + a_perm = unmatricize(style, m, blocked_axes) biperm_dest = biperm(invperm(invbiperm), length_codomain(axes(a_dest))) return permuteblockeddims!(a_dest, a_perm, biperm_dest) end -function unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β) - a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm) +function unmatricizeadd!( + a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2}, + α::Number, β::Number + ) + return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invbiperm, α, β) +end +function unmatricizeadd!( + style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, invbiperm, α, β + ) + a12 = unmatricize(style, m, axes(a_dest), invbiperm) a_dest .= α .* a12 .+ β .* a_dest return a_dest end diff --git a/test/Project.toml b/test/Project.toml index aa511c9..9905c69 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +[sources] +TensorAlgebra = {path = ".."} + [compat] Aqua = "0.8.9" BlockArrays = "1.6.1" From 64191ab6978bceae218a4ab2ac9f8a2c70c09a4e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 11:46:05 -0500 Subject: [PATCH 2/9] Fix tests --- src/contract/contract.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/contract/contract.jl b/src/contract/contract.jl index 88c8335..a943628 100644 --- a/src/contract/contract.jl +++ b/src/contract/contract.jl @@ -7,6 +7,7 @@ Algorithm(alg::Algorithm) = alg struct Matricize{Style} <: Algorithm fusion_style::Style end +Matricize() = Matricize(ReshapeFusion()) function default_contract_alg(a1::AbstractArray, labels1, a2::AbstractArray, labels2) style1 = FusionStyle(a1) From d2c7769c07dfad8da29fc682745a17aa696002ba Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 12:23:55 -0500 Subject: [PATCH 3/9] Refactor unmatricize --- src/matricize.jl | 119 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 92 insertions(+), 27 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index c04d7e2..f6809f2 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -141,18 +141,36 @@ function matricize(::ReshapeFusion, a::AbstractArray, length1::Val, length2::Val end # ==================================== unmatricize ======================================= -function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2}) - return unmatricize(FusionStyle(m), m, axes_dest, invbiperm) +function unmatricize( + m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, + ) + return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes) end +# This is the primary function that should be overloaded for new fusion styles. function unmatricize( - style::FusionStyle, m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2} + style::FusionStyle, m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, ) - length(axes_dest) == length(invbiperm) || - throw(ArgumentError("axes do not match permutation")) - blocked_axes = axes_dest[invbiperm] - a12 = unmatricize(style, m, blocked_axes) - biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest)) - return permuteblockeddims(a12, biperm_dest) + return throw( + MethodError( + unmatricize, + Tuple{ + typeof(style), typeof(m), typeof(codomain_axes), typeof(domain_axes) + }, + ) + ) +end + +# Implementation using reshape. +function unmatricize( + style::ReshapeFusion, m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, + ) + return reshape(m, (codomain_axes..., domain_axes...)) end function unmatricize( @@ -162,38 +180,53 @@ function unmatricize( return unmatricize(FusionStyle(m), m, blocked_axes) end function unmatricize( - ::ReshapeFusion, + style::FusionStyle, m::AbstractMatrix, blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}}, ) - return reshape(m, Tuple(blocked_axes)...) + return unmatricize(style, m, blocks(blocked_axes)...) end function unmatricize( - m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, + m::AbstractMatrix, axes_dest, + invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, ) - return unmatricize(FusionStyle(m), m, codomain_axes, domain_axes) + return unmatricize(FusionStyle(m), m, axes_dest, invperm1, invperm2) end function unmatricize( - style::FusionStyle, m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, + style::FusionStyle, m::AbstractMatrix, axes_dest, + invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, ) - blocked_axes = tuplemortar((codomain_axes, domain_axes)) - return unmatricize(style, m, blocked_axes) + invbiperm = permmortar((invperm1, invperm2)) + length(axes_dest) == length(invbiperm) || + throw(ArgumentError("axes do not match permutation")) + blocked_axes = axes_dest[invbiperm] + a12 = unmatricize(style, m, blocked_axes) + biperm_dest = biperm(invperm(invbiperm), length_codomain(axes_dest)) + return permuteblockeddims(a12, biperm_dest) +end + +function unmatricize(m::AbstractMatrix, axes_dest, invbiperm::AbstractBlockPermutation{2}) + return unmatricize(FusionStyle(m), m, axes_dest, invbiperm) +end +function unmatricize( + style::FusionStyle, m::AbstractMatrix, axes_dest, + invbiperm::AbstractBlockPermutation{2} + ) + return unmatricize(style, m, axes_dest, blocks(invbiperm)...) end function unmatricize!( - a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2} + a_dest::AbstractArray, m::AbstractMatrix, + invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, ) - return unmatricize!(FusionStyle(m), a_dest, m, invbiperm) + return unmatricize!(FusionStyle(m), a_dest, m, invperm1, invperm2) end function unmatricize!( style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, - invbiperm::AbstractBlockPermutation{2}, + invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, ) + invbiperm = permmortar((invperm1, invperm2)) ndims(a_dest) == length(invbiperm) || throw(ArgumentError("destination does not match permutation")) blocked_axes = axes(a_dest)[invbiperm] @@ -202,16 +235,48 @@ function unmatricize!( return permuteblockeddims!(a_dest, a_perm, biperm_dest) end +function unmatricize!( + a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2} + ) + return unmatricize!(FusionStyle(m), a_dest, m, invbiperm) +end +function unmatricize!( + style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, + invbiperm::AbstractBlockPermutation{2}, + ) + return unmatricize!(style, a_dest, m, blocks(invbiperm)...) +end + function unmatricizeadd!( - a_dest::AbstractArray, m::AbstractMatrix, invbiperm::AbstractBlockPermutation{2}, + a_dest::AbstractArray, m::AbstractMatrix, + invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, α::Number, β::Number ) - return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invbiperm, α, β) + return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invperm1, invperm2, α, β) end function unmatricizeadd!( - style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, invbiperm, α, β + style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, + invperm1::Tuple{Vararg{Int}}, invperm2::Tuple{Vararg{Int}}, + α::Number, β::Number, ) - a12 = unmatricize(style, m, axes(a_dest), invbiperm) + a12 = unmatricize(style, m, axes(a_dest), invperm1, invperm2) a_dest .= α .* a12 .+ β .* a_dest return a_dest end + +function unmatricizeadd!( + a_dest::AbstractArray, m::AbstractMatrix, + invbiperm::AbstractBlockPermutation{2}, + α::Number, β::Number + ) + return unmatricizeadd!(FusionStyle(a_dest), a_dest, m, invbiperm, α, β) +end +function unmatricizeadd!( + style::FusionStyle, a_dest::AbstractArray, m::AbstractMatrix, + invbiperm::AbstractBlockPermutation{2}, + α::Number, β::Number, + ) + return unmatricizeadd!( + style, a_dest, m, blocks(invbiperm)..., α, β + ) +end From 0e8be0a906c153b99cad894d20f50a68ab15539a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 12:24:10 -0500 Subject: [PATCH 4/9] Format --- src/matricize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matricize.jl b/src/matricize.jl index f6809f2..294ffa8 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -158,7 +158,7 @@ function unmatricize( MethodError( unmatricize, Tuple{ - typeof(style), typeof(m), typeof(codomain_axes), typeof(domain_axes) + typeof(style), typeof(m), typeof(codomain_axes), typeof(domain_axes), }, ) ) From a0b89207d67ce99fde243fa587410043be6b6e1c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 12:30:02 -0500 Subject: [PATCH 5/9] Delete stale test dep --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 9905c69..e1ed25e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,7 +12,6 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" -TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" [sources] TensorAlgebra = {path = ".."} From fce6711e9288ef23dfa89ba0b10ee15472cd9cf1 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 13:09:54 -0500 Subject: [PATCH 6/9] Reorg --- src/matricize.jl | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index 294ffa8..b13e8d0 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -8,14 +8,9 @@ using .BaseExtensions: _permutedims, _permutedims! # ===================================== FusionStyle ====================================== abstract type FusionStyle end -struct ReshapeFusion <: FusionStyle end - FusionStyle(x) = FusionStyle(typeof(x)) FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,))) -# Defaults to ReshapeFusion, a simple reshape -FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion() - # ======================================= misc ======================================== trivial_axis(::Tuple{}) = Base.OneTo(1) trivial_axis(::Tuple{Vararg{AbstractUnitRange}}) = Base.OneTo(1) @@ -135,11 +130,6 @@ function matricize( return matricize(style, a, blocks(biperm_dest)...) end -# default is reshape -function matricize(::ReshapeFusion, a::AbstractArray, length1::Val, length2::Val) - return reshape(a, fuseaxes(axes(a), length1, length2)...) -end - # ==================================== unmatricize ======================================= function unmatricize( m::AbstractMatrix, @@ -164,15 +154,6 @@ function unmatricize( ) end -# Implementation using reshape. -function unmatricize( - style::ReshapeFusion, m::AbstractMatrix, - codomain_axes::Tuple{Vararg{AbstractUnitRange}}, - domain_axes::Tuple{Vararg{AbstractUnitRange}}, - ) - return reshape(m, (codomain_axes..., domain_axes...)) -end - function unmatricize( m::AbstractMatrix, blocked_axes::BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}}, @@ -280,3 +261,17 @@ function unmatricizeadd!( style, a_dest, m, blocks(invbiperm)..., α, β ) end + +# Defaults to ReshapeFusion, a simple reshape +struct ReshapeFusion <: FusionStyle end +FusionStyle(::Type{<:AbstractArray}) = ReshapeFusion() +function matricize(style::ReshapeFusion, a::AbstractArray, length1::Val, length2::Val) + return reshape(a, fuseaxes(axes(a), length1, length2)) +end +function unmatricize( + style::ReshapeFusion, m::AbstractMatrix, + codomain_axes::Tuple{Vararg{AbstractUnitRange}}, + domain_axes::Tuple{Vararg{AbstractUnitRange}}, + ) + return reshape(m, (codomain_axes..., domain_axes...)) +end From 77fa1dd300e4cccf1e77372f45c0c4095c4c63c1 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 14:14:44 -0500 Subject: [PATCH 7/9] Mark as breaking --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index fd37d20..7cd4145 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" -version = "0.4.7" authors = ["ITensor developers and contributors"] +version = "0.5.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" @@ -22,11 +22,11 @@ TensorAlgebraTensorOperationsExt = "TensorOperations" [compat] ArrayLayouts = "1.10.4" BlockArrays = "1.7.2" -EllipsisNotation = "1.8.0" +EllipsisNotation = "1.8" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6" TensorOperations = "5" TensorProducts = "0.1.5" -TupleTools = "1.6.0" +TupleTools = "1.6" TypeParameterAccessors = "0.2.1, 0.3, 0.4" julia = "1.10" From f993a81703d7797a644abae808606032a67ef67a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 14:24:05 -0500 Subject: [PATCH 8/9] Bump subdir versions --- docs/Project.toml | 2 +- examples/Project.toml | 2 +- test/Project.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index c23ecf2..190f565 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [compat] Documenter = "1.8.1" Literate = "2.20.1" -TensorAlgebra = "0.4" +TensorAlgebra = "0.5" diff --git a/examples/Project.toml b/examples/Project.toml index 26ae9f9..fa07a52 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -2,4 +2,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [compat] -TensorAlgebra = "0.4" +TensorAlgebra = "0.5" diff --git a/test/Project.toml b/test/Project.toml index e1ed25e..9bdf530 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -26,7 +26,7 @@ Random = "1.10" SafeTestsets = "0.1" StableRNGs = "1.0.2" Suppressor = "0.2" -TensorAlgebra = "0.4" +TensorAlgebra = "0.5" TensorOperations = "5.1.4" Test = "1.10" TestExtras = "0.3.1" From 1fb114f2734623241296e0d8ec36e6f5cdb047a4 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 24 Nov 2025 14:35:57 -0500 Subject: [PATCH 9/9] Add missing sources --- docs/Project.toml | 3 +++ examples/Project.toml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 190f565..a3bb587 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" +[sources] +TensorAlgebra = {path = ".."} + [compat] Documenter = "1.8.1" Literate = "2.20.1" diff --git a/examples/Project.toml b/examples/Project.toml index fa07a52..8b00c47 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,5 +1,8 @@ [deps] TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" +[sources] +TensorAlgebra = {path = ".."} + [compat] TensorAlgebra = "0.5"