In [3]:
using Yao
using Yao.Blocks
export ControlBlock

tuplesort(tp; by::Function=x->x) = (sort([tp...], by=by)...)

"""
    ControlBlock{BT, N, C, B, T}

BT: controlled block type,
N: number of qubits,
C: number of control bits,
T: type of matrix.
"""
mutable struct ControlBlock{BT, N, C, T} <: CompositeBlock{N, T}
    ctrl_qubits::NTuple{C}
    block::BT
    addr::Int

    # TODO: input a control block, we need to expand this control block to its upper parent block
    # function ControlBlock{N}(ctrl_qubits::Vector{Int}, ctrl::ControlBlock, addr::Int) where {N, K, T}
    # end

    function ControlBlock{N, T}(ctrl_qubits::NTuple{C}, block::BT, addr::Int) where {BT, N, C, T}
        new{BT, N, C, T}(ctrl_qubits, block, addr)
    end

    function ControlBlock{N}(ctrl_qubits::NTuple{C}, block::BT, addr::Int) where {N, C, B, T, BT <: MatrixBlock{B, T}}
        # NOTE: control qubits use sign to characterize
        # inverse control qubits
        # we sort it from lowest addr to highest first
        # this will help we have an deterministic behaviour
        # TODO: remove repeated, add error
        ordered_control = tuplesort(ctrl_qubits, by=x->abs(x))
        new{BT, N, C, T}(ordered_control, block, addr)
    end
end

function ControlBlock{N}(ctrl_qubits::NTuple{C}, target::Pair{Int, BT}) where {N, K, C, T, BT <: MatrixBlock{K, T}}
    ControlBlock{N}(ctrl_qubits, target.second, target.first)
end

function ControlBlock(ctrl_qubits::NTuple, block, addr::Int)
    N = max(maximum(abs.(ctrl_qubits)), addr)
    ControlBlock{N}(ctrl_qubits, block, addr)
end

function copy(ctrl::ControlBlock{BT, N, C, T}) where {BT, N, C, T}
    ControlBlock{BT, N, C, T}(copy(ctrl.ctrl_qubits), ctrl.block, copy(ctrl.addr))
end

function mat(ctrl::ControlBlock{BT, N, C, T}) where {BT, N, C, T}
    # NOTE: we sort the addr of control qubits by its relative addr to
    # the block under control, this is useful when calculate its
    # matrix form.
    ctrl_addrs = sort(ctrl.ctrl_qubits, by=x->abs(abs(x)-ctrl.addr))

    # start of the iteration
    U = mat(ctrl.block)
    addr = ctrl.addr
    U_nqubit = nqubits(ctrl.block)
    for each_ctrl in ctrl_addrs
        if each_ctrl > 0
            U = _single_control_gate_sparse(abs(each_ctrl), U, addr, U_nqubit)
        else
            U = _single_inverse_control_gate_sparse(abs(each_ctrl), U, addr, U_nqubit)
        end

        head = addr # inner block head
        tail = addr + U_nqubit - 1 # inner block tail
        inc = min(abs(head - abs(each_ctrl)), abs(tail - abs(each_ctrl)))
        U_nqubit = U_nqubit + inc
        addr = min(abs(each_ctrl), addr)
    end

    # check blank lines at the beginning
    lowest_addr = min(minimum(abs.(ctrl_addrs)), ctrl.addr)
    if lowest_addr != 1 # lowest addr is not from the first
        nblank = lowest_addr - 1
        U = kron(U, IMatrix{1 << nblank, T}())
    end

    # check blank lines in the end
    highest_addr = max(maximum(abs.(ctrl_addrs)), ctrl.addr)
    if highest_addr != N # highest addr is not the last
        nblank = N - highest_addr
        U = kron(IMatrix{1 << nblank, T}(), U)
    end
    U
end

function _single_inverse_control_gate_sparse(control::Int, U, addr, nqubit)
    @assert control != addr "cannot control itself"

    T = eltype(U)
    if control < addr
        op = A_kron_B(
            mat(P1(T)), control, 1,
            IMatrix(U), addr
        )
        op += A_kron_B(
            mat(P0(T)), control, 1,
            U, addr
        )
    else
        op = A_kron_B(
            IMatrix(U), addr, nqubit,
            mat(P1(T)), control
        )
        op += A_kron_B(
            U, addr, nqubit,
            mat(P0(T)), control
        )
    end
    op
end

function _single_control_gate_sparse(control::Int, U, addr, nqubit)
    @assert control != addr "cannot control itself"

    T = eltype(U)
    if control < addr
        op = A_kron_B(
            mat(P0(T)), control, 1,
            IMatrix(U), addr
        )
        op += A_kron_B(
            mat(P1(T)), control, 1,
            U, addr
        )
    else
        op = A_kron_B(
            IMatrix(U), addr, nqubit,
            mat(P0(T)), control
        )
        op += A_kron_B(
            U, addr, nqubit,
            mat(P1(T)), control
        )
    end
    op
end

# kronecker A and B relatively on position ia, ib
# A has size 2^na x 2^na
function A_kron_B(A, ia, na, B, ib)
    T = eltype(A)

    out = A
    if ia + na < ib
        blank_size = ib - ia - na
        out = kron(IMatrix{1 << blank_size, T}(), out)
    end
    kron(B, out)
end

blocks(c::ControlBlock) = [c.block]

#################
# Dispatch Rules
#################

# NOTE: ControlBlock will forward parameters directly without loop
function dispatch!(f::Function, ctrl::ControlBlock, params::Vector)
    dispatch!(f, ctrl.block, params)
end

function dispatch!(f::Function, ctrl::ControlBlock, params...)
    dispatch!(f, ctrl.block, params...)
end

function hash(ctrl::ControlBlock, h::UInt)
    hashkey = hash(objectid(ctrl), h)
    for each in ctrl.ctrl_qubits
        hashkey = hash(each, hashkey)
    end

    hashkey = hash(ctrl.block, hashkey)
    hashkey = hash(ctrl.addr, hashkey)
    hashkey
end

function ==(lhs::ControlBlock{BT, N, C, T}, rhs::ControlBlock{BT, N, C, T}) where {BT, N, C, T}
    (lhs.ctrl_qubits == rhs.ctrl_qubits) && (lhs.block == rhs.block) && (lhs.addr == rhs.addr)
end

function print_block(io::IO, x::ControlBlock)
    printstyled(io, "control("; bold=true, color=color(ControlBlock))

    for i in eachindex(x.ctrl_qubits)
        printstyled(io, x.ctrl_qubits[i]; bold=true, color=color(ControlBlock))

        if i != lastindex(x.ctrl_qubits)
            printstyled(io, ", "; bold=true, color=color(ControlBlock))
        end
    end
    printstyled(io, ")"; bold=true, color=color(ControlBlock))
end



print_block (generic function with 1 method)