Skip to content

Commit

Permalink
Various AD bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Sep 13, 2023
1 parent f2b6ff0 commit 25c06e5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
20 changes: 10 additions & 10 deletions ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap)
return a * b, times_pullback
end

function ChainRulesCore.rrule(::typeof(permute), t::AbstractTensorMap, p::Index2Tuple)
function permute_pullback(c)
invpt = _repartition(TupleTools.invperm(linearize(p)), t)
return NoTangent(), permute(c, invpt), NoTangent()
function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple)
function permute_pullback(Δtdst)
invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc)
return NoTangent(), permute(unthunk(Δtdst), invp), NoTangent()
end
return permute(t, p), permute_pullback
return permute(tsrc, p), permute_pullback
end

function ChainRulesCore.rrule(::typeof(scalar), t::AbstractTensorMap)
scalar_pullback(Δc) = NoTangent(), fill!(similar(t), Δc)
scalar_pullback(Δc) = NoTangent(), fill!(similar(t), unthunk(Δc))
return scalar(t), scalar_pullback
end

Expand All @@ -102,7 +102,7 @@ function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
end

function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
adjoint_pullback(Δadjoint) = NoTangent(), adjoint(Δadjoint)
adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint))
return adjoint(A), adjoint_pullback
end

Expand All @@ -124,7 +124,7 @@ end
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), t::AbstractTensorMap; kwargs...)
T = eltype(t)

U, S, V = tsvd(t; kwargs...)
U, S, V, ϵ = tsvd(t; kwargs...)

F = similar(S)
for (k, dst) in blocks(F)
Expand Down Expand Up @@ -171,10 +171,10 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd), t::AbstractTensorMap; kw
∂t += U * pinv(S) * dV * (one(prv) - prv)
end

return NoTangent(), ∂t, fill(NoTangent(), length(kwargs))...
return NoTangent(), ∂t
end

return (U, S, V), tsvd_pullback
return (U, S, V, ϵ), tsvd_pullback
end

function _elementwise_mult(a::AbstractTensorMap, b::AbstractTensorMap)
Expand Down
8 changes: 4 additions & 4 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ function _canonicalize(p::Index2Tuple{N₁,N₂},
::AbstractTensorMap{<:IndexSpace,N₁,N₂}) where {N₁,N₂}
return p
end
function _canonicalize(p::Index2Tuple, t::AbstractTensorMap)
p′ = linearize(p)
p₁ = TupleTools.getindices(p, codomainind(t))
p₂ = TupleTools.getindices(p, domainind(t))
_canonicalize(p::Index2Tuple, t::AbstractTensorMap) = _canonicalize(linearize(p), t)
function _canonicalize(p::IndexTuple, t::AbstractTensorMap)
p₁ = TupleTools.getindices(p, codomainind(t))
p₂ = TupleTools.getindices(p, domainind(t))
return (p₁, p₂)
end

Expand Down

0 comments on commit 25c06e5

Please sign in to comment.