Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
fix rand_state type issue and make zygote easier (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu authored and Roger-luo committed May 26, 2019
1 parent 7f6f55a commit e0c59ca
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@

docs/build/
docs/site/
docs/src/examples/*.md
docs/src/examples/*.md
*.swp
21 changes: 13 additions & 8 deletions src/instruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function YaoBase.instruct!(
locs::NTuple{M, Int},
control_locs::NTuple{C, Int}=(),
control_bits::NTuple{C, Int}=()) where {T1, T2, M, C}

@warn "Element Type Mismatch: register $(T1), operator $(T2). Converting operator to match, this may cause performance issue"
return instruct!(state, copyto!(similar(operator, T1), operator), locs, control_locs, control_bits)
end
Expand All @@ -53,6 +53,17 @@ function YaoBase.instruct!(state::AbstractVecOrMat{T1}, U1::SDDiagonal{T2}, loc:
return instruct!(state, copyto!(similar(U1, T1), U1), loc)
end

function _prepare_instruct(state, U, locs::NTuple{M}, control_locs, control_bits::NTuple{C}) where {M, C}
N, MM = log2dim1(state), size(U, 1)

locked_bits = MVector(control_locs..., locs...)
locked_vals = MVector(control_bits..., (0 for k in 1:M)...)
locs_raw_it = (b+1 for b in itercontrol(N, setdiff(1:N, locs), zeros(Int, N-M)))
locs_raw = SVector(locs_raw_it...)
ic = itercontrol(N, locked_bits, locked_vals)
return locs_raw, ic
end

function YaoBase.instruct!(
state::AbstractVecOrMat{T},
operator::AbstractMatrix{T},
Expand All @@ -70,13 +81,7 @@ function YaoBase.instruct!(
control_bits::NTuple{C, Int} = ()) where {T, M, C}

U = sort_unitary(operator, locs)
N, MM = log2dim1(state), size(U, 1)

locked_bits = MVector(control_locs..., locs...)
locked_vals = MVector(control_bits..., (0 for k in 1:M)...)
locs_raw_it = (b+1 for b in itercontrol(N, setdiff(1:N, locs), zeros(Int, N-M)))
locs_raw = SVector(locs_raw_it...)
ic = itercontrol(N, locked_bits, locked_vals)
locs_raw, ic = _prepare_instruct(state, U, locs, control_locs, control_bits)

return _instruct!(state, autostatic(U), locs_raw, ic)
end
Expand Down
2 changes: 1 addition & 1 deletion src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ for op in [:(==), :≈]
end
end

Base.:*(bra::AdjointArrayReg{1}, ket::ArrayReg{1}) = dot(parent(bra).state, state(ket))
Base.:*(bra::AdjointArrayReg{1}, ket::ArrayReg{1}) = dot(state(parent(bra)), state(ket))
Base.:*(bra::AdjointArrayReg{B}, ket::ArrayReg{B}) where B = bra .* ket

# broadcast
Expand Down
2 changes: 1 addition & 1 deletion src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ ArrayReg{2,Complex{Float64},Array...}
rand_state(n::Int; nbatch::Int=1) = rand_state(ComplexF64, n; nbatch=nbatch)

function rand_state(::Type{T}, n::Int; nbatch::Int=1) where T
raw = randn(T, 1<<n, nbatch) + im * randn(T, 1<<n, nbatch)
raw = randn(T, 1<<n, nbatch)
return normalize!(ArrayReg{nbatch}(raw))
end

Expand Down
4 changes: 4 additions & 0 deletions test/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,22 @@ end
for k in 1:3
@test st[:, k] onehot(T, 4, 0)
end
@test eltype(product_state(Float64, 4, 0).state) == Float64
end
@testset "test zero state" begin
st = state(zero_state(T, 4; nbatch=4))
for k in 1:4
@test st[:, k] onehot(T, 4, 0)
end
@test eltype(zero_state(Float64, 4).state) == Float64
end
@testset "test rand state" begin
# NOTE: we only check if the state is normalized
st = state(rand_state(T, 4, nbatch=2))
for k in 1:2
@test norm(st[:, k]) 1.0
end
@test eltype(rand_state(Float64, 4).state) == Float64
end
@testset "test uniform state" begin
st = state(uniform_state(T, 4; nbatch=2))
Expand All @@ -49,6 +52,7 @@ end
@test each 1/sqrt(16)
end
end
@test eltype(uniform_state(Float64, 4).state) == Float64
end
@testset "test oneto" begin
r1 = uniform_state(ComplexF64, 4)
Expand Down

0 comments on commit e0c59ca

Please sign in to comment.