Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.5.4"
version = "0.6.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

Expand All @@ -26,7 +25,6 @@ EllipsisNotation = "1.8"
LinearAlgebra = "1.10"
MatrixAlgebraKit = "0.2, 0.3, 0.4, 0.5, 0.6"
TensorOperations = "5"
TensorProducts = "0.1.5"
TupleTools = "1.6"
TypeParameterAccessors = "0.2.1, 0.3, 0.4"
julia = "1.10"
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ TensorAlgebra = {path = ".."}
[compat]
Documenter = "1.8.1"
Literate = "2.20.1"
TensorAlgebra = "0.5"
TensorAlgebra = "0.6"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorAlgebra = {path = ".."}

[compat]
TensorAlgebra = "0.5"
TensorAlgebra = "0.6"
20 changes: 3 additions & 17 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,14 @@
module TensorAlgebra

export contract,
contract!,
eigen,
eigvals,
factorize,
left_null,
left_orth,
left_polar,
lq,
qr,
right_null,
right_orth,
right_polar,
orth,
polar,
svd,
svdvals
export contract, contract!, eigen, eigvals, factorize, left_null, left_orth, left_polar,
lq, qr, right_null, right_orth, right_polar, orth, polar, svd, svdvals

include("MatrixAlgebra.jl")
include("blockedtuple.jl")
include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("matricize.jl")
include("blockarrays.jl")
include("contract/contract.jl")
include("contract/output_labels.jl")
include("contract/blockedperms.jl")
Expand Down
69 changes: 69 additions & 0 deletions src/blockarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using BlockArrays: AbstractBlockArray, AbstractBlockedUnitRange, BlockedArray, blockedrange,
eachblockaxes1, mortar

struct BlockReshapeFusion <: FusionStyle end
FusionStyle(::Type{<:AbstractBlockArray}) = BlockReshapeFusion()

trivial_axis(::BlockReshapeFusion, a::AbstractArray) = blockedrange([1])
function mortar_axis(axs)
all(isone ∘ first, axs) ||
throw(ArgumentError("Only one-based axes are supported"))
return blockedrange(length.(axs))
end
function tensor_product_axis(
::BlockReshapeFusion, r1::AbstractUnitRange, r2::AbstractUnitRange
)
isone(first(r1)) || isone(first(r2)) ||
throw(ArgumentError("Only one-based axes are supported"))
blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2))
blockaxs = vec(map(splat(tensor_product_axis), blockaxpairs))
return mortar_axis(blockaxs)
end
function matricize(style::BlockReshapeFusion, a::AbstractArray, ndims_codomain::Val)
ax = matricize_axes(style, a, ndims_codomain)
reshaped_blocks_a = reshape(blocks(a), blocklength.(ax))
bs = map(reshaped_blocks_a) do b
matricize(b, ndims_codomain)
end
return mortar(bs, ax)
end
using BlockArrays: blocklengths
function unmatricize(
::BlockReshapeFusion,
m::AbstractMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}},
)
ax = (codomain_axes..., domain_axes...)
reshaped_blocks_m = reshape(blocks(m), blocklength.(ax))
bs = map(CartesianIndices(reshaped_blocks_m)) do I
block_axes_I = BlockedTuple(
map(ntuple(identity, length(ax))) do i
return Base.axes1(ax[i][Block(I[i])])
end,
(length(codomain_axes), length(domain_axes)),
)
return unmatricize(reshaped_blocks_m[I], block_axes_I)
end
return mortar(bs, ax)
end

struct BlockedReshapeFusion <: FusionStyle end
FusionStyle(::Type{<:BlockedArray}) = BlockedReshapeFusion()
unblock(a::BlockedArray) = a.blocks
unblock(a::AbstractBlockArray) = a[Base.OneTo.(size(a))...]
unblock(a::AbstractArray) = a
function matricize(::BlockedReshapeFusion, a::AbstractArray, ndims_codomain::Val)
return matricize(ReshapeFusion(), unblock(a), ndims_codomain)
end
function unmatricize(
style::BlockedReshapeFusion, m::AbstractMatrix,
axes_codomain::Tuple{Vararg{AbstractUnitRange}},
axes_domain::Tuple{Vararg{AbstractUnitRange}},
)
a = unmatricize(
ReshapeFusion(), m,
Base.OneTo.(length.(axes_codomain)), Base.OneTo.(length.(axes_domain)),
)
return BlockedArray(a, (axes_codomain..., axes_domain...))
end
5 changes: 5 additions & 0 deletions src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}})
return blockedtrivialperm(Val.(blocklengths))
end

function trivialbiperm(length_codomain::Val, length::Val)
length_domain = Val(unval(length) - unval(length_codomain))
return blockedtrivialperm((length_codomain, length_domain))
end

function trivialperm(blockedperm::AbstractBlockTuple)
return blockedtrivialperm(blocklengths(blockedperm))
end
Expand Down
Loading
Loading