From 7ab41590f2f47b11e0898f390c34030ea8e606fe Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 23 Oct 2025 10:36:25 -0400 Subject: [PATCH] Fix expect and correlation_matrix for complex operator and real states --- Project.toml | 2 +- src/mps.jl | 37 ++++++++++++++++++++++++------------- test/base/test_mps.jl | 16 ++++++++++++++++ 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index bd3fdcc..667edae 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorMPS" uuid = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2" authors = ["Matthew Fishman ", "Miles Stoudenmire "] -version = "0.3.22" +version = "0.3.23" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/mps.jl b/src/mps.jl index d5fcbea..2f3395c 100644 --- a/src/mps.jl +++ b/src/mps.jl @@ -1,5 +1,6 @@ using Adapt: adapt using NDTensors: using_auto_fermion +using NDTensors.TypeParameterAccessors: unspecify_type_parameters using Random: Random using ITensors.SiteTypes: SiteTypes, siteind, siteinds, state @@ -737,7 +738,8 @@ Cuu = correlation_matrix(psi, "Cdagup", "Cup"; sites=2:8) ``` """ function correlation_matrix( - psi::MPS, _Op1, _Op2; sites = 1:length(psi), site_range = nothing, ishermitian = nothing + psi::MPS, _Op1, _Op2; sites = 1:length(psi), site_range = nothing, + ishermitian = nothing ) if !isnothing(site_range) @warn "The `site_range` keyword arg. to `correlation_matrix` is deprecated: use the keyword `sites` instead" @@ -796,7 +798,12 @@ function correlation_matrix( # Nb = size of block of correlation matrix Nb = length(sites) - C = zeros(ElT, Nb, Nb) + op1_start = op(_Op1, s[start_site]) + op2_start = op(_Op2, s[start_site]) + ElT1 = eltype(op1_start) + ElT2 = eltype(op2_start) + ElT′ = promote_type(ElT1, ElT2, ElT) + C = zeros(ElT′, Nb, Nb) if start_site == 1 L = ITensor(1.0) @@ -817,7 +824,7 @@ function correlation_matrix( # Get j == i diagonal correlations rind = commonind(psi[i], psi[i + 1]) - oᵢ = adapt(datatype(Li), op(onsiteOp, s, i)) + oᵢ = adapt(unspecify_type_parameters(datatype(Li)), op(onsiteOp, s, i)) C[ni, ni] = ((Li * oᵢ) * prime(dag(psi[i]), !rind))[] / norm2_psi # Get j > i correlations @@ -825,7 +832,7 @@ function correlation_matrix( Op1 = "$Op1 * F" end - oᵢ = adapt(datatype(Li), op(Op1, s, i)) + oᵢ = adapt(unspecify_type_parameters(datatype(Li)), op(Op1, s, i)) Li12 = (dag(psi[i])' * oᵢ) * Li pL12 = i @@ -836,7 +843,8 @@ function correlation_matrix( while pL12 < j - 1 pL12 += 1 if !using_auto_fermion() && fermionic2 - oᵢ = adapt(datatype(psi[pL12]), op("F", s[pL12])) + dtype = unspecify_type_parameters(datatype(psi[pL12])) + oᵢ = adapt(dtype, op("F", s[pL12])) Li12 *= (oᵢ * dag(psi[pL12])') else sᵢ = siteind(psi, pL12) @@ -848,7 +856,7 @@ function correlation_matrix( lind = commonind(psi[j], Li12) Li12 *= psi[j] - oⱼ = adapt(datatype(Li12), op(Op2, s, j)) + oⱼ = adapt(unspecify_type_parameters(datatype(Li12)), op(Op2, s, j)) sⱼ = siteind(psi, j) val = (Li12 * oⱼ) * prime(dag(psi[j]), (sⱼ, lind)) @@ -863,7 +871,8 @@ function correlation_matrix( pL12 += 1 if !using_auto_fermion() && fermionic2 - oᵢ = adapt(datatype(psi[pL12]), op("F", s[pL12])) + dtype = unspecify_type_parameters(datatype(psi[pL12])) + oᵢ = adapt(dtype, op("F", s[pL12])) Li12 *= (oᵢ * dag(psi[pL12])') else sᵢ = siteind(psi, pL12) @@ -879,7 +888,7 @@ function correlation_matrix( if !using_auto_fermion() && fermionic1 Op2 = "$Op2 * F" end - oᵢ = adapt(datatype(psi[i]), op(Op2, s, i)) + oᵢ = adapt(unspecify_type_parameters(datatype(psi[i])), op(Op2, s, i)) Li21 = (Li * oᵢ) * dag(psi[i])' pL21 = i if !using_auto_fermion() && fermionic1 @@ -892,7 +901,8 @@ function correlation_matrix( while pL21 < j - 1 pL21 += 1 if !using_auto_fermion() && fermionic1 - oᵢ = adapt(datatype(psi[pL21]), op("F", s[pL21])) + dtype = unspecify_type_parameters(datatype(psi[pL21])) + oᵢ = adapt(dtype, op("F", s[pL21])) Li21 *= oᵢ * dag(psi[pL21])' else sᵢ = siteind(psi, pL21) @@ -904,14 +914,15 @@ function correlation_matrix( lind = commonind(psi[j], Li21) Li21 *= psi[j] - oⱼ = adapt(datatype(psi[j]), op(Op1, s, j)) + oⱼ = adapt(unspecify_type_parameters(datatype(psi[j])), op(Op1, s, j)) sⱼ = siteind(psi, j) val = (prime(dag(psi[j]), (sⱼ, lind)) * (oⱼ * Li21))[] C[nj, ni] = val / norm2_psi pL21 += 1 if !using_auto_fermion() && fermionic1 - oᵢ = adapt(datatype(psi[pL21]), op("F", s[pL21])) + dtype = unspecify_type_parameters(datatype(psi[pL21])) + oᵢ = adapt(dtype, op("F", s[pL21])) Li21 *= (oᵢ * dag(psi[pL21])') else sᵢ = siteind(psi, pL21) @@ -935,7 +946,7 @@ function correlation_matrix( L = L * psi[pL] * prime(dag(psi[pL]), !sᵢ) end lind = commonind(psi[i], psi[i - 1]) - oᵢ = adapt(datatype(psi[i]), op(onsiteOp, s, i)) + oᵢ = adapt(unspecify_type_parameters(datatype(psi[i])), op(onsiteOp, s, i)) sᵢ = siteind(psi, i) val = (L * (oᵢ * psi[i]) * prime(dag(psi[i]), (sᵢ, lind)))[] C[Nb, Nb] = val / norm2_psi @@ -1007,7 +1018,7 @@ function expect(psi::MPS, ops; sites = 1:length(psi), site_range = nothing) for (entry, j) in enumerate(site_range) psi = orthogonalize(psi, j) for (n, opname) in enumerate(ops) - oⱼ = adapt(datatype(psi[j]), op(opname, s[j])) + oⱼ = adapt(unspecify_type_parameters(datatype(psi[j])), op(opname, s[j])) val = inner(psi[j], apply(oⱼ, psi[j])) / norm2_psi ex[n][entry] = (el_types[n] <: Real) ? real(val) : val end diff --git a/test/base/test_mps.jl b/test/base/test_mps.jl index eb73b23..6fac8e7 100644 --- a/test/base/test_mps.jl +++ b/test/base/test_mps.jl @@ -913,6 +913,21 @@ end @test_throws ErrorException expect(psi0, "Sz") end + @testset "expect real wavefunction complex operator" for elt in (Float32, Float64) + N = 8 + s = siteinds("S=1/2", N) + using StableRNGs: StableRNG + rng = StableRNG(123) + psi = random_mps(rng, elt, s; linkdims = 2) + eSy = zeros(complex(elt), N) + for j in 1:N + psi = orthogonalize(psi, j) + eSy[j] = (dag(psi[j]) * apply(op("Sy", s[j]), psi[j]))[] + end + res = expect(psi, "Sy") + @test res ≈ eSy atol = eps(elt) + end + @testset "Expected value and Correlations" begin m = 2 @@ -935,6 +950,7 @@ end ("Sz", "Sz"), ("iSy", "iSy"), ("Sx", "Sx"), + ("Sy", "Sy"), ("Sz", "Sx"), ("S+", "S+"), ("S-", "S+"),