Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

add support for symbolic #31

Merged
merged 2 commits into from
Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/instruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,31 +213,31 @@ end
# Specialized
import YaoBase: rot_mat

rot_mat(::Type{T}, ::Val{:Rx}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:Rx}, theta::Number) where T =
T[cos(theta/2) -im * sin(theta/2); -im * sin(theta/2) cos(theta/2)]
rot_mat(::Type{T}, ::Val{:Ry}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:Ry}, theta::Number) where T =
T[cos(theta/2) -sin(theta/2); sin(theta/2) cos(theta/2)]
rot_mat(::Type{T}, ::Val{:Rz}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:Rz}, theta::Number) where T =
Diagonal(T[exp(-im*theta/2), exp(im*theta/2)])
rot_mat(::Type{T}, ::Val{:CPHASE}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:CPHASE}, theta::Number) where T =
Diagonal(T[1, 1, 1, exp(im*theta)])
rot_mat(::Type{T}, ::Val{:PSWAP}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:PSWAP}, theta::Number) where T =
rot_mat(T, Const.SWAP, theta)

for G in [:Rx, :Ry, :Rz, :CPHASE]
# forward single gates
@eval function YaoBase.instruct!(state::AbstractVecOrMat{T}, g::Val{$(QuoteNode(G))},
locs::Union{Int, NTuple{N3,Int}},
control_locs::NTuple{N1, Int},
control_bits::NTuple{N2, Int}, theta::Real) where {T, N1, N2, N3}
control_bits::NTuple{N2, Int}, theta::Number) where {T, N1, N2, N3}
m = rot_mat(T, g, theta)
instruct!(state, m, locs, control_locs, control_bits)
end
end

# forward single gates
@eval function YaoBase.instruct!(state::AbstractVecOrMat{T}, g::Val,
locs::Union{Int, NTuple{N1, Int}}, theta::Real) where {T, N1}
locs::Union{Int, NTuple{N1, Int}}, theta::Number) where {T, N1}
instruct!(state, g, locs, (), (), theta)
end

Expand Down Expand Up @@ -402,7 +402,7 @@ function YaoBase.instruct!(
state::AbstractVecOrMat{T},
::Val{:PSWAP},
locs::Tuple{Int, Int},
theta::Real) where T
theta::Number) where T
mask1 = bmask(locs[1])
mask2 = bmask(locs[2])
mask12 = mask1|mask2
Expand Down Expand Up @@ -430,7 +430,7 @@ function YaoBase.instruct!(
locs::Tuple{Int, Int},
control_locs::NTuple{C, Int},
control_bits::NTuple{C, Int},
theta::Real) where {T, C}
theta::Number) where {T, C}
mask1 = bmask(locs[1])
mask2 = bmask(locs[2])
mask12 = mask1|mask2
Expand Down
21 changes: 20 additions & 1 deletion src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export isnormalized,
Check if the register is normalized.
"""
isnormalized(r::ArrayReg) = all(sum(copy(r) |> relax!(to_nactive=nqubits(r)) |> probs, dims=1) .≈ 1)
isnormalized(r::AdjointArrayReg) = isnormalized(parent(r))

"""
normalize!(r::ArrayReg)
Expand All @@ -22,16 +23,24 @@ function LinearAlgebra.normalize!(r::ArrayReg{B}) where B
return r
end

LinearAlgebra.normalize!(r::AdjointArrayReg) = (normalize!(parent(r)); r)

# basic arithmatics

# neg
Base.:-(reg::ArrayReg) = ArrayReg(-state(reg))
Base.:-(reg::AdjointArrayReg) = adjoint(-parent(reg))

# +, -
for op in [:+, :-]
@eval function Base.$op(lhs::ArrayReg{B}, rhs::ArrayReg{B}) where B
return ArrayReg(($op)(state(lhs), state(rhs)))
end

@eval function Base.$op(lhs::AdjointArrayReg{B}, rhs::AdjointArrayReg{B}) where B
r = $op(parent(lhs), parent(rhs))
return adjoint(r)
end
end

# *, /
Expand All @@ -40,9 +49,19 @@ for op in [:*, :/]
ArrayReg{B}($op(state(lhs), rhs))
end

@eval function Base.$op(lhs::RT, rhs::Number) where {B, RT <: AdjointArrayReg{B}}
r = $op(parent(lhs), rhs')
return adjoint(r)
end

if op == :*
@eval function Base.$op(lhs::Number, rhs::RT) where {B, RT <: ArrayReg{B}}
ArrayReg{B}(($op)(lhs, state(rhs)))
ArrayReg{B}(lhs * state(rhs))
end

@eval function Base.$op(lhs::Number, rhs::RT) where {B, RT <: AdjointArrayReg{B}}
r = lhs' * parent(rhs)
return adjoint(r)
end
end
end
Expand Down
18 changes: 16 additions & 2 deletions test/operations.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test, YaoArrayRegister, Random, LinearAlgebra, SparseArrays
using Test
using YaoArrayRegister, Random, LinearAlgebra, SparseArrays, BitBasis


@testset "broadcast register" begin
Expand All @@ -8,21 +9,34 @@ using Test, YaoArrayRegister, Random, LinearAlgebra, SparseArrays
@test [rand_state(3)...] |> length == 1
end

@testset "Math Operations" begin
@testset "arithmetics" begin
nbit = 5
reg1 = zero_state(5)
reg2 = ArrayReg(bit"00100")
@test reg1!=reg2
@test statevec(reg2) == onehot(ComplexF64, nbit, 4)
reg3 = reg1 + reg2
reg4 = (reg1 + reg2)'

@test statevec(reg3) == onehot(ComplexF64, nbit, 4) + onehot(ComplexF64, nbit, 0)
@test statevec(reg3 |> normalize!) == (onehot(ComplexF64, nbit, 4) + onehot(ComplexF64, nbit, 0))/sqrt(2)
@test statevec(reg4 |> normalize!) == (onehot(ComplexF64, nbit, 4) + onehot(ComplexF64, nbit, 0))'/sqrt(2)
@test (reg1 + reg2 - reg1) == reg2
@test reg1' + reg2' - reg1' == reg2'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also add tests for Number * reg', reg' * number, isnormalized(reg'), normalize!(reg'), -reg'

@test isnormalized(reg4)
@test isnormalized(reg3)

@test statevec(-reg4) == - statevec(reg4)
@test statevec(-reg3) == - statevec(reg3)
reg = rand_state(4)
@test all(state(reg + (-reg)).==0)
@test all(state(reg*2 - reg/0.5) .== 0)

reg = rand_state(3)
@test reg'*reg ≈ 1

@test state(reg1 * 2) == state(reg1) * 2
@test state(reg1' * 2) == state(reg1') * 2
@test reg1 * 2 == 2 * reg1
@test reg1' * 2 == 2 * reg1'
end