Skip to content

Commit

Permalink
Add some missing rrules. (#78)
Browse files Browse the repository at this point in the history
* Add some missing `rrule`s.

* little bit of cleanup

* repartition

---------

Co-authored-by: lkdvos <lukas.devos@ugent.be>
  • Loading branch information
leburgel and lkdvos committed Aug 25, 2023
1 parent 4693952 commit 1eeabe2
Showing 1 changed file with 171 additions and 4 deletions.
175 changes: 171 additions & 4 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,30 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using TensorKit
using ChainRulesCore
using LinearAlgebra
using TupleTools

# Utility
# -------

_conj(conjA::Symbol) = conjA == :C ? :N : :C
trivtuple(N) = ntuple(identity, N)

function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
return p[1:N₁], p[(N₁ + 1):end]
end
_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁)
function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
return _repartition(p, N₁)
end
function _repartition(p::Union{IndexTuple,Index2Tuple},
::AbstractTensorMap{<:Any,N₁}) where {N₁}
return _repartition(p, N₁)
end

# Constructors
# ------------
Expand Down Expand Up @@ -58,6 +80,19 @@ function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap)
return a * b, times_pullback
end

function ChainRulesCore.rrule(::typeof(permute), t::AbstractTensorMap, p::Index2Tuple)
function permute_pullback(c)
invpt = _repartition(TupleTools.invperm(linearize(p)), t)
return NoTangent(), permute(c, invpt), NoTangent()
end
return permute(t, p), permute_pullback
end

function ChainRulesCore.rrule(::typeof(scalar), t::AbstractTensorMap)
scalar_pullback(Δc) = NoTangent(), fill!(similar(t), Δc)
return scalar(t), scalar_pullback
end

# LinearAlgebra
# -------------

Expand All @@ -79,7 +114,7 @@ end
function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p)
p == 2 || error("currently only implemented for p = 2")
n = norm(a, p)
norm_pullback(Δn) = NoTangent(), @thunk(a * (Δn' + Δn) / (n * 2)), NoTangent()
norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent()
return n, norm_pullback
end

Expand All @@ -99,7 +134,7 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), t::AbstractTensorMap; kw
dst[i, j] = zero(eltype(S))
else
sᵢ, sⱼ = src[i, i], src[j, j]
dst[i, j] = 1 / (abs(sᵢ - sⱼ) < 1e-12) ? 1e-12 : sᵢ^2 - sⱼ^2
dst[i, j] = 1 / (abs(sⱼ - sᵢ) < 1e-12 ? 1e-12 : sⱼ^2 - sᵢ^2)
end
end
end
Expand Down Expand Up @@ -321,7 +356,7 @@ end

function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function pullback(c)
function convert_pullback(c)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
dual = copy(out)
dual[:data] = c[:data]
Expand All @@ -331,11 +366,143 @@ function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractT
return (NoTangent(), NoTangent(), zero(t))
end
end
return out, pullback
return out, convert_pullback
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end

# TensorOperations rules for TensorMaps
# -------------------------------------

function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
C::AbstractTensorMap, pC::Index2Tuple,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap, pB::Index2Tuple, conjB::Symbol,
α::Number, β::Number, backend...)
C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...)

function tensorcontract_pullback(ΔC)
dC = @thunk conj(β) * ΔC
dA = @thunk begin
ipC = _repartition(invperm(linearize(pC)), pA)
ipA = _repartition(invperm(linearize(pA)), A)
conjΔC = conjA == :C ? :C : :N
conjB′ = conjA == :C ? conjB : _conj(conjB)
c_dA = tensorcontract(ipA, ΔC, ipC, conjΔC,
B, reverse(pB), conjB′, conjA == :C ? α : conj(α),
backend...)
(!(eltype(A) <: Complex) && (eltype(c_dA) <: Complex)) ? real(c_dA) : c_dA
end
dB = @thunk begin
ipC = _repartition(invperm(linearize(pC)), pA)
ipB = _repartition(invperm(linearize(pB)), B)
NA₁ = TensorOperations.numout(pA)
conjΔC = conjB == :C ? :C : :N
conjA′ = conjB == :C ? conjA : _conj(conjA)
return tensorcontract(ipB, A, reverse(pA), conjA′,
ΔC, ipC, conjΔC,
conjB == :C ? α : conj(α), backend...)
(!(eltype(B) <: Complex) && (eltype(c_dB) <: Complex)) ? real(c_dB) : c_dB
end
= @thunk tensorscalar(tensorcontract(((), ()),
tensorcontract(pC, A, pA, conjA, B, pB,
conjB),
((),
trivtuple(TensorOperations.numind(pC))),
:C, ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))
= @thunk tensorscalar(tensorcontract(((), ()), C,
((),
trivtuple(TensorOperations.numind(pC))),
:C, ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))

dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(),
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end

return C′, tensorcontract_pullback
end

function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
C::AbstractTensorMap, pC::Index2Tuple,
A::AbstractTensorMap, conjA::Symbol,
α::Number, β::Number, backend...)
C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...)

function tensoradd_pullback(ΔC)
dC = @thunk conj(β) * ΔC
dA = @thunk begin
ipC = _repartition(invperm(linearize(pC)), A)
c_dA = tensorcopy(ipC, ΔC, conjA, conjA == :N ? conj(α) : α, backend...)

return (!(scalartype(A) <: Complex) && (scalartype(c_dA) <: Complex)) ?
real(c_dA) : c_dA
end
= @thunk tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)),
_conj(conjA), ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N, backend...))
= @thunk tensorscalar(tensorcontract(((), ()), C,
((),
trivtuple(TensorOperations.numind(pC))),
:C, ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend...
end

return C′, tensoradd_pullback
end

function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap, pC::Index2Tuple,
A::AbstractTensorMap,
pA::Index2Tuple, conjA::Symbol, α::Number, β::Number,
backend...)
C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...)

function tensortrace_pullback(ΔC)
dC = @thunk conj(β) * ΔC
dA = @thunk begin
ipC = _repartition(invperm((linearize(pC)..., pA[1]..., pA[2]...)), A)
E = id(storagetype(A), ProductSpace((space(A, i) for i in pA[1])...))
return tensorproduct(ipC, ΔC, (trivtuple(TensorOperations.numind(pC)), ()),
conjA, E,
((), trivtuple(TensorOperations.numind(pA))), conjA,
conjA == :N ? conj(α) : α, backend...)
end

= @thunk tensorscalar(tensorcontract(((), ()),
tensortrace(pC, A, pA, conjA),
((),
trivtuple(TensorOperations.numind(pC))),
_conj(conjA), ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))
= @thunk tensorscalar(tensorcontract(((), ()), C,
((),
trivtuple(TensorOperations.numind(pC))),
:C, ΔC,
(trivtuple(TensorOperations.numind(pC)),
()), :N,
backend...))
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ,
dbackend...
end

return C′, tensortrace_pullback
end

end

0 comments on commit 1eeabe2

Please sign in to comment.