Skip to content

Commit

Permalink
Update SpaceMismatch error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Sep 5, 2023
1 parent b524237 commit b891bde
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
9 changes: 4 additions & 5 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂},
codB, domB = codomainind(B), domainind(B)
oindA, cindA, oindB, cindB = reorder_indices(codA, domA, codB, domB, oindA, cindA,
oindB, cindB, p1, p2)

if space(B, cindB[1]) != space(A, cindA[1])' || space(B, cindB[2]) != space(A, cindA[2])
throw(SpaceMismatch())

if space(B, cindB[1]) != space(A, cindA[1])' ||
space(B, cindB[2]) != space(A, cindA[2])'
throw(SpaceMismatch("$(space(C)) ≠ permute($(space(A))[$oindA, $cindA] * $(space(B))[$cindB, $oindB], ($p1, $p2)"))
end

if BraidingStyle(sectortype(B)) isa Bosonic
Expand Down Expand Up @@ -239,8 +240,6 @@ function planarcontract!(C::AbstractTensorMap{S,N₁,N₂},

if space(B, cindB[1]) != space(A, cindA[1])' ||
space(B, cindB[2]) != space(A, cindA[2])'
# @show space(B, cindB[1]), space(A, cindA[1])
# @show space(B, cindB[2]), space(A, cindA[2])
throw(SpaceMismatch("$(space(C)) ≠ permute($(space(A))[$oindA, $cindA] * $(space(B))[$cindB, $oindB], ($p1, $p2)"))
end

Expand Down
14 changes: 9 additions & 5 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ end
# Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1
# Copy, adjoint! and fill:
function Base.copy!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap)
space(tdst) == space(tsrc) || throw(SpaceMismatch())
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
for c in blocksectors(tdst)
copy!(StridedView(block(tdst, c)), StridedView(block(tsrc, c)))
end
Expand All @@ -179,7 +179,8 @@ function LinearAlgebra.adjoint!(tdst::AbstractTensorMap,
tsrc::AbstractTensorMap)
spacetype(tdst) === spacetype(tsrc) && InnerProductStyle(tdst) === EuclideanProduct() ||
throw(ArgumentError("adjoint! requires Euclidean inner product spacetype"))
space(tdst) == adjoint(space(tsrc)) || throw(SpaceMismatch())
space(tdst) == adjoint(space(tsrc)) ||
throw(SpaceMismatch("$(space(tdst)) ≠ adjoint($(space(tsrc)))"))
for c in blocksectors(tdst)
adjoint!(StridedView(block(tdst, c)), StridedView(block(tsrc, c)))
end
Expand Down Expand Up @@ -396,8 +397,9 @@ end
# concatenate tensors
function catdomain(t1::AbstractTensorMap{S,N₁,1},
t2::AbstractTensorMap{S,N₁,1}) where {S,N₁}
codomain(t1) == codomain(t2) || throw(SpaceMismatch())

codomain(t1) == codomain(t2) ||
throw(SpaceMismatch("codomains of tensors to concatenate must match:\n\
$(codomain(t1))$(codomain(t2))"))
V1, = domain(t1)
V2, = domain(t2)
isdual(V1) == isdual(V2) ||
Expand All @@ -413,7 +415,9 @@ function catdomain(t1::AbstractTensorMap{S,N₁,1},
end
function catcodomain(t1::AbstractTensorMap{S,1,N₂},
t2::AbstractTensorMap{S,1,N₂}) where {S,N₂}
domain(t1) == domain(t2) || throw(SpaceMismatch())
domain(t1) == domain(t2) ||
throw(SpaceMismatch("domains of tensors to concatenate must match:\n\
$(domain(t1))$(domain(t2))"))

V1, = codomain(t1)
V2, = codomain(t2)
Expand Down
8 changes: 4 additions & 4 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ function VectorInterface.scale!!(t::AbstractTensorMap, α::Number)
end

function VectorInterface.scale!(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number)
space(ty) == space(tx) || throw(SpaceMismatch())
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
for c in blocksectors(tx)
scale!(block(ty, c), block(tx, c), α)
end
return ty
end
function VectorInterface.scale!!(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number)
space(ty) == space(tx) || throw(SpaceMismatch())
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
T = scalartype(ty)
if promote_type(T, typeof(α), scalartype(tx)) <: T
return scale!(ty, tx, α)
Expand All @@ -55,7 +55,7 @@ end
# TODO: remove VectorInterface from calls to `add!` when `TensorKit.add!` is renamed
function VectorInterface.add(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number=VectorInterface._one, β::Number=VectorInterface._one)
space(ty) == space(tx) || throw(SpaceMismatch())
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
T = promote_type(scalartype(ty), scalartype(tx), typeof(α), typeof(β))
return VectorInterface.add!(scale!(similar(ty, T), ty, β), tx, α)
end
Expand All @@ -82,7 +82,7 @@ end
# inner
#-------
function VectorInterface.inner(tx::AbstractTensorMap, ty::AbstractTensorMap)
space(tx) == space(ty) || throw(SpaceMismatch())
space(tx) == space(ty) || throw(SpaceMismatch("$(space(tx))$(space(ty))"))
InnerProductStyle(tx) === EuclideanProduct() ||
throw(ArgumentError("dot requires Euclidean inner product"))
T = promote_type(scalartype(tx), scalartype(ty))
Expand Down

0 comments on commit b891bde

Please sign in to comment.