Skip to content

Commit

Permalink
Fix rrules for Fermionic symmetries (#126)
Browse files Browse the repository at this point in the history
This PR fixes the rrules involving variants of `tensorcontract`. In particular, it balances out the necessary twists that are inserted by `tensorcontract` and are unwanted as AD is defined in terms of inner products (traces instead of supertraces).

---------

Co-authored-by: lkdvos <lukas.devos@ugent.be>
  • Loading branch information
qmortier and lkdvos committed Jun 3, 2024
1 parent 4607691 commit 7b480fc
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 102 deletions.
178 changes: 167 additions & 11 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using VectorInterface
using TensorKit
using ChainRulesCore
using LinearAlgebra
using TupleTools

import TensorOperations as TO
using TensorOperations: Backend, promote_contract
using VectorInterface: promote_scale, promote_add

ext = @static if isdefined(Base, :get_extension)
Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt)
else
TensorOperations.TensorOperationsChainRulesCoreExt
end
const _conj = ext._conj
const trivtuple = ext.trivtuple

# Utility
# -------

_conj(conjA::Symbol) = conjA == :C ? :N : :C
trivtuple(N) = ntuple(identity, N)

function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
Expand Down Expand Up @@ -104,26 +114,26 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
projectA = ProjectTo(A)
projectB = ProjectTo(B)
function otimes_pullback(ΔC_)
# TODO: this rule is probably better written in terms of inner products,
# using planarcontract and adjoint tensormaps would remove the twists.
ΔC = unthunk(ΔC_)
pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...),
((codomainind(B) .+ numout(A))...,
(domainind(B) .+ (numin(A) + numout(A)))...))
dA_ = @thunk begin
ipA = (codomainind(A), domainind(A))
pB = (allind(B), ())
dA = zerovector(A,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(B)))
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, B, pB, :C)
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)))
dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, tB, pB, :C)
return projectA(dA)
end
dB_ = @thunk begin
ipB = (codomainind(B), domainind(B))
pA = ((), allind(A))
dB = zerovector(B,
TensorOperations.promote_contract(scalartype(ΔC),
scalartype(A)))
dB = tensorcontract!(dB, ipB, A, pA, :C, ΔC, pΔC, :N)
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))
tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)))
dB = tensorcontract!(dB, ipB, tA, pA, :C, ΔC, pΔC, :N)
return projectB(dB)
end
return NoTangent(), dA_, dB_
Expand Down Expand Up @@ -653,4 +663,150 @@ function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end

function ChainRulesCore.rrule(::typeof(TO.tensorcontract!),
C::AbstractTensorMap{S}, pC::Index2Tuple,
A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol,
α::Number, β::Number,
backend::Backend...) where {S}
C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...)

projectA = ProjectTo(A)
projectB = ProjectTo(B)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipC = invperm(linearize(pC))
pΔC = (TupleTools.getindices(ipC, trivtuple(TO.numout(pA))),
TupleTools.getindices(ipC, TO.numout(pA) .+ trivtuple(TO.numin(pB))))

dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipA = (invperm(linearize(pA)), ())
conjΔC = conjA == :C ? :C : :N
conjB′ = conjA == :C ? conjB : _conj(conjB)
_dA = zerovector(A,
promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)))
tB = twist(B,
TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]),
filter(x -> isdual(space(B, x)), pB[2])))
_dA = tensorcontract!(_dA, ipA,
ΔC, pΔC, conjΔC,
tB, reverse(pB), conjB′,
conjA == :C ? α : conj(α), Zero(), backend...)
return projectA(_dA)
end
dB = @thunk begin
ipB = (invperm(linearize(pB)), ())
conjΔC = conjB == :C ? :C : :N
conjA′ = conjB == :C ? conjA : _conj(conjA)
_dB = zerovector(B,
promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)))
tA = twist(A,
TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]),
filter(x -> !isdual(space(A, x)), pA[2])))
_dB = tensorcontract!(_dB, ipB,
tA, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
conjB == :C ? α : conj(α), Zero(), backend...)
return projectB(_dB)
end
= @thunk begin
# TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB
AB = tensorcontract(pC, A, pA, conjA, B, pB, conjB)
return projectα(inner(AB, ΔC))
end
= @thunk projectβ(inner(C, ΔC))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(),
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end
return C′, pullback
end

function ChainRulesCore.rrule(::typeof(TO.tensoradd!),
C::AbstractTensorMap{S}, pC::Index2Tuple,
A::AbstractTensorMap{S}, conjA::Symbol,
α::Number, β::Number, backend::Backend...) where {S}
C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...)

projectA = ProjectTo(A)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipC = invperm(linearize(pC))
_dA = zerovector(A, promote_add(ΔC, α))
_dA = tensoradd!(_dA, (ipC, ()), ΔC, conjA, conjA == :N ? conj(α) : α, Zero(),
backend...)
return projectA(_dA)
end
= @thunk begin
# TODO: this is an inner product implemented as a contraction
# for non-symmetric tensors this might be more efficient like this,
# but for symmetric tensors an intermediate object will anyways be created
# and then it might be more efficient to use an addition and inner product
tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
_dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)),
_conj(conjA), tΔC,
(trivtuple(TO.numind(pC)),
()), :N, One(), backend...))
return projectα(_dα)
end
= @thunk projectβ(inner(C, ΔC))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend...
end

