Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
Fix measure dispatch (#42)
Browse files Browse the repository at this point in the history
* fix measure dispatch
  • Loading branch information
GiggleLiu committed Dec 7, 2019
1 parent e025f29 commit 8b2b068
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[YaoBase]]
deps = ["BitBasis", "LegibleLambdas", "LinearAlgebra", "LuxurySparse", "MLStyle", "MacroTools", "Random", "SparseArrays", "Test", "TupleTools"]
git-tree-sha1 = "78ac5318105d091d1f7fe447309cead9c8206c16"
git-tree-sha1 = "2eb8db378629feeed892aa00fa279b59dbd4b7ca"
repo-rev = "master"
repo-url = "https://github.com/QuantumBFS/YaoBase.jl.git"
uuid = "a8f54c17-34bc-5a9d-b050-f522fe3f755f"
Expand Down
30 changes: 16 additions & 14 deletions src/measure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using StatsBase, StaticArrays, BitBasis, Random
export measure, measure!, measure_remove!, measure_resetto!, select, select!
export measure, measure!, select, select!

function _measure(rng::AbstractRNG, pl::AbstractVector, nshots::Int)
N = log2i(length(pl))
Expand All @@ -16,29 +16,30 @@ function _measure(rng::AbstractRNG, pl::AbstractMatrix, nshots::Int)
end

YaoBase.measure(
rng::AbstractRNG,
::ComputationalBasis,
reg::ArrayReg{1},
::AllLocs;
nshots::Int = 1,
rng::AbstractRNG=Random.GLOBAL_RNG
) = _measure(rng, reg |> probs, nshots)

function YaoBase.measure(
rng::AbstractRNG,
::ComputationalBasis,
reg::ArrayReg{B},
::AllLocs;
nshots::Int = 1,
rng::AbstractRNG=Random.GLOBAL_RNG
) where {B}
pl = dropdims(sum(reg |> rank3 .|> abs2, dims = 2), dims = 2)
return _measure(rng, pl, nshots)
end

function YaoBase.measure_remove!(
rng::AbstractRNG,
function YaoBase.measure!(
::YaoBase.RemoveMeasured,
::ComputationalBasis,
reg::ArrayReg{B},
::AllLocs,
::AllLocs;
rng::AbstractRNG=Random.GLOBAL_RNG
) where {B}
state = reg |> rank3
nstate = similar(reg.state, 1 << nremain(reg), B)
Expand All @@ -55,14 +56,15 @@ function YaoBase.measure_remove!(
end

function YaoBase.measure!(
rng::AbstractRNG,
::YaoBase.NoPostProcess,
::ComputationalBasis,
reg::ArrayReg{B},
::AllLocs,
::AllLocs;
rng::AbstractRNG=Random.GLOBAL_RNG
) where {B}
state = reg |> rank3
nstate = zero(state)
res = measure_remove!(rng, reg)
res = measure!(RemoveMeasured(), reg; rng=rng)
_nstate = reshape(reg.state, :, B)
for ib in 1:B
@inbounds nstate[Int64(res[ib])+1, :, ib] .= view(_nstate, :, ib)
Expand All @@ -71,18 +73,18 @@ function YaoBase.measure!(
return res
end

function YaoBase.measure_resetto!(
rng::AbstractRNG,
function YaoBase.measure!(
rst::YaoBase.ResetTo,
::ComputationalBasis,
reg::ArrayReg{B},
::AllLocs;
config::Integer = 0,
rng::AbstractRNG=Random.GLOBAL_RNG
) where {B}
state = rank3(reg)
M, N, B1 = size(state)
nstate = zero(state)
res = measure_remove!(rng, reg)
nstate[Int(config)+1, :, :] = reshape(reg.state, :, B)
res = measure!(YaoBase.RemoveMeasured(), reg; rng=rng)
nstate[Int(rst.x)+1, :, :] = reshape(reg.state, :, B)
reg.state = reshape(nstate, M, N * B)
return res
end
Expand Down
1 change: 0 additions & 1 deletion src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ export ArrayReg,
# YaoBase deprecated
addbit!,
reset!,
measure_reset!,
# additional
state,
statevec,
Expand Down
2 changes: 1 addition & 1 deletion test/focus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ end
@test reg |> nactive == 3
@test copy(reg) |> addbits!(2) |> nactive == 5
reg2 = copy(reg) |> addbits!(2) |> focus!(4, 5)
@test (reg2 |> measure_remove!; reg2) |> relax!(to_nactive = nqubits(reg2)) reg
@test (measure!(RemoveMeasured(), reg2); reg2) |> relax!(to_nactive = nqubits(reg2)) reg

@test insert_qubits!(copy(reg), 2; nqubits = 2) |> nactive == 5
end
Expand Down
16 changes: 12 additions & 4 deletions test/measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,33 @@ end

@testset "measure and resetto/remove" begin
reg = rand_state(4)
res = measure_resetto!(reg, (4,))
res = measure!(YaoBase.ResetTo(0), reg, (4,))
@test isnormalized(reg)
result = measure(reg; nshots = 10)
@test all(result .< 8)

reg = rand_state(6) |> focus!(1, 4, 3)
reg0 = copy(reg)
res = measure_remove!(reg)
res = measure!(YaoBase.RemoveMeasured(), reg)
@test nqubits(reg) == 3
select(reg0, res)
@test select(reg0, res) |> normalize! reg

r = rand_state(10)
r1 = copy(r) |> focus!(1, 4, 3)
res = measure_remove!(r, (1, 4, 3))
res = measure!(YaoBase.RemoveMeasured(), r, (1, 4, 3))
r2 = select(r1, res)
r2 = relax!(r2, (); to_nactive = nqubits(r2))
@test normalize!(r2) r

reg = rand_state(6, nbatch = 5) |> focus!((1:5)...)
measure_resetto!(reg, 1)
measure!(YaoBase.ResetTo(0), reg, 1)
@test nactive(reg) == 5
end

@testset "fix measure kwargs error" begin
r = rand_state(10)
@test length(measure(r; nshots=10)) == 10
@test_throws MethodError measure!(r; nshots=10)
@test_throws MethodError measure!(YaoBase.RemoveMeasured(), r; nshots=10)
end

0 comments on commit 8b2b068

Please sign in to comment.