Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorMPS"
uuid = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2"
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>", "Miles Stoudenmire <mstoudenmire@flatironinstitute.org>"]
version = "0.3.22"
version = "0.3.23"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
37 changes: 24 additions & 13 deletions src/mps.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Adapt: adapt
using NDTensors: using_auto_fermion
using NDTensors.TypeParameterAccessors: unspecify_type_parameters
using Random: Random
using ITensors.SiteTypes: SiteTypes, siteind, siteinds, state

Expand Down Expand Up @@ -737,7 +738,8 @@ Cuu = correlation_matrix(psi, "Cdagup", "Cup"; sites=2:8)
```
"""
function correlation_matrix(
psi::MPS, _Op1, _Op2; sites = 1:length(psi), site_range = nothing, ishermitian = nothing
psi::MPS, _Op1, _Op2; sites = 1:length(psi), site_range = nothing,
ishermitian = nothing
)
if !isnothing(site_range)
@warn "The `site_range` keyword arg. to `correlation_matrix` is deprecated: use the keyword `sites` instead"
Expand Down Expand Up @@ -796,7 +798,12 @@ function correlation_matrix(
# Nb = size of block of correlation matrix
Nb = length(sites)

C = zeros(ElT, Nb, Nb)
op1_start = op(_Op1, s[start_site])
op2_start = op(_Op2, s[start_site])
ElT1 = eltype(op1_start)
ElT2 = eltype(op2_start)
ElT′ = promote_type(ElT1, ElT2, ElT)
C = zeros(ElT′, Nb, Nb)

if start_site == 1
L = ITensor(1.0)
Expand All @@ -817,15 +824,15 @@ function correlation_matrix(

# Get j == i diagonal correlations
rind = commonind(psi[i], psi[i + 1])
oᵢ = adapt(datatype(Li), op(onsiteOp, s, i))
oᵢ = adapt(unspecify_type_parameters(datatype(Li)), op(onsiteOp, s, i))
C[ni, ni] = ((Li * oᵢ) * prime(dag(psi[i]), !rind))[] / norm2_psi

# Get j > i correlations
if !using_auto_fermion() && fermionic2
Op1 = "$Op1 * F"
end

oᵢ = adapt(datatype(Li), op(Op1, s, i))
oᵢ = adapt(unspecify_type_parameters(datatype(Li)), op(Op1, s, i))

Li12 = (dag(psi[i])' * oᵢ) * Li
pL12 = i
Expand All @@ -836,7 +843,8 @@ function correlation_matrix(
while pL12 < j - 1
pL12 += 1
if !using_auto_fermion() && fermionic2
oᵢ = adapt(datatype(psi[pL12]), op("F", s[pL12]))
dtype = unspecify_type_parameters(datatype(psi[pL12]))
oᵢ = adapt(dtype, op("F", s[pL12]))
Li12 *= (oᵢ * dag(psi[pL12])')
else
sᵢ = siteind(psi, pL12)
Expand All @@ -848,7 +856,7 @@ function correlation_matrix(
lind = commonind(psi[j], Li12)
Li12 *= psi[j]

oⱼ = adapt(datatype(Li12), op(Op2, s, j))
oⱼ = adapt(unspecify_type_parameters(datatype(Li12)), op(Op2, s, j))
sⱼ = siteind(psi, j)
val = (Li12 * oⱼ) * prime(dag(psi[j]), (sⱼ, lind))

Expand All @@ -863,7 +871,8 @@ function correlation_matrix(

pL12 += 1
if !using_auto_fermion() && fermionic2
oᵢ = adapt(datatype(psi[pL12]), op("F", s[pL12]))
dtype = unspecify_type_parameters(datatype(psi[pL12]))
oᵢ = adapt(dtype, op("F", s[pL12]))
Li12 *= (oᵢ * dag(psi[pL12])')
else
sᵢ = siteind(psi, pL12)
Expand All @@ -879,7 +888,7 @@ function correlation_matrix(
if !using_auto_fermion() && fermionic1
Op2 = "$Op2 * F"
end
oᵢ = adapt(datatype(psi[i]), op(Op2, s, i))
oᵢ = adapt(unspecify_type_parameters(datatype(psi[i])), op(Op2, s, i))
Li21 = (Li * oᵢ) * dag(psi[i])'
pL21 = i
if !using_auto_fermion() && fermionic1
Expand All @@ -892,7 +901,8 @@ function correlation_matrix(
while pL21 < j - 1
pL21 += 1
if !using_auto_fermion() && fermionic1
oᵢ = adapt(datatype(psi[pL21]), op("F", s[pL21]))
dtype = unspecify_type_parameters(datatype(psi[pL21]))
oᵢ = adapt(dtype, op("F", s[pL21]))
Li21 *= oᵢ * dag(psi[pL21])'
else
sᵢ = siteind(psi, pL21)
Expand All @@ -904,14 +914,15 @@ function correlation_matrix(
lind = commonind(psi[j], Li21)
Li21 *= psi[j]

oⱼ = adapt(datatype(psi[j]), op(Op1, s, j))
oⱼ = adapt(unspecify_type_parameters(datatype(psi[j])), op(Op1, s, j))
sⱼ = siteind(psi, j)
val = (prime(dag(psi[j]), (sⱼ, lind)) * (oⱼ * Li21))[]
C[nj, ni] = val / norm2_psi

pL21 += 1
if !using_auto_fermion() && fermionic1
oᵢ = adapt(datatype(psi[pL21]), op("F", s[pL21]))
dtype = unspecify_type_parameters(datatype(psi[pL21]))
oᵢ = adapt(dtype, op("F", s[pL21]))
Li21 *= (oᵢ * dag(psi[pL21])')
else
sᵢ = siteind(psi, pL21)
Expand All @@ -935,7 +946,7 @@ function correlation_matrix(
L = L * psi[pL] * prime(dag(psi[pL]), !sᵢ)
end
lind = commonind(psi[i], psi[i - 1])
oᵢ = adapt(datatype(psi[i]), op(onsiteOp, s, i))
oᵢ = adapt(unspecify_type_parameters(datatype(psi[i])), op(onsiteOp, s, i))
sᵢ = siteind(psi, i)
val = (L * (oᵢ * psi[i]) * prime(dag(psi[i]), (sᵢ, lind)))[]
C[Nb, Nb] = val / norm2_psi
Expand Down Expand Up @@ -1007,7 +1018,7 @@ function expect(psi::MPS, ops; sites = 1:length(psi), site_range = nothing)
for (entry, j) in enumerate(site_range)
psi = orthogonalize(psi, j)
for (n, opname) in enumerate(ops)
oⱼ = adapt(datatype(psi[j]), op(opname, s[j]))
oⱼ = adapt(unspecify_type_parameters(datatype(psi[j])), op(opname, s[j]))
val = inner(psi[j], apply(oⱼ, psi[j])) / norm2_psi
ex[n][entry] = (el_types[n] <: Real) ? real(val) : val
end
Expand Down
16 changes: 16 additions & 0 deletions test/base/test_mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,21 @@ end
@test_throws ErrorException expect(psi0, "Sz")
end

@testset "expect real wavefunction complex operator" for elt in (Float32, Float64)
N = 8
s = siteinds("S=1/2", N)
using StableRNGs: StableRNG
rng = StableRNG(123)
psi = random_mps(rng, elt, s; linkdims = 2)
eSy = zeros(complex(elt), N)
for j in 1:N
psi = orthogonalize(psi, j)
eSy[j] = (dag(psi[j]) * apply(op("Sy", s[j]), psi[j]))[]
end
res = expect(psi, "Sy")
@test res ≈ eSy atol = eps(elt)
end

@testset "Expected value and Correlations" begin
m = 2

Expand All @@ -935,6 +950,7 @@ end
("Sz", "Sz"),
("iSy", "iSy"),
("Sx", "Sx"),
("Sy", "Sy"),
("Sz", "Sx"),
("S+", "S+"),
("S-", "S+"),
Expand Down
Loading