return C′, pullback
end

function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap{S},
pC::Index2Tuple, A::AbstractTensorMap{S},
pA::Index2Tuple, conjA::Symbol, α::Number, β::Number,
backend::Backend...) where {S}
C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...)

projectA = ProjectTo(A)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...))
E = one!(TO.tensoralloc_add(scalartype(A), pA, A, conjA))
twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E)))
_dA = zerovector(A, promote_scale(ΔC, α))
_dA = tensorproduct!(_dA, (ipC, ()), ΔC,
(trivtuple(TO.numind(pC)), ()), conjA, E,
((), trivtuple(TO.numind(pA))), conjA,
conjA == :N ? conj(α) : α, Zero(), backend...)
return projectA(_dA)
end
= @thunk begin
# TODO: this result might be easier to compute as:
# C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α
At = tensortrace(pC, A, pA, conjA)
return projectα(inner(At, ΔC))
end
= @thunk projectβ(inner(C, ΔC))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end

return C′, pullback
end

end
4 changes: 2 additions & 2 deletions src/tensors/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ end

# Show
#------
function Base.summary(t::AdjointTensorMap)
return print("AdjointTensorMap(", codomain(t), "", domain(t), ")")
function Base.summary(io::IO, t::AdjointTensorMap)
return print(io, "AdjointTensorMap(", codomain(t), "", domain(t), ")")
end
function Base.show(io::IO, t::AdjointTensorMap{S}) where {S<:IndexSpace}
if get(io, :compact, false)
Expand Down
26 changes: 20 additions & 6 deletions src/tensors/indexmanipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ end

# Twist
"""
twist!(t::AbstractTensorMap, i::Int; inv::Bool=false)
-> t
twist!(t::AbstractTensorMap, i::Int; inv::Bool=false) -> t
twist!(t::AbstractTensorMap, is; inv::Bool=false) -> t
Apply a twist to the `i`th index of `t`, storing the result in `t`.
If `inv=true`, use the inverse twist.
Expand All @@ -248,17 +248,31 @@ function twist!(t::AbstractTensorMap, i::Int; inv::Bool=false)
end
return t
end
function twist!(t::AbstractTensorMap, is; inv::Bool=false)
if !all(in(allind(t)), is)
msg = "Can't twist indices $is of a tensor with only $(numind(t)) indices."
throw(ArgumentError(msg))
end
(BraidingStyle(sectortype(t)) == Bosonic() || isempty(is)) && return t
N₁ = numout(t)
for (f₁, f₂) in fusiontrees(t)
θ = prod(i -> i <= N₁ ? twist(f₁.uncoupled[i]) : twist(f₂.uncoupled[i - N₁]), is)
inv &&= θ')
rmul!(t[f₁, f₂], θ)
end
return t
end

"""
twist(t::AbstractTensorMap, i::Int; inv::Bool=false)
-> t
twist(tsrc::AbstractTensorMap, i::Int; inv::Bool=false) -> tdst
twist(tsrc::AbstractTensorMap, is; inv::Bool=false) -> tdst
Apply a twist to the `i`th index of `t` and return the result as a new tensor.
Apply a twist to the `i`th index of `tsrc` and return the result as a new tensor.
If `inv=true`, use the inverse twist.
See [`twist!`](@ref) for storing the result in place.
"""
twist(t::AbstractTensorMap, i::Int; inv::Bool=false) = twist!(copy(t), i; inv=inv)
twist(t::AbstractTensorMap, i; inv::Bool=false) = twist!(copy(t), i; inv)

# Fusing and splitting
# TODO: add functionality for easy fusing and splitting of tensor indices
Expand Down
4 changes: 2 additions & 2 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,8 @@ end

# Show
#------
function Base.summary(t::TensorMap)
return print("TensorMap(", space(t), ")")
function Base.summary(io::IO, t::TensorMap)
return print(io, "TensorMap(", space(t), ")")
end
function Base.show(io::IO, t::TensorMap{S}) where {S<:IndexSpace}
if get(io, :compact, false)
Expand Down
8 changes: 1 addition & 7 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,7 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
end
A′ = permute(A, (oindA, cindA); copy=copyA)
B′ = permute(B, (cindB, oindB))
if BraidingStyle(sectortype(S)) isa Fermionic
for i in domainind(A′)
if !isdual(space(A′, i))
A′ = twist!(A′, i)
end
end
end
A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′)))
ipC = TupleTools.invperm((p₁..., p₂...))
oindAinC = TupleTools.getindices(ipC, ntuple(n -> n, N₁))
oindBinC = TupleTools.getindices(ipC, ntuple(n -> n + N₁, N₂))
Expand Down
Loading

0 comments on commit 7b480fc

Please sign in to comment.