Skip to content

Commit

Permalink
Remove overloaded rrules in favor of TensorOperations update
Browse files Browse the repository at this point in the history
  • Loading branch information
leburgel authored and lkdvos committed Sep 5, 2023
1 parent 401a1d1 commit 868d1df
Showing 1 changed file with 0 additions and 132 deletions.
132 changes: 0 additions & 132 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,136 +373,4 @@ function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end

# TensorOperations rules for TensorMaps
# -------------------------------------

function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
C::AbstractTensorMap, pC::Index2Tuple,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap, pB::Index2Tuple, conjB::Symbol,
α::Number, β::Number, backend...)
C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...)

function tensorcontract_pullback(ΔC)
dC = @thunk conj(β) * ΔC
dA = @thunk begin
ipC = _repartition(invperm(linearize(pC)), pA)
ipA = _repartition(invperm(linearize(pA)), A)
conjΔC = conjA == :C ? :C : :N
conjB′ = conjA == :C ? conjB : _conj(conjB)
c_dA = tensorcontract(ipA, ΔC, ipC, conjΔC,
B, reverse(pB), conjB′, conjA == :C ? α : conj(α),
backend...)
(!(eltype(A) <: Complex) && (eltype(c_dA) <: Complex)) ? real(c_dA) : c_dA
end
dB = @thunk begin
ipC = _repartition(invperm(linearize(pC)), pA)
ipB = _repartition(invperm(linearize(pB)), B)
NA₁ = TensorOperations.numout(pA)
conjΔC = conjB == :C ? :C : :N
conjA′ = conjB == :C ? conjA : _conj(conjA)
return tensorcontract(ipB, A, reverse(pA), conjA′,
ΔC, ipC, conjΔC,
conjB == :C ? α : conj(α), backend...)
(!(eltype(B) <: Complex) && (eltype(c_dB) <: Complex)) ? real(c_dB) : c_dB
end
= @thunk tensorscalar(tensorcontract(((), ()),
tensorcontract(pC, A, pA, conjA, B, pB,
conjB),
((),
trivtuple(TensorOperations.numind(pC))),
:C, ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))
= @thunk tensorscalar(tensorcontract(((), ()), C,
((),
trivtuple(TensorOperations.numind(pC))),
:C, ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))

dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(),
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end

return C′, tensorcontract_pullback
end

function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
C::AbstractTensorMap, pC::Index2Tuple,
A::AbstractTensorMap, conjA::Symbol,
α::Number, β::Number, backend...)
C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...)

function tensoradd_pullback(ΔC)
dC = @thunk conj(β) * ΔC
dA = @thunk begin
ipC = _repartition(invperm(linearize(pC)), A)
c_dA = tensorcopy(ipC, ΔC, conjA, conjA == :N ? conj(α) : α, backend...)

return (!(scalartype(A) <: Complex) && (scalartype(c_dA) <: Complex)) ?
real(c_dA) : c_dA
end
= @thunk tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)),
_conj(conjA), ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N, backend...))
= @thunk tensorscalar(tensorcontract(((), ()), C,
((),
trivtuple(TensorOperations.numind(pC))),
:C, ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend...
end

return C′, tensoradd_pullback
end

function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap, pC::Index2Tuple,
A::AbstractTensorMap,
pA::Index2Tuple, conjA::Symbol, α::Number, β::Number,
backend...)
C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...)

function tensortrace_pullback(ΔC)
dC = @thunk conj(β) * ΔC
dA = @thunk begin
ipC = _repartition(invperm((linearize(pC)..., pA[1]..., pA[2]...)), A)
E = id(storagetype(A), ProductSpace((space(A, i) for i in pA[1])...))
return tensorproduct(ipC, ΔC, (trivtuple(TensorOperations.numind(pC)), ()),
conjA, E,
((), trivtuple(TensorOperations.numind(pA))), conjA,
conjA == :N ? conj(α) : α, backend...)
end

= @thunk tensorscalar(tensorcontract(((), ()),
tensortrace(pC, A, pA, conjA),
((),
trivtuple(TensorOperations.numind(pC))),
_conj(conjA), ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))
= @thunk tensorscalar(tensorcontract(((), ()), C,
((),
trivtuple(TensorOperations.numind(pC))),
:C, ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end

return C′, tensortrace_pullback
end

end

0 comments on commit 868d1df

Please sign in to comment.