From 6a83a8327425a2008d4f1f7d66b8377e6bac2aba Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 7 Jul 2022 12:10:01 -0400 Subject: [PATCH] [ITensors] Simplify the rrules for priming and tagging MPS/MPO (#950) --- NEWS.md | 9 +++ src/ITensorChainRules/indexset.jl | 94 ++++++---------------------- src/ITensorChainRules/itensor.jl | 30 +++++---- src/ITensorChainRules/mps/mpo.jl | 12 ++-- src/ITensorChainRules/zygoterules.jl | 8 +-- 5 files changed, 51 insertions(+), 102 deletions(-) diff --git a/NEWS.md b/NEWS.md index 72208ea452..c0ac9334de 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,15 @@ Note that as of Julia v1.5, in order to see deprecation warnings you will need t After we release v1 of the package, we will start following [semantic versioning](https://semver.org). +ITensors v0.3.19 Release Notes +============================== + +Bugs: + +Enhancements: + +- Simplify the `rrule`s for priming and tagging MPS/MPO + ITensors v0.3.18 Release Notes ============================== diff --git a/src/ITensorChainRules/indexset.jl b/src/ITensorChainRules/indexset.jl index 3ffb62b00a..253a8157c8 100644 --- a/src/ITensorChainRules/indexset.jl +++ b/src/ITensorChainRules/indexset.jl @@ -1,60 +1,3 @@ -function setinds_pullback(ȳ, x, a...) - x̄ = ITensors.setinds(ȳ, inds(x)) - ā = map_notangent(a) - return (NoTangent(), x̄, ā...) -end - -function inv_op(f::Function, args...; kwargs...) - return error( - "Trying to differentiate `$f` but the inverse of the operation (`inv_op`) `$f` with arguments $args and keyword arguments $kwargs is not defined.", - ) -end - -function inv_op(::typeof(prime), x, n::Integer=1; kwargs...) - return prime(x, -n; kwargs...) -end - -function inv_op(::typeof(replaceprime), x, n1n2::Pair; kwargs...) - return replaceprime(x, reverse(n1n2); kwargs...) -end - -function inv_op(::typeof(addtags), x, args...; kwargs...) - return removetags(x, args...; kwargs...) -end - -function inv_op(::typeof(removetags), x, args...; kwargs...) - return addtags(x, args...; kwargs...) -end - -function inv_op(::typeof(replacetags), x, n1n2::Pair; kwargs...) - return replacetags(x, reverse(n1n2); kwargs...) -end - -_check_inds(x::ITensor, y::ITensor) = hassameinds(x, y) -_check_inds(x::MPS, y::MPS) = hassameinds(siteinds, x, y) -_check_inds(x::MPO, y::MPO) = hassameinds(siteinds, x, y) - -for fname in ( - :prime, :setprime, :noprime, :replaceprime, :addtags, :removetags, :replacetags, :settags -) - @eval begin - function ChainRulesCore.rrule(f::typeof($fname), x::Union{MPS,MPO}, a...; kwargs...) - y = f(x, a...; kwargs...) - function f_pullback(ȳ) - x̄ = inv_op(f, unthunk(ȳ), a...; kwargs...) - if !_check_inds(x, x̄) - error( - "Trying to differentiate function `$f` with arguments $a and keyword arguments $kwargs. The forward pass indices $(inds(x)) do not match the reverse pass indices $(inds(x̄)). Likely this is because the priming/tagging operation you tried to perform is not invertible. Please write your code in a way where the index manipulation operation you are performing is invertible. For example, `prime(A::ITensor)` is invertible, with an inverse `prime(A, -1)`. However, `noprime(A)` is in general not invertible since the information about the prime levels of the original tensor are lost. Instead, you might try `prime(A, -1)` or `replaceprime(A, 1 => 0)` which are invertible.", - ) - end - ā = map_notangent(a) - return (NoTangent(), x̄, ā...) - end - return y, f_pullback - end - end -end - for fname in ( :prime, :setprime, @@ -72,11 +15,10 @@ for fname in ( :swapinds, ) @eval begin - function ChainRulesCore.rrule(f::typeof($fname), x::ITensor, a...; kwargs...) + function rrule(f::typeof($fname), x::ITensor, a...; kwargs...) y = f(x, a...; kwargs...) function f_pullback(ȳ) - uȳ = unthunk(ȳ) - x̄ = replaceinds(uȳ, inds(y), inds(x)) + x̄ = replaceinds(unthunk(ȳ), inds(y) => inds(x)) ā = map_notangent(a) return (NoTangent(), x̄, ā...) end @@ -85,23 +27,25 @@ for fname in ( end end -function ChainRulesCore.rrule(::typeof(adjoint), x::ITensor) - y = x' - function adjoint_pullback(ȳ) - uȳ = unthunk(ȳ) - x̄ = replaceinds(uȳ, inds(y), inds(x)) - return (NoTangent(), x̄) +for fname in ( + :prime, :setprime, :noprime, :replaceprime, :addtags, :removetags, :replacetags, :settags +) + @eval begin + function rrule(f::typeof($fname), x::Union{MPS,MPO}, a...; kwargs...) + y = f(x, a...; kwargs...) + function f_pullback(ȳ) + x̄ = copy(unthunk(ȳ)) + for j in eachindex(x̄) + x̄[j] = replaceinds(ȳ[j], inds(y[j]) => inds(x[j])) + end + ā = map_notangent(a) + return (NoTangent(), x̄, ā...) + end + return y, f_pullback + end end - return y, adjoint_pullback end -function ChainRulesCore.rrule(::typeof(adjoint), x::Union{MPS,MPO}) - y = x' - function adjoint_pullback(ȳ) - x̄ = inv_op(prime, ȳ) - return (NoTangent(), x̄) - end - return y, adjoint_pullback -end +rrule(::typeof(adjoint), x::Union{ITensor,MPS,MPO}) = rrule(prime, x) @non_differentiable permute(::Indices, ::Indices) diff --git a/src/ITensorChainRules/itensor.jl b/src/ITensorChainRules/itensor.jl index 9dbda701f9..a830cc77bd 100644 --- a/src/ITensorChainRules/itensor.jl +++ b/src/ITensorChainRules/itensor.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.rrule(::typeof(getindex), x::ITensor, I...) +function rrule(::typeof(getindex), x::ITensor, I...) y = getindex(x, I...) function getindex_pullback(ȳ) # TODO: add definition `ITensor(::Tuple{}) = ITensor()` @@ -14,7 +14,7 @@ end # Specialized version in order to avoid call to `setindex!` # within the pullback, should be better for taking higher order # derivatives in Zygote. -function ChainRulesCore.rrule(::typeof(getindex), x::ITensor) +function rrule(::typeof(getindex), x::ITensor) y = x[] function getindex_pullback(ȳ) x̄ = ITensor(unthunk(ȳ)) @@ -91,7 +91,7 @@ function rrule(::typeof(tensor), x1::ITensor) end # Special case for contracting a pair of ITensors -function ChainRulesCore.rrule(::typeof(contract), x1::ITensor, x2::ITensor) +function rrule(::typeof(contract), x1::ITensor, x2::ITensor) project_x1 = ProjectTo(x1) project_x2 = ProjectTo(x2) function contract_pullback(ȳ) @@ -104,7 +104,7 @@ end @non_differentiable ITensors.optimal_contraction_sequence(::Any) -function ChainRulesCore.rrule(::typeof(*), x1::Number, x2::ITensor) +function rrule(::typeof(*), x1::Number, x2::ITensor) project_x1 = ProjectTo(x1) project_x2 = ProjectTo(x2) function contract_pullback(ȳ) @@ -115,7 +115,7 @@ function ChainRulesCore.rrule(::typeof(*), x1::Number, x2::ITensor) return x1 * x2, contract_pullback end -function ChainRulesCore.rrule(::typeof(*), x1::ITensor, x2::Number) +function rrule(::typeof(*), x1::ITensor, x2::Number) project_x1 = ProjectTo(x1) project_x2 = ProjectTo(x2) function contract_pullback(ȳ) @@ -126,28 +126,28 @@ function ChainRulesCore.rrule(::typeof(*), x1::ITensor, x2::Number) return x1 * x2, contract_pullback end -function ChainRulesCore.rrule(::typeof(+), x1::ITensor, x2::ITensor) +function rrule(::typeof(+), x1::ITensor, x2::ITensor) function add_pullback(ȳ) return (NoTangent(), ȳ, ȳ) end return x1 + x2, add_pullback end -function ChainRulesCore.rrule(::typeof(-), x1::ITensor, x2::ITensor) +function rrule(::typeof(-), x1::ITensor, x2::ITensor) function subtract_pullback(ȳ) return (NoTangent(), ȳ, -ȳ) end return x1 - x2, subtract_pullback end -function ChainRulesCore.rrule(::typeof(-), x::ITensor) +function rrule(::typeof(-), x::ITensor) function minus_pullback(ȳ) return (NoTangent(), -ȳ) end return -x, minus_pullback end -function ChainRulesCore.rrule(::typeof(itensor), x::Array, a...) +function rrule(::typeof(itensor), x::Array, a...) function itensor_pullback(ȳ) uȳ = permute(unthunk(ȳ), a...) x̄ = reshape(array(uȳ), size(x)) @@ -157,7 +157,7 @@ function ChainRulesCore.rrule(::typeof(itensor), x::Array, a...) return itensor(x, a...), itensor_pullback end -function ChainRulesCore.rrule(::Type{ITensor}, x::Array{<:Number}, a...) +function rrule(::Type{ITensor}, x::Array{<:Number}, a...) function ITensor_pullback(ȳ) # TODO: define `Array(::ITensor)` directly uȳ = Array(unthunk(ȳ), a...) @@ -168,7 +168,7 @@ function ChainRulesCore.rrule(::Type{ITensor}, x::Array{<:Number}, a...) return ITensor(x, a...), ITensor_pullback end -function ChainRulesCore.rrule(::Type{ITensor}, x::Number) +function rrule(::Type{ITensor}, x::Number) function ITensor_pullback(ȳ) x̄ = ȳ[] return (NoTangent(), x̄) @@ -176,7 +176,7 @@ function ChainRulesCore.rrule(::Type{ITensor}, x::Number) return ITensor(x), ITensor_pullback end -function ChainRulesCore.rrule(::typeof(dag), x::ITensor) +function rrule(::typeof(dag), x::ITensor) function dag_pullback(ȳ) x̄ = dag(unthunk(ȳ)) return (NoTangent(), x̄) @@ -184,7 +184,7 @@ function ChainRulesCore.rrule(::typeof(dag), x::ITensor) return dag(x), dag_pullback end -function ChainRulesCore.rrule(::typeof(permute), x::ITensor, a...) +function rrule(::typeof(permute), x::ITensor, a...) y = permute(x, a...) function permute_pullback(ȳ) x̄ = permute(unthunk(ȳ), inds(x)) @@ -197,9 +197,7 @@ end # Needed because by default it was calling the generic # `rrule` for `tr` inside ChainRules. # TODO: Raise an issue with ChainRules. -function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(tr), x::ITensor; kwargs... -) +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(tr), x::ITensor; kwargs...) return rrule_via_ad(config, ITensors._tr, x; kwargs...) end diff --git a/src/ITensorChainRules/mps/mpo.jl b/src/ITensorChainRules/mps/mpo.jl index 78c0d1fba2..c33b620ab8 100644 --- a/src/ITensorChainRules/mps/mpo.jl +++ b/src/ITensorChainRules/mps/mpo.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) +function rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) y = contract(x1, x2; kwargs...) function contract_pullback(ȳ) x̄1 = contract(ȳ, dag(x2); kwargs...) @@ -8,11 +8,11 @@ function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) return y, contract_pullback end -function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) +function rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) return rrule(contract, x1, x2; kwargs...) end -function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) +function rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) y = +(x1, x2; kwargs...) function add_pullback(ȳ) return (NoTangent(), ȳ, ȳ) @@ -20,11 +20,11 @@ function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) return y, add_pullback end -function ChainRulesCore.rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...) +function rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...) return rrule(+, x1, -x2; kwargs...) end -function ChainRulesCore.rrule(::typeof(tr), x::MPO; kwargs...) +function rrule(::typeof(tr), x::MPO; kwargs...) y = tr(x; kwargs...) function tr_pullback(ȳ) s = noprime(firstsiteinds(x)) @@ -40,7 +40,7 @@ function ChainRulesCore.rrule(::typeof(tr), x::MPO; kwargs...) return y, tr_pullback end -function ChainRulesCore.rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...) +function rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...) if !hassameinds(siteinds, x1, (x2, x3)) || !hassameinds(siteinds, x3, (x2, x1)) error( "Taking gradients of `inner(x::MPS, A::MPO, y::MPS)` is not supported if the site indices of the input MPS and MPO don't match. Try using if you input `inner(x, A, y), try `inner(x', A, y)` instead.", diff --git a/src/ITensorChainRules/zygoterules.jl b/src/ITensorChainRules/zygoterules.jl index 65e82549c2..3a3607ea8b 100644 --- a/src/ITensorChainRules/zygoterules.jl +++ b/src/ITensorChainRules/zygoterules.jl @@ -2,11 +2,9 @@ using ZygoteRules: @adjoint # Needed for defining the rule for `adjoint(A::ITensor)` # which currently doesn't work by overloading `ChainRulesCore.rrule` +# since it is defined in `Zygote`, which takes precedent. @adjoint function Base.adjoint(x::Union{ITensor,MPS,MPO}) - y = prime(x) - function adjoint_pullback(ȳ) - x̄ = inv_op(prime, ȳ) - return (x̄,) - end + y, adjoint_rrule_pullback = rrule(adjoint, x) + adjoint_pullback(ȳ) = Base.tail(adjoint_rrule_pullback(ȳ)) return y, adjoint_pullback end