Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
add support for symbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger-luo committed Sep 24, 2019
1 parent 7cdb707 commit 6f69dc4
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
18 changes: 9 additions & 9 deletions src/instruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,31 +213,31 @@ end
# Specialized
import YaoBase: rot_mat

rot_mat(::Type{T}, ::Val{:Rx}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:Rx}, theta::Number) where T =
T[cos(theta/2) -im * sin(theta/2); -im * sin(theta/2) cos(theta/2)]
rot_mat(::Type{T}, ::Val{:Ry}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:Ry}, theta::Number) where T =
T[cos(theta/2) -sin(theta/2); sin(theta/2) cos(theta/2)]
rot_mat(::Type{T}, ::Val{:Rz}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:Rz}, theta::Number) where T =
Diagonal(T[exp(-im*theta/2), exp(im*theta/2)])
rot_mat(::Type{T}, ::Val{:CPHASE}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:CPHASE}, theta::Number) where T =
Diagonal(T[1, 1, 1, exp(im*theta)])
rot_mat(::Type{T}, ::Val{:PSWAP}, theta::Real) where T =
rot_mat(::Type{T}, ::Val{:PSWAP}, theta::Number) where T =
rot_mat(T, Const.SWAP, theta)

for G in [:Rx, :Ry, :Rz, :CPHASE]
# forward single gates
@eval function YaoBase.instruct!(state::AbstractVecOrMat{T}, g::Val{$(QuoteNode(G))},
locs::Union{Int, NTuple{N3,Int}},
control_locs::NTuple{N1, Int},
control_bits::NTuple{N2, Int}, theta::Real) where {T, N1, N2, N3}
control_bits::NTuple{N2, Int}, theta::Number) where {T, N1, N2, N3}
m = rot_mat(T, g, theta)
instruct!(state, m, locs, control_locs, control_bits)
end
end

# forward single gates
@eval function YaoBase.instruct!(state::AbstractVecOrMat{T}, g::Val,
locs::Union{Int, NTuple{N1, Int}}, theta::Real) where {T, N1}
locs::Union{Int, NTuple{N1, Int}}, theta::Number) where {T, N1}
instruct!(state, g, locs, (), (), theta)
end

Expand Down Expand Up @@ -402,7 +402,7 @@ function YaoBase.instruct!(
state::AbstractVecOrMat{T},
::Val{:PSWAP},
locs::Tuple{Int, Int},
theta::Real) where T
theta::Number) where T
mask1 = bmask(locs[1])
mask2 = bmask(locs[2])
mask12 = mask1|mask2
Expand Down Expand Up @@ -430,7 +430,7 @@ function YaoBase.instruct!(
locs::Tuple{Int, Int},
control_locs::NTuple{C, Int},
control_bits::NTuple{C, Int},
theta::Real) where {T, C}
theta::Number) where {T, C}
mask1 = bmask(locs[1])
mask2 = bmask(locs[2])
mask12 = mask1|mask2
Expand Down
21 changes: 20 additions & 1 deletion src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export isnormalized,
Check if the register is normalized.
"""
isnormalized(r::ArrayReg) = all(sum(copy(r) |> relax!(to_nactive=nqubits(r)) |> probs, dims=1) .≈ 1)
isnormalized(r::AdjointArrayReg) = isnormalized(parent(r))

"""
normalize!(r::ArrayReg)
Expand All @@ -22,16 +23,24 @@ function LinearAlgebra.normalize!(r::ArrayReg{B}) where B
return r
end

LinearAlgebra.normalize!(r::AdjointArrayReg) = (normalize!(parent(r)); r)

# basic arithmatics

# neg
Base.:-(reg::ArrayReg) = ArrayReg(-state(reg))
Base.:-(reg::AdjointArrayReg) = adjoint(-parent(reg))

# +, -
for op in [:+, :-]
@eval function Base.$op(lhs::ArrayReg{B}, rhs::ArrayReg{B}) where B
return ArrayReg(($op)(state(lhs), state(rhs)))
end

@eval function Base.$op(lhs::AdjointArrayReg{B}, rhs::AdjointArrayReg{B}) where B
r = $op(parent(lhs), parent(rhs))
return adjoint(r)
end
end

# *, /
Expand All @@ -40,9 +49,19 @@ for op in [:*, :/]
ArrayReg{B}($op(state(lhs), rhs))
end

@eval function Base.$op(lhs::RT, rhs::Number) where {B, RT <: AdjointArrayReg{B}}
r = $op(parent(lhs), rhs')
return adjoint(r)
end

if op == :*
@eval function Base.$op(lhs::Number, rhs::RT) where {B, RT <: ArrayReg{B}}
ArrayReg{B}(($op)(lhs, state(rhs)))
ArrayReg{B}(lhs * state(rhs))
end

@eval function Base.$op(lhs::Number, rhs::RT) where {B, RT <: AdjointArrayReg{B}}
r = lhs' * parent(rhs)
return adjoint(r)
end
end
end
Expand Down
5 changes: 4 additions & 1 deletion test/operations.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test, YaoArrayRegister, Random, LinearAlgebra, SparseArrays
using Test
using YaoArrayRegister, Random, LinearAlgebra, SparseArrays, BitBasis


@testset "broadcast register" begin
Expand All @@ -15,9 +16,11 @@ end
@test reg1!=reg2
@test statevec(reg2) == onehot(ComplexF64, nbit, 4)
reg3 = reg1 + reg2

@test statevec(reg3) == onehot(ComplexF64, nbit, 4) + onehot(ComplexF64, nbit, 0)
@test statevec(reg3 |> normalize!) == (onehot(ComplexF64, nbit, 4) + onehot(ComplexF64, nbit, 0))/sqrt(2)
@test (reg1 + reg2 - reg1) == reg2
@test reg1' + reg2' - reg1' == reg2'

reg = rand_state(4)
@test all(state(reg + (-reg)).==0)
Expand Down

0 comments on commit 6f69dc4

Please sign in to comment.