Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.6.1"
version = "0.6.2"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
92 changes: 83 additions & 9 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,104 @@ FusionStyle(x) = FusionStyle(typeof(x))
FusionStyle(T::Type) = throw(MethodError(FusionStyle, (T,)))

# ======================================= misc ========================================
trivial_axis(style::FusionStyle, a::AbstractArray) = trivial_axis(ReshapeFusion(), a)
function trivial_axis(
style::FusionStyle,
::Val{:codomain},
a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return trivial_axis(style, a, axes_codomain, axes_domain)
end
function trivial_axis(
style::FusionStyle,
::Val{:domain},
a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return trivial_axis(style, a, axes_codomain, axes_domain)
end
function trivial_axis(
style::FusionStyle,
a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
return trivial_axis(style, a)
end
function trivial_axis(style::FusionStyle, a::AbstractArray)
return trivial_axis(ReshapeFusion(), a)
end

# Tensor product two spaces (ranges) together based on a fusion style.
function tensor_product_axis(
style::FusionStyle, ::Val{:codomain}, r1::AbstractUnitRange, r2::AbstractUnitRange
)
return tensor_product_axis(style, r1, r2)
end
function tensor_product_axis(
style::FusionStyle, ::Val{:domain}, r1::AbstractUnitRange, r2::AbstractUnitRange
)
return tensor_product_axis(style, r1, r2)
end
function tensor_product_axis(::FusionStyle, r1::AbstractUnitRange, r2::AbstractUnitRange)
return tensor_product_axis(ReshapeFusion(), r1, r2)
end
function tensor_product_axis(side::Val, r1::AbstractUnitRange, r2::AbstractUnitRange)
style = tensor_product_fusionstyle(r1, r2)
return tensor_product_axis(style, side, r1, r2)
end
function tensor_product_axis(r1::AbstractUnitRange, r2::AbstractUnitRange)
style = tensor_product_fusionstyle(r1, r2)
return tensor_product_axis(style, r1, r2)
end
function tensor_product_fusionstyle(r1::AbstractUnitRange, r2::AbstractUnitRange)
style1 = FusionStyle(r1)
style2 = FusionStyle(r2)
style1 == style2 || error("Styles must match.")
return tensor_product_axis(style1, r1, r2)
return style1
end

function fused_axis(
style::FusionStyle,
side::Val{:codomain},
a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
init_axis = trivial_axis(style, side, a, axes_codomain, axes_domain)
return reduce(axes_codomain; init = init_axis) do ax1, ax2
return tensor_product_axis(style, side, ax1, ax2)
end
end
function fused_axis(
style::FusionStyle,
side::Val{:domain},
a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
init_axis = trivial_axis(style, side, a, axes_codomain, axes_domain)
return reduce(axes_domain; init = init_axis) do ax1, ax2
return tensor_product_axis(style, side, ax1, ax2)
end
end
function matricize_axes(
style::FusionStyle,
a::AbstractArray,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
axis_codomain = fused_axis(style, Val(:codomain), a, axes_codomain, axes_domain)
axis_domain = fused_axis(style, Val(:domain), a, axes_codomain, axes_domain)
return axis_codomain, axis_domain
end
function matricize_axes(style::FusionStyle, a::AbstractArray, ndims_codomain::Val)
unval(ndims_codomain) ≤ ndims(a) ||
throw(ArgumentError("Codomain length exceeds number of dimensions."))
biperm = trivialbiperm(ndims_codomain, Val(ndims(a)))
axesblocks = blocks(axes(a)[biperm])
init_axis = trivial_axis(style, a)
return map(axesblocks) do axesblock
return reduce(axesblock; init = init_axis) do ax1, ax2
return tensor_product_axis(style, ax1, ax2)
end
end
return matricize_axes(style, a, blocks(axes(a)[biperm])...)
end
function matricize_axes(a::AbstractArray, ndims_codomain::Val)
return matricize_axes(FusionStyle(a), a, ndims_codomain)
Expand Down
Loading