From 493274a95537fdf7a72cc8f4737a5adebc9dedf2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 5 Dec 2025 16:09:42 -0500 Subject: [PATCH 1/2] Define trivial_axis, better matricize overloads --- Project.toml | 4 +- .../KroneckerArraysTensorAlgebraExt.jl | 57 ++++++++++++++++--- test/Project.toml | 1 + test/test_tensoralgebra.jl | 8 ++- 4 files changed, 58 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 087e67c..d0c000f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.3.4" +version = "0.3.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -34,6 +34,6 @@ GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.10" MatrixAlgebraKit = "0.6" -TensorAlgebra = "0.6.2" +TensorAlgebra = "0.6.3" TypeParameterAccessors = "0.4.2" julia = "1.10" diff --git a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl index 921ab96..b23009a 100644 --- a/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl +++ b/ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl @@ -3,7 +3,7 @@ module KroneckerArraysTensorAlgebraExt using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, CartesianProductUnitRange, ⊗, cartesianrange, kroneckerfactors, kroneckerfactortypes using TensorAlgebra: TensorAlgebra, AbstractBlockPermutation, BlockedTrivialPermutation, - FusionStyle, matricize, tensor_product_axis, unmatricize + FusionStyle, matricize, tensor_product_axis, trivial_axis, unmatricize struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle a::A @@ -19,14 +19,57 @@ function TensorAlgebra.FusionStyle(A::Type{<:CartesianProductUnitRange}) return KroneckerFusion(FusionStyle.(kroneckerfactortypes(A))...) end +function TensorAlgebra.trivial_axis( + style::KroneckerFusion, side::Val{:codomain}, a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_kronecker(style, side, a, axes_codomain, axes_domain) +end +function TensorAlgebra.trivial_axis( + style::KroneckerFusion, side::Val{:domain}, a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + return trivial_kronecker(style, side, a, axes_codomain, axes_domain) +end +function trivial_kronecker( + style::FusionStyle, side::Val, a::AbstractArray, + axes_codomain::Tuple{Vararg{AbstractUnitRange}}, + axes_domain::Tuple{Vararg{AbstractUnitRange}}, + ) + style_a, style_b = kroneckerfactors(style) + a_a, a_b = kroneckerfactors(a) + axes_codomain_a = kroneckerfactors.(axes_codomain, 1) + axes_codomain_b = kroneckerfactors.(axes_codomain, 2) + axes_domain_a = kroneckerfactors.(axes_domain, 1) + axes_domain_b = kroneckerfactors.(axes_domain, 2) + ra = trivial_axis(style_a, side, a_a, axes_codomain_a, axes_domain_a) + rb = trivial_axis(style_b, side, a_b, axes_codomain_b, axes_domain_b) + return cartesianrange(ra, rb) +end + function TensorAlgebra.tensor_product_axis( - style::KroneckerFusion, r1::AbstractUnitRange, r2::AbstractUnitRange + style::KroneckerFusion, side::Val{:codomain}, + r1::AbstractUnitRange, r2::AbstractUnitRange, + ) + return tensor_product_kronecker(style, side, r1, r2) +end +function TensorAlgebra.tensor_product_axis( + style::KroneckerFusion, side::Val{:domain}, + r1::AbstractUnitRange, r2::AbstractUnitRange, + ) + return tensor_product_kronecker(style, side, r1, r2) +end +function tensor_product_kronecker( + style::KroneckerFusion, side::Val, + r1::AbstractUnitRange, r2::AbstractUnitRange, ) style_a, style_b = kroneckerfactors(style) r1a, r1b = kroneckerfactors(r1) r2a, r2b = kroneckerfactors(r2) - ra = tensor_product_axis(style_a, r1a, r2a) - rb = tensor_product_axis(style_b, r1b, r2b) + ra = tensor_product_axis(style_a, side, r1a, r2a) + rb = tensor_product_axis(style_b, side, r1b, r2b) return cartesianrange(ra, rb) end @@ -44,8 +87,7 @@ function TensorAlgebra.matricize( end function unmatricize_kronecker( - style::FusionStyle, - m::AbstractMatrix, + style::FusionStyle, m::AbstractMatrix, axes_codomain::Tuple{Vararg{AbstractUnitRange}}, axes_domain::Tuple{Vararg{AbstractUnitRange}}, ) @@ -60,8 +102,7 @@ function unmatricize_kronecker( return a1 ⊗ a2 end function TensorAlgebra.unmatricize( - style::KroneckerFusion, - m::AbstractMatrix, + style::KroneckerFusion, m::AbstractMatrix, codomain_axes::Tuple{Vararg{AbstractUnitRange}}, domain_axes::Tuple{Vararg{AbstractUnitRange}}, ) diff --git a/test/Project.toml b/test/Project.toml index 6717a93..d5ddf22 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -20,6 +20,7 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [sources] KroneckerArrays = {path = ".."} +TensorAlgebra = {path = "/Users/mfishman/.julia/dev/TensorAlgebra"} [compat] Adapt = "4" diff --git a/test/test_tensoralgebra.jl b/test/test_tensoralgebra.jl index ecbcd8e..fdb0efb 100644 --- a/test/test_tensoralgebra.jl +++ b/test/test_tensoralgebra.jl @@ -1,9 +1,12 @@ -using TensorAlgebra: matricize, tensor_product_axis, unmatricize +using TensorAlgebra: matricize, tensor_product_axis, trivial_axis, unmatricize using KroneckerArrays: ⊗, cartesianrange, kroneckerfactors, unproduct using Test: @test, @testset @testset "TensorAlgebraExt" begin @testset "tensor_product_axis" begin + r = cartesianrange(2, 3) + @test trivial_axis(r) ≡ cartesianrange(1, 1) + r1 = cartesianrange(2, 3) r2 = cartesianrange(4, 5) r = tensor_product_axis(r1, r2) @@ -15,7 +18,8 @@ using Test: @test, @testset @testset "matricize/unmatricize" begin a = randn(2, 2, 2) ⊗ randn(3, 3, 3) m = matricize(a, (1, 2), (3,)) - @test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) ⊗ matricize(kroneckerfactors(a, 2), (1, 2), (3,)) + @test m == matricize(kroneckerfactors(a, 1), (1, 2), (3,)) ⊗ + matricize(kroneckerfactors(a, 2), (1, 2), (3,)) @test unmatricize(m, (axes(a, 1), axes(a, 2)), (axes(a, 3),)) == a end end From d259ad3117262fa0028a011e04c70ca9898eb5d7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 5 Dec 2025 16:16:25 -0500 Subject: [PATCH 2/2] Cleanup test/Project.toml --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index d5ddf22..6717a93 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -20,7 +20,6 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [sources] KroneckerArrays = {path = ".."} -TensorAlgebra = {path = "/Users/mfishman/.julia/dev/TensorAlgebra"} [compat] Adapt = "4"