Skip to content

Commit

Permalink
[ITensors] [ENHANCEMENT] Fix apply(::MPO, ::MPO) autodiff, add Vara…
Browse files Browse the repository at this point in the history
…rg `apply(::MPO...)` (#949)
  • Loading branch information
mtfishman committed Jul 7, 2022
1 parent 5b4ad1b commit efacea7
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 38 deletions.
11 changes: 11 additions & 0 deletions NDTensors/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ Note that as of Julia v1.5, in order to see deprecation warnings you will need t

After we release v1 of the package, we will start following [semantic versioning](https://semver.org).

NDTensors v0.1.42 Release Notes
===============================

Bugs:

Enhancements:

- Define `map` for Tensor and TensorStorage (b66d1b7)
- Define `real` and `imag` for Tensor (b66d1b7)
- Throw error when trying to do an eigendecomposition of Tensor with Infs or NaNs (b66d1b7)

NDTensors v0.1.41 Release Notes
===============================

Expand Down
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>"]
version = "0.1.41"
version = "0.1.42"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
14 changes: 10 additions & 4 deletions NDTensors/src/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,11 @@ function LinearAlgebra.eigen(

matrixT = matrix(T)
if any(!isfinite, matrixT)
display(matrixT)
throw(ArgumentError("Trying to perform the eigendecomposition of a matrix containing NaNs or Infs"))
throw(
ArgumentError(
"Trying to perform the eigendecomposition of a matrix containing NaNs or Infs"
),
)
end

DM, VM = eigen(matrixT)
Expand Down Expand Up @@ -351,8 +354,11 @@ function LinearAlgebra.eigen(

matrixT = matrix(T)
if any(!isfinite, matrixT)
display(matrixT)
throw(ArgumentError("Trying to perform the eigendecomposition of a matrix containing NaNs or Infs"))
throw(
ArgumentError(
"Trying to perform the eigendecomposition of a matrix containing NaNs or Infs"
),
)
end

DM, VM = eigen(matrixT)
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ Base.similar(T::Tensor, args...) = similar(T, args...)

function map(f, x::Tensor{T}) where {T}
if !iszero(f(zero(T)))
error("map(f, ::Tensor) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.")
error(
"map(f, ::Tensor) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.",
)
end
return setstorage(x, map(f, storage(x)))
end
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/tensorstorage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ Random.randn!(S::TensorStorage) = (randn!(data(S)); S)

function map(f, x::TensorStorage{T}) where {T}
if !iszero(f(zero(T)))
error("map(f, ::TensorStorage) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.")
error(
"map(f, ::TensorStorage) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.",
)
end
return setdata(x, map(f, data(x)))
end
Expand Down
37 changes: 37 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,43 @@ Note that as of Julia v1.5, in order to see deprecation warnings you will need t

After we release v1 of the package, we will start following [semantic versioning](https://semver.org).

ITensors v0.3.18 Release Notes
==============================

Bugs:

- Extend `apply(::MPO, ::MPO)` to `apply(::MPO, ::MPO, ::MPO...)` (#949)
- Fix AD for `apply(::MPO, ::MPO)` and `contract(::MPO, ::MPO)` (#949)
- Properly use element type in `randomMPS` in the 1-site case (b66d1b7)
- Fix bug in `tr(::MPO)` rrule where the derivative was being multiplied twice into the identity MPO (b66d1b7)
- Fix directsum when specifying a single `Index` (#930)
- Fix bug in loginner when inner is negative or complex (#945)
- Fix subtraction bug in `OpSum` (#945)

Enhancements:

- Define "I" for Qudit/Boson type (b66d1b7)
- Only warn in `inner` if the result is `Inf` or `NaN` (b66d1b7)
- Make sure `randomITensor(())` and `randomITensor(Float64, ())` returns a Dense storage type (b66d1b7)
- Define `isreal` and `iszero` for ITensors (b66d1b7)
- Project element type of ITensor in reverse pass of tensor-tensor or scalar-tensor contraction (b66d1b7)
- Define reverse rules for ITensor subtraction and negation (b66d1b7)
- Define `map` for ITensors (b66d1b7)
- Throw error when performing eigendecomposition of tensor with NaN or Inf elements (b66d1b7)
- Fix `rrule` for `MPO` constructor by generalizing the `rrule` for the `MPS` constructor (#946)
- Forward truncation arguments to more operations in `rrule` for `apply` (#945)
- Add rrules for addition and subtraction of MPOs (#935)

ITensors v0.3.17 Release Notes
==============================

Bugs:

Enhancements:

- Add Zp as alias for operator Z+, etc. (#942)
- Export diag (#942)

ITensors v0.3.16 Release Notes
==============================

Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensors"
uuid = "9136182c-28ba-11e9-034c-db9fb085ebd5"
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>", "Miles Stoudenmire <mstoudenmire@flatironinstitute.org>"]
version = "0.3.17"
version = "0.3.18"

[deps]
BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
Expand Down Expand Up @@ -36,7 +36,7 @@ HDF5 = "0.14, 0.15, 0.16"
IsApprox = "0.1"
KrylovKit = "0.4.2, 0.5"
LinearMaps = "3"
NDTensors = "0.1.41"
NDTensors = "0.1.42"
PackageCompiler = "1.0.0, 2"
Requires = "1.1"
SerializedElementArrays = "0.1"
Expand Down
7 changes: 6 additions & 1 deletion src/ITensorChainRules/ITensorChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ import ChainRulesCore: rrule

ITensors.dag(z::AbstractZero) = z

map_notangent(a) = map(Returns(NoTangent()), a)
if VERSION < v"1.7"
map_notangent(a) = map(_ -> NoTangent(), a)
else
map_notangent(a) = map(Returns(NoTangent()), a)
end

include("projection.jl")
include(joinpath("NDTensors", "tensor.jl"))
Expand Down Expand Up @@ -40,5 +44,6 @@ include("zygoterules.jl")
@non_differentiable ITensors.filter_inds_set_function(::Function, ::Any...)
@non_differentiable ITensors.indpairs(::Any...)
@non_differentiable onehot(::Any...)
@non_differentiable Base.convert(::Type{TagSet}, str::String)

end
9 changes: 9 additions & 0 deletions src/ITensorChainRules/itensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,13 @@ function ChainRulesCore.rrule(::typeof(permute), x::ITensor, a...)
return y, permute_pullback
end

# Needed because by default it was calling the generic
# `rrule` for `tr` inside ChainRules.
# TODO: Raise an issue with ChainRules.
function ChainRulesCore.rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(tr), x::ITensor; kwargs...
)
return rrule_via_ad(config, ITensors._tr, x; kwargs...)
end

@non_differentiable combiner(::Indices)
22 changes: 11 additions & 11 deletions src/ITensorChainRules/mps/mpo.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
function rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...)
y = *(x1, x2; kwargs...)
function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...)
y = contract(x1, x2; kwargs...)
function contract_pullback(ȳ)
x̄1 = *(ȳ, dag(x2); kwargs...)
x̄2 = *(dag(x1), ȳ; kwargs...)
x̄1 = contract(ȳ, dag(x2); kwargs...)
x̄2 = contract(dag(x1), ȳ; kwargs...)
return (NoTangent(), x̄1, x̄2)
end
return y, contract_pullback
end

function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...)
return rrule(contract, x1, x2; kwargs...)
end

function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...)
y = +(x1, x2; kwargs...)
function add_pullback(ȳ)
Expand All @@ -17,14 +21,10 @@ function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...)
end

function ChainRulesCore.rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...)
y = -(x1, x2; kwargs...)
function subtract_pullback(ȳ)
return (NoTangent(), ȳ, -ȳ)
end
return y, subtract_pullback
return rrule(+, x1, -x2; kwargs...)
end

function rrule(::typeof(tr), x::MPO; kwargs...)
function ChainRulesCore.rrule(::typeof(tr), x::MPO; kwargs...)
y = tr(x; kwargs...)
function tr_pullback(ȳ)
s = noprime(firstsiteinds(x))
Expand All @@ -40,7 +40,7 @@ function rrule(::typeof(tr), x::MPO; kwargs...)
return y, tr_pullback
end

function rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...)
function ChainRulesCore.rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...)
if !hassameinds(siteinds, x1, (x2, x3)) || !hassameinds(siteinds, x3, (x2, x1))
error(
"Taking gradients of `inner(x::MPS, A::MPO, y::MPS)` is not supported if the site indices of the input MPS and MPO don't match. Try using if you input `inner(x, A, y), try `inner(x', A, y)` instead.",
Expand Down
16 changes: 1 addition & 15 deletions src/ITensorChainRules/zygoterules.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using ZygoteRules: @adjoint

# Needed for defining the rule for `adjoint(A::ITensor)`
# which currently doesn't work by overloading `ChainRulesCore.rrule`
using ZygoteRules: @adjoint

@adjoint function Base.adjoint(x::Union{ITensor,MPS,MPO})
y = prime(x)
function adjoint_pullback(ȳ)
Expand All @@ -11,16 +10,3 @@ using ZygoteRules: @adjoint
end
return y, adjoint_pullback
end

## XXX: raise issue about `tr` being too generically
## defined in ChainRules
##
## using Zygote
##
## # Needed because by default it was calling the generic
## # rrule for `tr` inside ChainRules
## function rrule(::typeof(tr), x::ITensor; kwargs...)
## y, tr_pullback_zygote = pullback(ITensors._tr, x; kwargs...)
## tr_pullback(ȳ) = (NoTangent(), tr_pullback_zygote(ȳ)...)
## return y, tr_pullback
## end
2 changes: 1 addition & 1 deletion src/mps/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1624,7 +1624,7 @@ function truncate(ψ0::AbstractMPS; kwargs...)
return ψ
end

# Make `*` and alias for `contract` of two `AbstractMPS`
# Make `*` an alias for `contract` of two `AbstractMPS`
*(A::AbstractMPS, B::AbstractMPS; kwargs...) = contract(A, B; kwargs...)

function _apply_to_orthocenter!(f, ψ::AbstractMPS, x)
Expand Down
4 changes: 4 additions & 0 deletions src/mps/mpo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,10 @@ function apply(A::MPO, B::MPO; kwargs...)
return replaceprime(AB, 2 => 1)
end

function apply(A1::MPO, A2::MPO, A3::MPO, As::MPO...; kwargs...)
return apply(apply(A1, A2; kwargs...), A3, As...; kwargs...)
end

(A::MPO)(B::MPO; kwargs...) = apply(A, B; kwargs...)

contract_mpo_mpo_doc = """
Expand Down
21 changes: 20 additions & 1 deletion test/ITensorChainRules/test_chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ Random.seed!(1234)
# https://github.com/ITensor/ITensors.jl/issues/936
n = 2
s = siteinds("S=1/2", n)
x = randomMPS(s) |> x -> outer(x', x)
x = (x -> outer(x', x))(randomMPS(s))
f1 = x -> tr(x)
f2 = x -> 2tr(x)
f3 = x -> -tr(x)
Expand Down Expand Up @@ -627,3 +627,22 @@ end
∇num = (f+ ϵ) - f(θ)) / ϵ
@test ∇f ∇num atol = 1e-5
end

@testset "contract/apply MPOs" begin
n = 2
s = siteinds("S=1/2", n)
x = (x -> outer(x', x))(randomMPS(s; linkdims=4))
x_itensor = contract(x)

f = x -> tr(apply(x, x))
@test f(x) f(x_itensor)
@test contract(f'(x)) f'(x_itensor)

f = x -> tr(replaceprime(contract(x', x), 2 => 1))
@test f(x) f(x_itensor)
@test contract(f'(x)) f'(x_itensor)

f = x -> tr(replaceprime(*(x', x), 2 => 1))
@test f(x) f(x_itensor)
@test contract(f'(x)) f'(x_itensor)
end
33 changes: 33 additions & 0 deletions test/itensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,39 @@ end
@test !hascommoninds(A, C)
end

@testset "isreal, iszero, real, imag" begin
i, j = Index.(2, ("i", "j"))
A = randomITensor(i, j)
Ac = randomITensor(ComplexF64, i, j)
Ar = real(Ac)
Ai = imag(Ac)
@test Ac Ar + im * Ai
@test isreal(A)
@test !isreal(Ac)
@test isreal(Ar)
@test isreal(Ai)
@test !iszero(A)
@test !iszero(real(A))
@test iszero(imag(A))
@test iszero(ITensor(0.0, i, j))
@test iszero(ITensor(i, j))
end

@testset "map" begin
A = randomITensor(Index(2))
@test eltype(A) == Float64
B = map(ComplexF64, A)
@test B A
@test eltype(B) == ComplexF64
B = map(Float32, A)
@test B A
@test eltype(B) == Float32
B = map(x -> 2x, A)
@test B 2A
@test eltype(B) == Float64
@test_throws ErrorException map(x -> x + 1, A)
end

@testset "getindex with state string" begin
i₁ = Index(2, "S=1/2")
i₂ = Index(2, "S=1/2")
Expand Down
8 changes: 8 additions & 0 deletions test/mpo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,14 @@ end
@test_throws DimensionMismatch K * badL
end

@testset "Multi-arg apply(::MPO...)" begin
ρ1 = (x -> outer(x', x; maxdim=4))(randomMPS(sites; linkdims=2))
ρ2 = (x -> outer(x', x; maxdim=4))(randomMPS(sites; linkdims=2))
ρ3 = (x -> outer(x', x; maxdim=4))(randomMPS(sites; linkdims=2))
@test apply(ρ1, ρ2, ρ3; cutoff=1e-8)
apply(apply(ρ1, ρ2; cutoff=1e-8), ρ3; cutoff=1e-8)
end

sites = siteinds("S=1/2", N)
O = MPO(sites, "Sz")
@test length(O) == N # just make sure this works
Expand Down

4 comments on commit efacea7

@mtfishman
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=NDTensors

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/63801

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a NDTensors-v0.1.42 -m "<description of version>" efacea7b2e5bc1894b58ce75874352d825eb8954
git push origin NDTensors-v0.1.42

@mtfishman
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/63821

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.18 -m "<description of version>" efacea7b2e5bc1894b58ce75874352d825eb8954
git push origin v0.3.18

Please sign in to comment.