Skip to content

Commit

Permalink
[ITensors] Simplify the rrules for priming and tagging MPS/MPO (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jul 7, 2022
1 parent efacea7 commit 6a83a83
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 102 deletions.
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ Note that as of Julia v1.5, in order to see deprecation warnings you will need t

After we release v1 of the package, we will start following [semantic versioning](https://semver.org).

ITensors v0.3.19 Release Notes
==============================

Bugs:

Enhancements:

- Simplify the `rrule`s for priming and tagging MPS/MPO

ITensors v0.3.18 Release Notes
==============================

Expand Down
94 changes: 19 additions & 75 deletions src/ITensorChainRules/indexset.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,3 @@
function setinds_pullback(ȳ, x, a...)
= ITensors.setinds(ȳ, inds(x))
ā = map_notangent(a)
return (NoTangent(), x̄, ā...)
end

function inv_op(f::Function, args...; kwargs...)
return error(
"Trying to differentiate `$f` but the inverse of the operation (`inv_op`) `$f` with arguments $args and keyword arguments $kwargs is not defined.",
)
end

function inv_op(::typeof(prime), x, n::Integer=1; kwargs...)
return prime(x, -n; kwargs...)
end

function inv_op(::typeof(replaceprime), x, n1n2::Pair; kwargs...)
return replaceprime(x, reverse(n1n2); kwargs...)
end

function inv_op(::typeof(addtags), x, args...; kwargs...)
return removetags(x, args...; kwargs...)
end

function inv_op(::typeof(removetags), x, args...; kwargs...)
return addtags(x, args...; kwargs...)
end

function inv_op(::typeof(replacetags), x, n1n2::Pair; kwargs...)
return replacetags(x, reverse(n1n2); kwargs...)
end

_check_inds(x::ITensor, y::ITensor) = hassameinds(x, y)
_check_inds(x::MPS, y::MPS) = hassameinds(siteinds, x, y)
_check_inds(x::MPO, y::MPO) = hassameinds(siteinds, x, y)

for fname in (
:prime, :setprime, :noprime, :replaceprime, :addtags, :removetags, :replacetags, :settags
)
@eval begin
function ChainRulesCore.rrule(f::typeof($fname), x::Union{MPS,MPO}, a...; kwargs...)
y = f(x, a...; kwargs...)
function f_pullback(ȳ)
= inv_op(f, unthunk(ȳ), a...; kwargs...)
if !_check_inds(x, x̄)
error(
"Trying to differentiate function `$f` with arguments $a and keyword arguments $kwargs. The forward pass indices $(inds(x)) do not match the reverse pass indices $(inds(x̄)). Likely this is because the priming/tagging operation you tried to perform is not invertible. Please write your code in a way where the index manipulation operation you are performing is invertible. For example, `prime(A::ITensor)` is invertible, with an inverse `prime(A, -1)`. However, `noprime(A)` is in general not invertible since the information about the prime levels of the original tensor are lost. Instead, you might try `prime(A, -1)` or `replaceprime(A, 1 => 0)` which are invertible.",
)
end
ā = map_notangent(a)
return (NoTangent(), x̄, ā...)
end
return y, f_pullback
end
end
end

for fname in (
:prime,
:setprime,
Expand All @@ -72,11 +15,10 @@ for fname in (
:swapinds,
)
@eval begin
function ChainRulesCore.rrule(f::typeof($fname), x::ITensor, a...; kwargs...)
function rrule(f::typeof($fname), x::ITensor, a...; kwargs...)
y = f(x, a...; kwargs...)
function f_pullback(ȳ)
= unthunk(ȳ)
= replaceinds(uȳ, inds(y), inds(x))
= replaceinds(unthunk(ȳ), inds(y) => inds(x))
ā = map_notangent(a)
return (NoTangent(), x̄, ā...)
end
Expand All @@ -85,23 +27,25 @@ for fname in (
end
end

function ChainRulesCore.rrule(::typeof(adjoint), x::ITensor)
y = x'
function adjoint_pullback(ȳ)
= unthunk(ȳ)
= replaceinds(uȳ, inds(y), inds(x))
return (NoTangent(), x̄)
for fname in (
:prime, :setprime, :noprime, :replaceprime, :addtags, :removetags, :replacetags, :settags
)
@eval begin
function rrule(f::typeof($fname), x::Union{MPS,MPO}, a...; kwargs...)
y = f(x, a...; kwargs...)
function f_pullback(ȳ)
= copy(unthunk(ȳ))
for j in eachindex(x̄)
x̄[j] = replaceinds(ȳ[j], inds(y[j]) => inds(x[j]))
end
ā = map_notangent(a)
return (NoTangent(), x̄, ā...)
end
return y, f_pullback
end
end
return y, adjoint_pullback
end

function ChainRulesCore.rrule(::typeof(adjoint), x::Union{MPS,MPO})
y = x'
function adjoint_pullback(ȳ)
= inv_op(prime, ȳ)
return (NoTangent(), x̄)
end
return y, adjoint_pullback
end
rrule(::typeof(adjoint), x::Union{ITensor,MPS,MPO}) = rrule(prime, x)

@non_differentiable permute(::Indices, ::Indices)
30 changes: 14 additions & 16 deletions src/ITensorChainRules/itensor.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function ChainRulesCore.rrule(::typeof(getindex), x::ITensor, I...)
function rrule(::typeof(getindex), x::ITensor, I...)
y = getindex(x, I...)
function getindex_pullback(ȳ)
# TODO: add definition `ITensor(::Tuple{}) = ITensor()`
Expand All @@ -14,7 +14,7 @@ end
# Specialized version in order to avoid call to `setindex!`
# within the pullback, should be better for taking higher order
# derivatives in Zygote.
function ChainRulesCore.rrule(::typeof(getindex), x::ITensor)
function rrule(::typeof(getindex), x::ITensor)
y = x[]
function getindex_pullback(ȳ)
= ITensor(unthunk(ȳ))
Expand Down Expand Up @@ -91,7 +91,7 @@ function rrule(::typeof(tensor), x1::ITensor)
end

# Special case for contracting a pair of ITensors
function ChainRulesCore.rrule(::typeof(contract), x1::ITensor, x2::ITensor)
function rrule(::typeof(contract), x1::ITensor, x2::ITensor)
project_x1 = ProjectTo(x1)
project_x2 = ProjectTo(x2)
function contract_pullback(ȳ)
Expand All @@ -104,7 +104,7 @@ end

@non_differentiable ITensors.optimal_contraction_sequence(::Any)

function ChainRulesCore.rrule(::typeof(*), x1::Number, x2::ITensor)
function rrule(::typeof(*), x1::Number, x2::ITensor)
project_x1 = ProjectTo(x1)
project_x2 = ProjectTo(x2)
function contract_pullback(ȳ)
Expand All @@ -115,7 +115,7 @@ function ChainRulesCore.rrule(::typeof(*), x1::Number, x2::ITensor)
return x1 * x2, contract_pullback
end

function ChainRulesCore.rrule(::typeof(*), x1::ITensor, x2::Number)
function rrule(::typeof(*), x1::ITensor, x2::Number)
project_x1 = ProjectTo(x1)
project_x2 = ProjectTo(x2)
function contract_pullback(ȳ)
Expand All @@ -126,28 +126,28 @@ function ChainRulesCore.rrule(::typeof(*), x1::ITensor, x2::Number)
return x1 * x2, contract_pullback
end

function ChainRulesCore.rrule(::typeof(+), x1::ITensor, x2::ITensor)
function rrule(::typeof(+), x1::ITensor, x2::ITensor)
function add_pullback(ȳ)
return (NoTangent(), ȳ, ȳ)
end
return x1 + x2, add_pullback
end

function ChainRulesCore.rrule(::typeof(-), x1::ITensor, x2::ITensor)
function rrule(::typeof(-), x1::ITensor, x2::ITensor)
function subtract_pullback(ȳ)
return (NoTangent(), ȳ, -ȳ)
end
return x1 - x2, subtract_pullback
end

function ChainRulesCore.rrule(::typeof(-), x::ITensor)
function rrule(::typeof(-), x::ITensor)
function minus_pullback(ȳ)
return (NoTangent(), -ȳ)
end
return -x, minus_pullback
end

function ChainRulesCore.rrule(::typeof(itensor), x::Array, a...)
function rrule(::typeof(itensor), x::Array, a...)
function itensor_pullback(ȳ)
= permute(unthunk(ȳ), a...)
= reshape(array(uȳ), size(x))
Expand All @@ -157,7 +157,7 @@ function ChainRulesCore.rrule(::typeof(itensor), x::Array, a...)
return itensor(x, a...), itensor_pullback
end

function ChainRulesCore.rrule(::Type{ITensor}, x::Array{<:Number}, a...)
function rrule(::Type{ITensor}, x::Array{<:Number}, a...)
function ITensor_pullback(ȳ)
# TODO: define `Array(::ITensor)` directly
= Array(unthunk(ȳ), a...)
Expand All @@ -168,23 +168,23 @@ function ChainRulesCore.rrule(::Type{ITensor}, x::Array{<:Number}, a...)
return ITensor(x, a...), ITensor_pullback
end

function ChainRulesCore.rrule(::Type{ITensor}, x::Number)
function rrule(::Type{ITensor}, x::Number)
function ITensor_pullback(ȳ)
= ȳ[]
return (NoTangent(), x̄)
end
return ITensor(x), ITensor_pullback
end

function ChainRulesCore.rrule(::typeof(dag), x::ITensor)
function rrule(::typeof(dag), x::ITensor)
function dag_pullback(ȳ)
= dag(unthunk(ȳ))
return (NoTangent(), x̄)
end
return dag(x), dag_pullback
end

function ChainRulesCore.rrule(::typeof(permute), x::ITensor, a...)
function rrule(::typeof(permute), x::ITensor, a...)
y = permute(x, a...)
function permute_pullback(ȳ)
= permute(unthunk(ȳ), inds(x))
Expand All @@ -197,9 +197,7 @@ end
# Needed because by default it was calling the generic
# `rrule` for `tr` inside ChainRules.
# TODO: Raise an issue with ChainRules.
function ChainRulesCore.rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(tr), x::ITensor; kwargs...
)
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(tr), x::ITensor; kwargs...)
return rrule_via_ad(config, ITensors._tr, x; kwargs...)
end

Expand Down
12 changes: 6 additions & 6 deletions src/ITensorChainRules/mps/mpo.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...)
function rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...)
y = contract(x1, x2; kwargs...)
function contract_pullback(ȳ)
x̄1 = contract(ȳ, dag(x2); kwargs...)
Expand All @@ -8,23 +8,23 @@ function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...)
return y, contract_pullback
end

