In [1]:
using Yao
using Yao.Blocks
using Yao.LuxurySparse
using Yao.Intrinsics
export ControlBlock
import Yao.Blocks: mat, blocks, copy
using Compat.Test
using BenchmarkTools

In [9]:
mutable struct ControlBlock{N, BT<:AbstractBlock, C, T, PT<:AbstractMatrix{T}} <: CompositeBlock{N, T}
    ctrl_qubits::NTuple{C, Int}
    projectors::NTuple{C, PT}
    block::BT
    addr::Int
end

function ControlBlock{N}(ctrl_qubits::NTuple{C, Int}, projectors::NTuple{C, PT}, block::BT, addr::Int) where {BT, N, C, T, PT<:AbstractMatrix{T}}
    ControlBlock{N, BT, C, T, PT}(ctrl_qubits, projectors, block, addr)
end

function copy(ctrl::ControlBlock{BT, N, C, T}) where {BT, N, C, T}
    ControlBlock{BT, N, C, T}(copy(ctrl.ctrl_qubits), copy(ctrl.projectors), ctrl.block, copy(ctrl.addr))
end

general_controlled_gates(num_bit::Int, projectors::Vector{Tp}, cbits::Vector{Int}, gates::Vector{Tg}, locs::Vector{Int}) where {Tg<:AbstractMatrix, Tp<:AbstractMatrix} = IMatrix(1<<num_bit) - hilbertkron(num_bit, projectors, cbits) + hilbertkron(num_bit, vcat(projectors, gates), vcat(cbits, locs))
general_c1_gates(num_bit::Int, projector::Tp, cbit::Int, gates::Vector{Tg}, locs::Vector{Int}) where {Tg<:AbstractMatrix, Tp<:AbstractMatrix} = hilbertkron(num_bit, [mat(I2) - projector], [cbit]) + hilbertkron(num_bit, vcat([projector], gates), vcat([cbit], locs))

mat(c::ControlBlock{N}) where N = general_controlled_gates(N, [c.projectors...], [c.ctrl_qubits...], [mat(c.block)], [c.addr])
mat(c::ControlBlock{N, BT, 1}) where {N, BT} = general_c1_gates(N, c.projectors[1], c.ctrl_qubits[1], [mat(c.block)], [c.addr])

blocks(c::ControlBlock) = [c.block]
addrs(c::ControlBlock) = [c.ctrl_qubits..., (c.addr.+addrs(c.blocks).-1)...]

addrs (generic function with 1 method)

In [3]:
cg1 = ControlBlock{2}((2,), (mat(P1),), X, 1)
@test mat(cg1) == mat(CNOT)

LoadError: [91mMethodError: no method matching ControlBlock{2,BT,C,T} where T where C where BT<:Yao.Blocks.AbstractBlock(::Tuple{Int64}, ::Tuple{SparseMatrixCSC{Complex{Float64},Int64}}, ::Yao.Blocks.XGate{Complex{Float64}}, ::Int64)[39m

In [11]:
cg = ControlBlock{3}((3,2), (mat(P1), mat(P1)), X, 1)
@test mat(cg) == mat(Toffoli)

[1m[32mTest Passed
[39m[22m

In [12]:
CG1 = ControlBlock{16}((7,), (mat(P1),), X, 3)
CG2 = ControlBlock{16}((7, 6), (mat(P1), mat(P0)), X, 3)

Total: 16, DataType: Complex{Float64}
ControlBlock{16,Yao.Blocks.XGate{Complex{Float64}},2,Complex{Float64},SparseMatrixCSC{Complex{Float64},Int64}}
└─ X gate


In [13]:
#@benchmark mat(CG2)
#@benchmark mat(CG1)

In [31]:
ControlBlock{N}(ctrl_qubits::NTuple{C, Int}, block::AbstractBlock, addr::Int) where {N, C} = ControlBlock{N}(ctrl_qubits, ([mat(P1) for i=1:C]...), block, addr)

In [32]:
cg = ControlBlock{3}((3,2), X, 1)
@test mat(cg) == mat(Toffoli)

[1m[32mTest Passed
[39m[22m

In [30]:
([12,3]...)

(12, 3)

In [7]:
mutable struct ControlBlock{N, BT<:AbstractBlock, C, T} <: CompositeBlock{N, T}
    ctrl_qubits::NTuple{C, Int}
    vals::NTuple{C, Int}
    block::BT
    addr::Int
end

function ControlBlock{N}(ctrl_qubits::NTuple{C, Int}, vals::NTuple{C, Int}, block::BT, addr::Int) where {BT<:AbstractBlock, N, C}
    ControlBlock{N, BT, C, Bool}(ctrl_qubits, vals, block, addr)
end

function ControlBlock{N}(ctrl_qubits::NTuple{C, Int}, vals::NTuple{C, Int}, block::BT, addr::Int) where {N, C, T, BT<:MatrixBlock{N, T}}
    ControlBlock{N, BT, C, T}(ctrl_qubits, vals, block, addr)
end

ControlBlock{N}(ctrl_qubits::NTuple{C, Int}, block::AbstractBlock, addr::Int) where {N, C} = ControlBlock{N}(ctrl_qubits, (ones(Int, C)...), block, addr)

In [8]:
cg1 = ControlBlock{2}((2,), (1,), X, 1)

MethodError: [91mMethodError: no method matching blocks(::ControlBlock{2,Yao.Blocks.XGate{Complex{Float64}},1,Bool})[0m
Closest candidates are:
  blocks([91m::Yao.Blocks.CachedBlock[39m) at /home/leo/jcode/Yao.jl/src/Blocks/CachedBlock.jl:68
  blocks([91m::Yao.Blocks.ChainBlock[39m) at /home/leo/jcode/Yao.jl/src/Blocks/ChainBlock.jl:59
  blocks([91m::Yao.Blocks.KronBlock[39m) at /home/leo/jcode/Yao.jl/src/Blocks/KronBlock.jl:84
  ...[39m

In [21]:
decode_sign(ctrls::Int...) = ctrls .|> abs, ctrls .|> sign .|> (x->(1+x)÷2)

decode_sign (generic function with 1 method)

In [22]:
decode_sign(-1, 2, -3)

((1, 2, 3), (0, 1, 0))