Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State and val functions #664

Merged
merged 14 commits into from
May 16, 2021
16 changes: 8 additions & 8 deletions src/mps/mps.jl
Expand Up @@ -297,23 +297,23 @@ psi = MPS(ComplexF64, sites, states)
phi = MPS(sites, "Up")
```
"""
function MPS(::Type{T}, sites::Vector{<:Index}, states) where {T<:Number}
if length(sites) != length(states)
throw(DimensionMismatch("Number of sites and and initial states don't match"))
function MPS(::Type{T}, sites::Vector{<:Index}, vals) where {T<:Number}
if length(sites) != length(vals)
throw(DimensionMismatch("Number of sites and and initial vals don't match"))
end
ivals = [state(sites[n], states[n]) for n in 1:length(sites)]
ivals = [val(sites[n], vals[n]) for n in 1:length(sites)]
return MPS(T, ivals)
end

function MPS(
::Type{T}, sites::Vector{<:Index}, states::Union{String,Integer}
::Type{T}, sites::Vector{<:Index}, vals::Union{String,Integer}
) where {T<:Number}
ivals = [state(sites[n], states) for n in 1:length(sites)]
ivals = [val(sites[n], vals) for n in 1:length(sites)]
return MPS(T, ivals)
end

function MPS(::Type{T}, sites::Vector{<:Index}, states::Function) where {T<:Number}
ivals = [state(sites[n], states(n)) for n in 1:length(sites)]
function MPS(::Type{T}, sites::Vector{<:Index}, vals::Function) where {T<:Number}
ivals = [val(sites[n], vals(n)) for n in 1:length(sites)]
return MPS(T, ivals)
end

Expand Down
25 changes: 17 additions & 8 deletions src/physics/site_types/electron.jl
Expand Up @@ -58,14 +58,23 @@ function space(
return 4
end

state(::SiteType"Electron", ::StateName"Emp") = 1
state(::SiteType"Electron", ::StateName"Up") = 2
state(::SiteType"Electron", ::StateName"Dn") = 3
state(::SiteType"Electron", ::StateName"UpDn") = 4
state(st::SiteType"Electron", ::StateName"0") = state(st, StateName("Emp"))
state(st::SiteType"Electron", ::StateName"↑") = state(st, StateName("Up"))
state(st::SiteType"Electron", ::StateName"↓") = state(st, StateName("Dn"))
state(st::SiteType"Electron", ::StateName"↑↓") = state(st, StateName("UpDn"))
val(::ValName"Emp", ::SiteType"Electron") = 1
val(::ValName"Up", ::SiteType"Electron") = 2
val(::ValName"Dn", ::SiteType"Electron") = 3
val(::ValName"UpDn", ::SiteType"Electron") = 4
val(::ValName"0", st::SiteType"Electron") = val(ValName("Emp"), st)
val(::ValName"↑", st::SiteType"Electron") = val(ValName("Up"), st)
val(::ValName"↓", st::SiteType"Electron") = val(ValName("Dn"), st)
val(::ValName"↑↓", st::SiteType"Electron") = val(ValName("UpDn"), st)

state(::StateName"Emp", ::SiteType"Electron") = [1.0, 0, 0, 0]
state(::StateName"Up", ::SiteType"Electron") = [0.0, 1, 0, 0]
state(::StateName"Dn", ::SiteType"Electron") = [0.0, 0, 1, 0]
state(::StateName"UpDn", ::SiteType"Electron") = [0.0, 0, 0, 1]
state(::StateName"0", st::SiteType"Electron") = state(StateName("Emp"), st)
state(::StateName"↑", st::SiteType"Electron") = state(StateName("Up"), st)
state(::StateName"↓", st::SiteType"Electron") = state(StateName("Dn"), st)
state(::StateName"↑↓", st::SiteType"Electron") = state(StateName("UpDn"), st)

function op!(Op::ITensor, ::OpName"Nup", ::SiteType"Electron", s::Index)
Op[s' => 2, s => 2] = 1.0
Expand Down
13 changes: 9 additions & 4 deletions src/physics/site_types/fermion.jl
Expand Up @@ -59,10 +59,15 @@ function space(
return 2
end

state(::SiteType"Fermion", ::StateName"Emp") = 1
state(::SiteType"Fermion", ::StateName"Occ") = 2
state(st::SiteType"Fermion", ::StateName"0") = state(st, StateName("Emp"))
state(st::SiteType"Fermion", ::StateName"1") = state(st, StateName("Occ"))
val(::ValName"Emp", ::SiteType"Fermion") = 1
val(::ValName"Occ", ::SiteType"Fermion") = 2
val(::ValName"0", st::SiteType"Fermion") = val(ValName("Emp"), st)
val(::ValName"1", st::SiteType"Fermion") = val(ValName("Occ"), st)

state(::StateName"Emp", ::SiteType"Fermion") = 1
state(::StateName"Occ", ::SiteType"Fermion") = 2
state(::StateName"0", st::SiteType"Fermion") = state(StateName("Emp"), st)
state(::StateName"1", st::SiteType"Fermion") = state(StateName("Occ"), st)

function op!(Op::ITensor, ::OpName"N", ::SiteType"Fermion", s::Index)
return Op[s' => 2, s => 2] = 1.0
Expand Down
6 changes: 4 additions & 2 deletions src/physics/site_types/qubit.jl
Expand Up @@ -36,9 +36,11 @@ function space(
)
end

state(::SiteType"Qubit", ::StateName"0") = 1
val(::ValName"0", st::SiteType"Qubit") = 1
val(::ValName"1", st::SiteType"Qubit") = 2

state(::SiteType"Qubit", ::StateName"1") = 2
state(::StateName"0", ::SiteType"Qubit") = [1.0, 0.0]
state(::StateName"1", ::SiteType"Qubit") = [0.0, 1.0]

# Use S=1/2 definition of any operators
# called using Qubit SiteType
Expand Down
30 changes: 25 additions & 5 deletions src/physics/site_types/spinhalf.jl
Expand Up @@ -32,11 +32,29 @@ function space(
return 2
end

state(::SiteType"S=1/2", ::StateName"Up") = 1
state(::SiteType"S=1/2", ::StateName"Dn") = 2
val(::ValName"Up", ::SiteType"S=1/2") = 1
val(::ValName"Dn", ::SiteType"S=1/2") = 2

state(st::SiteType"S=1/2", ::StateName"↑") = state(st, StateName("Up"))
state(st::SiteType"S=1/2", ::StateName"↓") = state(st, StateName("Dn"))
val(::ValName"↑", st::SiteType"S=1/2") = 1
val(::ValName"↓", st::SiteType"S=1/2") = 2

val(::ValName"Z+", ::SiteType"S=1/2") = 1
val(::ValName"Z-", ::SiteType"S=1/2") = 2

state(::StateName"Up", ::SiteType"S=1/2") = [1.0, 0.0]
state(::StateName"Dn", ::SiteType"S=1/2") = [0.0, 1.0]

state(::StateName"↑", st::SiteType"S=1/2") = [1.0, 0.0]
state(::StateName"↓", st::SiteType"S=1/2") = [0.0, 1.0]

state(::StateName"Z+", st::SiteType"S=1/2") = [1.0, 0.0]
state(::StateName"Z-", st::SiteType"S=1/2") = [0.0, 1.0]

state(::StateName"X+", st::SiteType"S=1/2") = [1 / sqrt(2), 1 / sqrt(2)]
state(::StateName"X-", st::SiteType"S=1/2") = [1 / sqrt(2), -1 / sqrt(2)]

state(::StateName"Y+", st::SiteType"S=1/2") = [1 / sqrt(2), im / sqrt(2)]
state(::StateName"Y-", st::SiteType"S=1/2") = [1 / sqrt(2), -im / sqrt(2)]

op(::OpName"Z", ::SiteType"S=1/2") = [
1 0
Expand Down Expand Up @@ -129,14 +147,16 @@ op(::OpName"projDn", t::SiteType"S=1/2") = op(OpName("ProjDn"), t)

space(::SiteType"SpinHalf"; kwargs...) = space(SiteType("S=1/2"); kwargs...)

state(::SiteType"SpinHalf", n::StateName) = state(SiteType("S=1/2"), n)
val(name::ValName, ::SiteType"SpinHalf") = val(name, SiteType("S=1/2"))

op(o::OpName, ::SiteType"SpinHalf"; kwargs...) = op(o, SiteType("S=1/2"); kwargs...)

# Support the tag "S=½" as equivalent to "S=1/2"

space(::SiteType"S=½"; kwargs...) = space(SiteType("S=1/2"); kwargs...)

val(name::ValName, ::SiteType"S=½") = val(name, SiteType("S=1/2"))

state(::SiteType"S=½", n::StateName) = state(SiteType("S=1/2"), n)

op(o::OpName, ::SiteType"S=½"; kwargs...) = op(o, SiteType("S=1/2"); kwargs...)
45 changes: 35 additions & 10 deletions src/physics/site_types/spinone.jl
Expand Up @@ -18,13 +18,37 @@ function space(
return 3
end

state(::SiteType"S=1", ::StateName"Up") = 1
state(::SiteType"S=1", ::StateName"Z0") = 2
state(::SiteType"S=1", ::StateName"Dn") = 3
val(::ValName"Up", ::SiteType"S=1") = 1
val(::ValName"Z0", ::SiteType"S=1") = 2
val(::ValName"Dn", ::SiteType"S=1") = 3

state(st::SiteType"S=1", ::StateName"↑") = state(st, StateName("Up"))
state(st::SiteType"S=1", ::StateName"0") = state(st, StateName("Z0"))
state(st::SiteType"S=1", ::StateName"↓") = state(st, StateName("Dn"))
val(::ValName"↑", st::SiteType"S=1") = 1
val(::ValName"0", st::SiteType"S=1") = 2
val(::ValName"↓", st::SiteType"S=1") = 3

val(::ValName"Z+", ::SiteType"S=1") = 1
# -- Z0 is already defined above --
val(::ValName"Z-", ::SiteType"S=1") = 3

state(::StateName"Up", ::SiteType"S=1") = [1.0, 0.0, 0.0]
state(::StateName"Z0", ::SiteType"S=1") = [0.0, 1.0, 0.0]
state(::StateName"Dn", ::SiteType"S=1") = [0.0, 0.0, 1.0]

state(::StateName"↑", st::SiteType"S=1") = [1.0, 0.0, 0.0]
state(::StateName"0", st::SiteType"S=1") = [0.0, 1.0, 0.0]
state(::StateName"↓", st::SiteType"S=1") = [0.0, 0.0, 1.0]

state(::StateName"Z+", st::SiteType"S=1") = [1.0, 0.0, 0.0]
# -- Z0 is already defined above --
state(::StateName"Z-", st::SiteType"S=1") = [0.0, 0.0, 1.0]

state(::StateName"X+", ::SiteType"S=1") = [1 / 2, 1 / sqrt(2), 1 / 2]
state(::StateName"X0", ::SiteType"S=1") = [-1 / sqrt(2), 0, 1 / sqrt(2)]
state(::StateName"X-", ::SiteType"S=1") = [1 / 2, -1 / sqrt(2), 1 / 2]

state(::StateName"Y+", ::SiteType"S=1") = [-1 / 2, -im / sqrt(2), 1 / 2]
state(::StateName"Y0", ::SiteType"S=1") = [1 / sqrt(2), 0, 1 / sqrt(2)]
state(::StateName"Y-", ::SiteType"S=1") = [-1 / 2, im / sqrt(2), 1 / 2]

function op!(Op::ITensor, ::OpName"Sz", ::SiteType"S=1", s::Index)
Op[s' => 1, s => 1] = +1.0
Expand Down Expand Up @@ -70,10 +94,11 @@ end
op!(Op::ITensor, ::OpName"iSʸ", t::SiteType"S=1", s::Index) = op!(Op, OpName("iSy"), t, s)

function op!(Op::ITensor, ::OpName"Sy", ::SiteType"S=1", s::Index)
Op[s' => 2, s => 1] = -1im / sqrt(2)
Op[s' => 1, s => 2] = +1im / sqrt(2)
Op[s' => 3, s => 2] = -1im / sqrt(2)
return Op[s' => 2, s => 3] = +1im / sqrt(2)
complex!(Op)
Op[s' => 2, s => 1] = +1im / sqrt(2)
Op[s' => 1, s => 2] = -1im / sqrt(2)
Op[s' => 3, s => 2] = +1im / sqrt(2)
return Op[s' => 2, s => 3] = -1im / sqrt(2)
end

op!(Op::ITensor, ::OpName"Sʸ", t::SiteType"S=1", s::Index) = op!(Op, OpName("Sy"), t, s)
Expand Down
19 changes: 13 additions & 6 deletions src/physics/site_types/tj.jl
Expand Up @@ -54,12 +54,19 @@ function space(
return 3
end

state(::SiteType"tJ", ::StateName"Emp") = 1
state(::SiteType"tJ", ::StateName"Up") = 2
state(::SiteType"tJ", ::StateName"Dn") = 3
state(st::SiteType"tJ", ::StateName"0") = state(st, StateName("Emp"))
state(st::SiteType"tJ", ::StateName"↑") = state(st, StateName("Up"))
state(st::SiteType"tJ", ::StateName"↓") = state(st, StateName("Dn"))
val(::ValName"Emp", ::SiteType"tJ") = 1
val(::ValName"Up", ::SiteType"tJ") = 2
val(::ValName"Dn", ::SiteType"tJ") = 3
val(::ValName"0", st::SiteType"tJ") = val(ValName("Emp"), st)
val(::ValName"↑", st::SiteType"tJ") = val(ValName("Up"), st)
val(::ValName"↓", st::SiteType"tJ") = val(ValName("Dn"), st)

state(::StateName"Emp", ::SiteType"tJ") = [1.0, 0, 0]
state(::StateName"Up", ::SiteType"tJ") = [0.0, 1, 0]
state(::StateName"Dn", ::SiteType"tJ") = [0.0, 0, 1]
state(::StateName"0", st::SiteType"tJ") = state(StateName("Emp"), st)
state(::StateName"↑", st::SiteType"tJ") = state(StateName("Up"), st)
state(::StateName"↓", st::SiteType"tJ") = state(StateName("Dn"), st)

function op!(Op::ITensor, ::OpName"Nup", ::SiteType"tJ", s::Index)
return Op[s' => 2, s => 2] = 1.0
Expand Down
75 changes: 64 additions & 11 deletions src/physics/sitetype.jl
Expand Up @@ -425,34 +425,87 @@ macro StateName_str(s)
return StateName{SmallString(s)}
end

state(::SiteType, ::StateName) = nothing
state(::SiteType, ::AbstractString) = nothing
state(::StateName, ::SiteType, ::Index) = nothing
state!(::ITensor, ::StateName, ::SiteType, ::Index) = nothing

function state(s::Index, name::AbstractString)::IndexVal
function state(s::Index, name::AbstractString; kwargs...)::ITensor
stypes = _sitetypes(s)
sname = StateName(name)

# Try calling state(::SiteType"Tag",::StateName"Name")
# Try calling state(::StateName"Name",::SiteType"Tag",s::Index)
for st in stypes
res = state(st, sname)
!isnothing(res) && return s(res)
res = state(sname, st, s; kwargs...)
!isnothing(res) && return res
end

# Try calling state(::SiteType"Tag","Name")
# Try calling state!(::ITensor,::StateName"Name",::SiteType"Tag",s::Index)
T = emptyITensor(s)
for st in stypes
res = state(st, name)
!isnothing(res) && return s(res)
state!(T, sname, st, s)
!isempty(T) && return T
end

#
# otherwise try calling a function of the form:
# state(::StateName"Name", ::SiteType"Tag"; kwargs...)
# which returns a Julia vector
#
for st in stypes
v = state(sname, st)
!isnothing(v) && return itensor(v, s)
end

return throw(
ArgumentError("Overload of \"state\" function not found for Index tags $(tags(s))")
ArgumentError(
"Overload of \"state\" or \"state!\" functions not found for state name \"$name\" and Index tags $(tags(s))",
),
)
end

state(s::Index, n::Integer) = s[n]
state(s::Index, n::Integer) = onehot(s => n)

state(sset::Vector{<:Index}, j::Integer, st) = state(sset[j], st)

#---------------------------------------
#
# val system
#
#---------------------------------------

@eval struct ValName{Name}
(f::Type{<:ValName})() = $(Expr(:new, :f))
end

ValName(s::AbstractString) = ValName{SmallString(s)}()
ValName(s::SmallString) = ValName{s}()
name(::ValName{N}) where {N} = N

macro ValName_str(s)
return ValName{SmallString(s)}
end

val(::ValName, ::SiteType) = nothing
val(::AbstractString, ::SiteType) = nothing

function val(s::Index, name::AbstractString)::IndexVal
stypes = _sitetypes(s)
sname = ValName(name)

# Try calling val(::StateName"Name",::SiteType"Tag",)
for st in stypes
res = val(sname, st)
!isnothing(res) && return s(res)
end

return throw(
ArgumentError("Overload of \"val\" function not found for Index tags $(tags(s))")
)
end

val(s::Index, n::Integer) = s[n]

val(sset::Vector{<:Index}, j::Integer, st) = val(sset[j], st)

#---------------------------------------
#
# siteind system
Expand Down
6 changes: 3 additions & 3 deletions test/mps.jl
Expand Up @@ -133,7 +133,7 @@ include("util.jl")
for j in 1:N
states[j] = isodd(j) ? 1 : 2
end
ivals = [state(sites[n], states[n]) for n in 1:length(sites)]
ivals = [val(sites[n], states[n]) for n in 1:length(sites)]
psi = MPS(ivals)
for j in 1:N
sign = isodd(j) ? +1.0 : -1.0
Expand Down Expand Up @@ -950,8 +950,8 @@ end
CSWAP = [op("CSWAP", s, n, m, k) for n in 1:N, m in 1:N, k in 1:N]
CCCNOT = [op("CCCNOT", s, n, m, k, l) for n in 1:N, m in 1:N, k in 1:N, l in 1:N]

v0 = [setelt(state(s, n, "0")) for n in 1:N]
v1 = [setelt(state(s, n, "1")) for n in 1:N]
v0 = [setelt(val(s, n, "0")) for n in 1:N]
v1 = [setelt(val(s, n, "1")) for n in 1:N]

# Single qubit
@test product(I[1], v0[1]) ≈ v0[1]
Expand Down