From 03f03c8d78783e9a7788a650abf312a15cfb08b7 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 10 Dec 2019 01:10:37 +0800 Subject: [PATCH] measure! now returns scalar to B=1 --- src/measure.jl | 7 ++++--- test/measure.jl | 11 ++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/measure.jl b/src/measure.jl index a364187..fc8605d 100644 --- a/src/measure.jl +++ b/src/measure.jl @@ -8,7 +8,7 @@ end function _measure(rng::AbstractRNG, pl::AbstractMatrix, nshots::Int) B = size(pl, 2) - res = Matrix{BitStr64{log2i(length(pl))}}(undef, nshots, B) + res = Matrix{BitStr64{log2i(size(pl, 1))}}(undef, nshots, B) for ib in 1:B @inbounds res[:, ib] = _measure(rng, view(pl, :, ib), nshots) end @@ -52,7 +52,7 @@ function YaoBase.measure!( res[ib] = ires end reg.state = reshape(nstate, 1, :) - return res + return B==1 ? res[] : res end function YaoBase.measure!( @@ -66,8 +66,9 @@ function YaoBase.measure!( nstate = zero(state) res = measure!(RemoveMeasured(), reg; rng=rng) _nstate = reshape(reg.state, :, B) + indices = Int64.(res) .+ 1 for ib in 1:B - @inbounds nstate[Int64(res[ib])+1, :, ib] .= view(_nstate, :, ib) + @inbounds nstate[indices[ib], :, ib] .= view(_nstate, :, ib) end reg.state = reshape(nstate, size(state, 1), :) return res diff --git a/test/measure.jl b/test/measure.jl index 6c93288..8b93618 100644 --- a/test/measure.jl +++ b/test/measure.jl @@ -18,6 +18,7 @@ end @test isnormalized(reg) result = measure(reg; nshots = 10) @test all(result .< 8) + @test ndims(res) == 0 reg = rand_state(6) |> focus!(1, 4, 3) reg0 = copy(reg) @@ -25,6 +26,7 @@ end @test nqubits(reg) == 3 select(reg0, res) @test select(reg0, res) |> normalize! ≈ reg + @test ndims(res) == 0 r = rand_state(10) r1 = copy(r) |> focus!(1, 4, 3) @@ -32,10 +34,12 @@ end r2 = select(r1, res) r2 = relax!(r2, (); to_nactive = nqubits(r2)) @test normalize!(r2) ≈ r + @test ndims(res) == 0 reg = rand_state(6, nbatch = 5) |> focus!((1:5)...) - measure!(YaoBase.ResetTo(0), reg, 1) + res = measure!(YaoBase.ResetTo(0), reg, 1) @test nactive(reg) == 5 + @test ndims(res) == 1 end @testset "fix measure kwargs error" begin @@ -44,3 +48,8 @@ end @test_throws MethodError measure!(r; nshots=10) @test_throws MethodError measure!(YaoBase.RemoveMeasured(), r; nshots=10) end + +@testset "fix measure output type error" begin + res = measure(rand_state(1;nbatch=10)) + @test res isa Matrix{BitStr64{1}} +end