Skip to content

Commit

Permalink
add Rotbasis (#107)
Browse files Browse the repository at this point in the history
* new general control

* add itercontrol

* new PutBlock

* new put block, more general control

* benchmark crot toffoli

* benchmark crot toffoli

* updatebenchmarks

* new h benchmark

* new benchmark results

* statify a sparse array

* fix test failures

* new rotbasis

* new general matrix gate

* add rotbasis & fix parameters

* fix rotbasis
  • Loading branch information
Roger-luo committed Jun 29, 2018
1 parent ef0206c commit 89894a5
Show file tree
Hide file tree
Showing 24 changed files with 736 additions and 345 deletions.
752 changes: 436 additions & 316 deletions examples/QCBM.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/Blocks/Blocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Compat.Iterators
using Compat.LinearAlgebra
using Compat.SparseArrays
using Lazy: @forward
using DataStructures

using ..Intrinsics
using ..Registers
Expand Down
12 changes: 2 additions & 10 deletions src/Blocks/Composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,7 @@ function dispatch!(x::CompositeBlock, itr)
count = 0
for block in Iterators.filter(x->nparameters(x) > 0, blocks(x))
params = view(itr, count+1:count+nparameters(block))
if block isa CompositeBlock
dispatch!(block, params)
else
dispatch!(block, params...)
end
dispatch!(block, params)
count += nparameters(block)
end
x
Expand All @@ -94,11 +90,7 @@ function dispatch!(f::Function, x::CompositeBlock, itr)
count = 0
for block in Iterators.filter(x->nparameters(x) > 0, blocks(x))
params = view(itr, count+1:count+nparameters(block))
if block isa CompositeBlock
dispatch!(f, block, params)
else
dispatch!(f, block, params...)
end
dispatch!(f, block, params)
count += nparameters(block)
end
x
Expand Down
19 changes: 19 additions & 0 deletions src/Blocks/GeneralMatrixGate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export GeneralMatrixGate

mutable struct GeneralMatrixGate{M, N, T} <: PrimitiveBlock{N, T}
matrix :: AbstractMatrix{T}
function GeneralMatrixGate{M, N, T}(matrix::AbstractMatrix{T}) where {M, N, T}
(1<<M == size(matrix, 1) && 1<<N == size(matrix, 2)) || throw(DimensionMismatch("Dimension of input matrix shape error."))
new{M, N, T}(matrix)
end
end
GeneralMatrixGate(matrix::AbstractMatrix{T}) where T = GeneralMatrixGate{log2i(size(matrix, 1)), log2i(size(matrix, 2)), T}(matrix)

==(A::GeneralMatrixGate, B::GeneralMatrixGate) = A.matrix == B.matrix
copy(r::GeneralMatrixGate) = GeneralMatrixGate(copy(r.matrix))

mat(r::GeneralMatrixGate) = r.matrix

function print_block(io::IO, g::GeneralMatrixGate{M, N, T}) where {M,N,T}
print("GeneralMatrixGate(2^$M × 2^$N)")
end
2 changes: 1 addition & 1 deletion src/Blocks/PhaseGate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mat(gate::PhaseGate{T}) where T = exp(im * gate.theta) * IMatrix{2, Complex{T}}(
adjoint(blk::PhaseGate) = PhaseGate(-blk.theta)

copy(block::PhaseGate{T}) where T = PhaseGate{T}(block.theta)
dispatch!(block::PhaseGate, theta) = (block.theta = theta; block)
dispatch!(block::PhaseGate, itr) = (block.theta = first(itr); block)

# Properties
nparameters(::Type{<:PhaseGate}) = 1
Expand Down
11 changes: 9 additions & 2 deletions src/Blocks/Primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ method to enable key value cache.
"""
abstract type PrimitiveBlock{N, T} <: MatrixBlock{N, T} end

function dispatch!(f::Function, x::PrimitiveBlock, params...)
dispatch!(x, f.(parameters(x), params)...)
function dispatch!(x::PrimitiveBlock, params...)
dispatch!(x, params)
end

dispatch!(f::Function, x::PrimitiveBlock, params...) = dispatch!(f, x, params)

function dispatch!(f::Function, x::PrimitiveBlock, itr)
dispatch!(x, f.(parameters(x), itr))
x
end

Expand All @@ -26,3 +32,4 @@ include("ShiftGate.jl")
include("RotationGate.jl")
include("SwapGate.jl")
include("ReflectBlock.jl")
include("GeneralMatrixGate.jl")
2 changes: 1 addition & 1 deletion src/Blocks/ReflectBlock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ ishermitian(::ReflectBlock) = true
isunitary(::ReflectBlock) = true

function print_block(io::IO, g::ReflectBlock{N, T}) where {N, T}
print("ReflectBlock(N = $N")
print("ReflectBlock(N = $N)")
end
4 changes: 2 additions & 2 deletions src/Blocks/RotationGate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ adjoint(blk::RotationGate) = RotationGate(blk.U, -blk.theta)

copy(R::RotationGate) = RotationGate(R.U, R.theta)

function dispatch!(R::RotationGate, theta)
R.theta = theta
function dispatch!(R::RotationGate, itr)
R.theta = first(itr)
R
end

Expand Down
2 changes: 1 addition & 1 deletion src/Blocks/ShiftGate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mat(gate::ShiftGate{T}) where T = Diagonal(Complex{T}[1.0, exp(im * gate.theta)]
adjoint(blk::ShiftGate) = ShiftGate(-blk.theta)

copy(block::ShiftGate{T}) where T = ShiftGate{T}(block.theta)
dispatch!(block::ShiftGate, theta::Vector) = (block.theta = theta[1]; block)
dispatch!(block::ShiftGate, itr) = (block.theta = first(itr); block)

# Properties
nparameters(::Type{<:ShiftGate}) = 1
Expand Down
4 changes: 3 additions & 1 deletion src/Interfaces/Function.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export @fn, InvOrder, addbit, Reset
export @fn, InvOrder, addbit, Reset, focus

macro fn(f)
:(FunctionBlock($(esc(f))))
Expand Down Expand Up @@ -35,3 +35,5 @@ const Reset = @fn Reset reset!
Return a [`FunctionBlock`](@ref) of adding n bits.
"""
addbit(n::Int) = FunctionBlock{Tuple{:AddBit, n}}(reg->addbit!(reg, n))

focus(locs::Int...,) = FunctionBlock{Tuple{:Focus, locs...,}}(reg->focus!(reg, locs))
11 changes: 10 additions & 1 deletion src/Interfaces/Primitive.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export H, phase, shift, Rx, Ry, Rz, rot, swap, I2, reflect
export H, phase, shift, Rx, Ry, Rz, rot, swap, I2, reflect, matrixgate

include("PauliGates.jl")

Expand Down Expand Up @@ -105,3 +105,12 @@ function reflect end

reflect(mirror::Vector) = ReflectBlock(mirror)
reflect(mirror::DefaultRegister{1}) = reflect(mirror|>statevec)

"""
matrixgate(matrix::AbstractMatrix) -> GeneralMatrixGate
matrixgate(matrix::MatrixBlock) -> GeneralMatrixGate
Construct a general matrix gate.
"""
matrixgate(matrix::AbstractMatrix) = GeneralMatrixGate(matrix)
matrixgate(matrix::MatrixBlock) = GeneralMatrixGate(mat(matrix))
14 changes: 13 additions & 1 deletion src/Registers/Default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ end

extend!(n::Int) = r->extend!(r, n)

function join(reg1::DefaultRegister{B, T1}, reg2::DefaultRegister{B, T2}) where {B, T1, T2}
s1 = reshape(reg1.state, size(reg1.state, 1), :, B)
s2 = reshape(reg2.state, size(reg2.state, 1), :, B)
T = promote_type(T1, T2)
state = Array{T,3}(size(s1, 1)*size(s2, 1), size(s1, 2)*size(s2, 2), B)
for b = 1:B
@inbounds @views state[:,:,b] = kron(s2[:,:,b], s1[:,:,b])
end
DefaultRegister{B}(reshape(state, size(state, 1), :))
end
join(reg1::DefaultRegister{1}, reg2::DefaultRegister{1}) = DefaultRegister{1}(kron(reg2.state, reg1.state))

"""
isnormalized(reg::DefaultRegister) -> Bool
Expand Down Expand Up @@ -151,7 +163,7 @@ end

reorder!(orders::Int...) = reg::DefaultRegister -> reorder!(reg, [orders...])

invorder!(reg::DefaultRegister) = reorder!(reg, collect(nqubits(reg):-1:1))
invorder!(reg::DefaultRegister) = reorder!(reg, collect(nactive(reg):-1:1))

function addbit!(reg::DefaultRegister{B, T}, n::Int) where {B, T}
state = zeros(T, size(reg.state, 1)*(1<<n), size(reg.state, 2))
Expand Down
2 changes: 1 addition & 1 deletion src/Registers/Registers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using StatsBase
using ..Intrinsics

import Base: length
import Base: eltype, copy, similar, *
import Base: eltype, copy, similar, *, join
import Base: show

# import package APIs
Expand Down
2 changes: 2 additions & 0 deletions src/Zoo/Differential.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export diff_circuit, num_gradient, rotter, cnot_entangler, opgrad, collect_rotblocks

"""
rotter(noleading::Bool=false, notrailing::Bool=false) -> ChainBlock{1, ComplexF64}
Expand Down
58 changes: 56 additions & 2 deletions src/Zoo/QFT.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,57 @@
@static if VERSION >= v"0.7-"
using FFTW
end

export QFTCircuit, QFTBlock, breflect, invorder_firstdim

CRk(i::Int, j::Int, k::Int) = control([i, ], j=>shift(-2π/(1<<k)))
CRot(n::Int, i::Int) = chain(i==j ? put(i=>H) : CRk(j, i, j-i+1) for j = i:n)
QFT(n::Int) = chain(n, CRot(n, i) for i = 1:n)
CRot(n::Int, i::Int) = chain(i==j ? kron(i=>H) : CRk(j, i, j-i+1) for j = i:n)
QFTCircuit(n::Int) = chain(n, CRot(n, i) for i = 1:n)

struct QFTBlock{N} <: PrimitiveBlock{N,ComplexF64} end
mat(q::QFTBlock{N}) where N = applymatrix(q)

apply!(reg::DefaultRegister{B}, ::QFTBlock) where B = (reg.state = fft!(invorder_firstdim(reg |> state), 1)/sqrt(1<<nqubits(reg)); reg)
apply!(reg::DefaultRegister{B}, ::Daggered{N, T, <:QFTBlock}) where {B,N,T} = (reg.state = invorder_firstdim(ifft!(reg|>state, 1)*sqrt(1<<nqubits(reg))); reg)

# traits
ishermitian(q::QFTBlock{N}) where N = N==1
isreflexive(q::QFTBlock{N}) where N = N==1
isunitary(q::QFTBlock{N}) where N = true

function breflect(num_bit::Int, b::Int)
for i in 1:num_bit÷2
b = swapbits(b, bmask(i, num_bit-i+1))
end
b
end

function breflect(num_bit::Int, b::Int, mask::Vector{Int})
@simd for m in mask
b = swapbits(b, m)
end
b
end

function invorder_firstdim(v::Matrix)
w = similar(v)
n = size(v, 1) |> log2i
n_2 = n ÷ 2
mask = [bmask(i, n-i+1) for i in 1:n_2]
@simd for b in basis(n)
@inbounds w[breflect(n, b, mask)+1,:] = v[b+1,:]
end
w
end

function invorder_firstdim(v::Vector)
n = length(v) |> log2i
n_2 = n ÷ 2
w = similar(v)
#mask = SVector{n_2, Int}([bmask(i, n-i+1)::Int for i in 1:n_2])
mask = [bmask(i, n-i+1)::Int for i in 1:n_2]
@simd for b in basis(n)
@inbounds w[breflect(n, b, mask)+1] = v[b+1]
end
w
end
67 changes: 67 additions & 0 deletions src/Zoo/RotBasis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
export RotBasis, randpolar, polar2u, u2polar, rot_basis

"""
RotBasis{T} <: PrimitiveBlock{1, Complex{T}}
A special rotation block that transform basis to angle θ and ϕ in bloch sphere.
"""
mutable struct RotBasis{T} <: PrimitiveBlock{1, Complex{T}}
theta::T
phi::T
end

# chain -> *
# mat(rb::RotBasis{T}) where T = mat(Ry(-rb.theta))*mat(Rz(-rb.phi))
function mat(x::RotBasis{T}) where T
R1 = _make_rot_mat(IMatrix{2, Complex{T}}(), mat(Z), -x.phi)
R2 = _make_rot_mat(IMatrix{2, Complex{T}}(), mat(Y), -x.theta)
R2 * R1
end

==(rb1::RotBasis, rb2::RotBasis) = rb1.theta == rb2.theta && rb1.phi == rb2.phi

copy(block::RotBasis{T}) where T = RotBasis{T}(block.theta, block.phi)
dispatch!(block::RotBasis, params) = ((block.theta, block.phi) = params; block)

parameters(rb::RotBasis) = (rb.theta, rb.phi)
nparameters(::Type{<:RotBasis}) = 2

function print_block(io::IO, R::RotBasis)
print(io, "RotBasis($(R.theta), $(R.phi))")
end

function hash(gate::RotBasis, h::UInt)
hash(hash(gate.theta, gate.phi, objectid(gate)), h)
end

rot_basis(num_bit::Int) = dispatch!(chain(num_bit, put(i=>RotBasis(0.0, 0.0)) for i=1:num_bit), randpolar(num_bit) |> vec)

"""
u2polar(vec::Array) -> Array
transform su(2) state vector to polar angle, apply to the first dimension of size 2.
"""
function u2polar(vec::Vector)
ratio = vec[2]/vec[1]
[atan(abs(ratio))*2, angle(ratio)]
end

"""
polar2u(vec::Array) -> Array
transform polar angle to su(2) state vector, apply to the first dimension of size 2.
"""
function polar2u(polar::Vector)
theta, phi = polar
[cos(theta/2)*exp(-im*phi/2), sin(theta/2)*exp(im*phi/2)]
end

u2polar(arr::Array) = mapslices(u2polar, arr, [1])
polar2u(arr::Array) = mapslices(polar2u, arr, [1])

"""
randpolar(params::Int...) -> Array
random polar basis, number of basis
"""
randpolar(params::Int...) = rand(2, params...)*pi
12 changes: 8 additions & 4 deletions src/Zoo/Zoo.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
module Zoo
using Compat

using ..Yao
using ..Blocks
using ..LuxurySparse
using ..Intrinsics
using ..Registers
using ..Blocks
import ..Blocks: mat, dispatch!, nparameters, parameters, cache_key, print_block, _make_rot_mat, apply!
import Base: ==, copy, hash
import ..Intrinsics: ishermitian, isreflexive, isunitary

# Block APIs
export QFT
export diff_circuit, num_gradient, rotter, cnot_entangler, opgrad, collect_rotblocks

include("QFT.jl")
include("Differential.jl")
include("RotBasis.jl")
include("Grover.jl")

end
21 changes: 21 additions & 0 deletions test/Blocks/GeneralMatrixGate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Compat
using Compat.Test

using Yao
using Yao.Blocks

import Yao.Blocks: GeneralMatrixGate

@testset "MatrixGate" begin
mg = GeneralMatrixGate(randn(4,4))
mg2 = copy(mg)
@test mg2 == mg
mg2.matrix[:,2] = 10
@test mg2 != mg
@test nqubits(mg) == 2
@test_throws DimensionMismatch GeneralMatrixGate(randn(3,3))

reg = rand_state(2)
@test copy(reg) |> mg |> statevec == mg.matrix * reg.state |> vec
end

4 changes: 4 additions & 0 deletions test/Blocks/Primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ end
@testset "ReflectBlock" begin
include("ReflectBlock.jl")
end

@testset "GeneralMatrixGate" begin
include("GeneralMatrixGate.jl")
end
11 changes: 11 additions & 0 deletions test/Interfaces/Interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ end
@test reflect(psi |> statevec) isa ReflectBlock
end

@testset "matrix gate" begin
matrix = randn(4,8)
@test matrixgate(matrix) isa GeneralMatrixGate
@test matrixgate(matrix) isa GeneralMatrixGate
matrix = randn(4,7)
@test_throws DimensionMismatch matrixgate(matrix)
end

@testset "chain" begin
@test chain(X, Y, Z) isa ChainBlock
end
Expand Down Expand Up @@ -126,6 +134,9 @@ end
Probs = @fn probs
@test Probs isa FunctionBlock{typeof(probs)}
@test apply!(copy(reg), Probs) == reg |> probs

FB = focus(1,3,2)
@test copy(reg) |> FB == focus!(copy(reg), [1,3,2])
end

@testset "sequence" begin
Expand Down
Loading

0 comments on commit 89894a5

Please sign in to comment.