Skip to content

Commit

Permalink
Update VectorInterface 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Sep 8, 2023
1 parent 868d1df commit f2b6ff0
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ LRUCache = "1.0.2"
Strided = "2"
TensorOperations = "4.0.2"
TupleTools = "1.1"
VectorInterface = "0.3"
VectorInterface = "0.4"
WignerSymbols = "1,2"
julia = "1.6"

Expand Down
1 change: 0 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ using TupleTools: StaticLength
using Strided

using VectorInterface
using VectorInterface: _zero, _one

using TensorOperations: TensorOperations, @tensor, @tensoropt, @ncon, ncon
using TensorOperations: IndexTuple, Index2Tuple, linearize, Backend
Expand Down
8 changes: 2 additions & 6 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,7 @@ function trace_permute!(tdst::AbstractTensorMap{S,N₁,N₂},
cod = codomain(tsrc)
dom = domain(tsrc)
n = length(cod)
if iszero(β)
fill!(tdst, β)
elseif β != 1
mul!(tdst, β, tdst)
end
VectorInterface.scale!(tdst, β)
r₁ = (p₁..., q₁...)
r₂ = (p₂..., q₂...)
for (f₁, f₂) in fusiontrees(tsrc)
Expand All @@ -209,7 +205,7 @@ function trace_permute!(tdst::AbstractTensorMap{S,N₁,N₂},
C = tdst[f₁′′, f₂′′]
A = tsrc[f₁, f₂]
α′ = α * coeff
TO.tensortrace!(C, (p₁, p₂), A, (q₁, q₂), :N, α′, true, backend...)
TO.tensortrace!(C, (p₁, p₂), A, (q₁, q₂), :N, α′, One(), backend...)
end
end
end
Expand Down
24 changes: 11 additions & 13 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ VectorInterface.zerovector!!(t::AbstractTensorMap) = zerovector!(t)
# scale, scale! & scale!!
#-------------------------
function VectorInterface.scale(t::AbstractTensorMap, α::Number)
T = Base.promote_op(scale, scalartype(t), scalartype(α))
T = VectorInterface.promote_scale(t, α)
return scale!(similar(t, T), t, α)
end
function VectorInterface.scale!(t::AbstractTensorMap, α::Number)
Expand All @@ -28,12 +28,10 @@ function VectorInterface.scale!(t::AbstractTensorMap, α::Number)
return t
end
function VectorInterface.scale!!(t::AbstractTensorMap, α::Number)
α === _one && return t
α === _zero && return zerovector!!(t)
T = Base.promote_op(scale, scalartype(t), scalartype(α))
α === One() && return t
T = VectorInterface.promote_scale(t, α)
return T <: scalartype(t) ? scale!(t, α) : scale(t, α)
end

function VectorInterface.scale!(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
for c in blocksectors(tx)
Expand All @@ -42,7 +40,7 @@ function VectorInterface.scale!(ty::AbstractTensorMap, tx::AbstractTensorMap, α
return ty
end
function VectorInterface.scale!!(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number)
T = Base.promote_op(scale, scalartype(tx), scalartype(α))
T = VectorInterface.promote_scale(tx, α)
if T <: scalartype(ty)
return scale!(ty, tx, α)
else
Expand All @@ -54,23 +52,23 @@ end
#-------------------
# TODO: remove VectorInterface from calls to `add!` when `TensorKit.add!` is renamed
function VectorInterface.add(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number=_one, β::Number=_one)
α::Number, β::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
T = Base.promote_op(VectorInterface.add, scalartype(ty), scalartype(tx), scalartype(α), scalartype(β))
T = VectorInterface.promote_add(ty, tx, α, β)
return VectorInterface.add!(scale!(similar(ty, T), ty, β), tx, α)
end
function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number=_one, β::Number=_one)
α::Number, β::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
for c in blocksectors(tx)
VectorInterface.add!(block(ty, c), block(tx, c), α, β)
end
return ty
end
function VectorInterface.add!!(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number=_one, β::Number=_one)
T = Base.promote_op(VectorInterface.add, scalartype(ty), scalartype(tx), scalartype(α),
scalartype(β))
α::Number, β::Number)
# spacecheck is done in add(!)
T = VectorInterface.promote_add(ty, tx, α, β)
if T <: scalartype(ty)
return VectorInterface.add!(ty, tx, α, β)
else
Expand All @@ -84,7 +82,7 @@ function VectorInterface.inner(tx::AbstractTensorMap, ty::AbstractTensorMap)
space(tx) == space(ty) || throw(SpaceMismatch("$(space(tx))$(space(ty))"))
InnerProductStyle(tx) === EuclideanProduct() ||
throw(ArgumentError("dot requires Euclidean inner product"))
T = Base.promote_op(VectorInterface.inner, scalartype(tx), scalartype(ty))
T = VectorInterface.promote_inner(tx, ty)
s = zero(T)
for c in blocksectors(tx)
s += convert(T, dim(c)) * dot(block(tx, c), block(ty, c))
Expand Down

0 comments on commit f2b6ff0

Please sign in to comment.