function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...)
function rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...)
return rrule(contract, x1, x2; kwargs...)
end

function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...)
function rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...)
y = +(x1, x2; kwargs...)
function add_pullback(ȳ)
return (NoTangent(), ȳ, ȳ)
end
return y, add_pullback
end

function ChainRulesCore.rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...)
function rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...)
return rrule(+, x1, -x2; kwargs...)
end

function ChainRulesCore.rrule(::typeof(tr), x::MPO; kwargs...)
function rrule(::typeof(tr), x::MPO; kwargs...)
y = tr(x; kwargs...)
function tr_pullback(ȳ)
s = noprime(firstsiteinds(x))
Expand All @@ -40,7 +40,7 @@ function ChainRulesCore.rrule(::typeof(tr), x::MPO; kwargs...)
return y, tr_pullback
end

function ChainRulesCore.rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...)
function rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...)
if !hassameinds(siteinds, x1, (x2, x3)) || !hassameinds(siteinds, x3, (x2, x1))
error(
"Taking gradients of `inner(x::MPS, A::MPO, y::MPS)` is not supported if the site indices of the input MPS and MPO don't match. Try using if you input `inner(x, A, y), try `inner(x', A, y)` instead.",
Expand Down
8 changes: 3 additions & 5 deletions src/ITensorChainRules/zygoterules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ using ZygoteRules: @adjoint

# Needed for defining the rule for `adjoint(A::ITensor)`
# which currently doesn't work by overloading `ChainRulesCore.rrule`
# since it is defined in `Zygote`, which takes precedent.
@adjoint function Base.adjoint(x::Union{ITensor,MPS,MPO})
y = prime(x)
function adjoint_pullback(ȳ)
= inv_op(prime, ȳ)
return (x̄,)
end
y, adjoint_rrule_pullback = rrule(adjoint, x)
adjoint_pullback(ȳ) = Base.tail(adjoint_rrule_pullback(ȳ))
return y, adjoint_pullback
end

0 comments on commit 6a83a83

Please sign in to comment.