In [35]:
using Compat.Test
using QuCircuit
import Base:sparse
import QuCircuit: GateType, gate

In [36]:
import QuCircuit: PrimitiveBlock
# use different types for parameters and matrix

mutable struct RotBasis{T} <: PrimitiveBlock{1, Complex{T}}
    theta::T
    phi::T
end
RotBasis(0.5, 0.4)

RotBasis{Float64}(0.5, 0.4)

In [48]:
# chain -> *
sparse(rb::RotBasis{T}) where T = sparse(chain(rot(:Z, -rb.phi), rot(:Y, -rb.theta)))
rot_basis(num_bit::Int) = roll(num_bit, RotBasis(0.5, 0.2))

rot_basis (generic function with 1 method)

In [38]:
num_bit = 4
reg = zero_state(num_bit) # useless randn_state
rb = rot_basis(num_bit)
rb(reg)

Default Register (CPU, Complex{Float64}):
    total: 4
    batch: 1
    active: 4

In [39]:
function get_axis(A::AbstractArray{T, N}, dim, i) where {T, N}
	getindex(A, ntuple(d->d==dim?i:Colon(), Val{N})...)
end

get_axis (generic function with 1 method)

In [40]:
# translation between polar angle and len-2 complex vector.
function u2polar(vec)
    ratio = get_axis(vec, 1, 2)./get_axis(vec, 1, 1)
    @. [atan(abs(ratio))'*2; angle(ratio)']
end

function polar2u(polar)
    theta, phi = get_axis(polar, 1, 1)', get_axis(polar, 1, 2)'
    @. [cos(theta/2)*exp(-im*phi/2); sin(theta/2)*exp(im*phi/2)]
end

# random polar basis, n-> number of basis
randpolar(params...) = rand(2, params...)*pi

randpolar (generic function with 1 method)

In [41]:
polar = randpolar(10)
print(size(polar))
@test all(isapprox.(polar |> polar2u |> u2polar, polar))

(2, 10)

[1m[32mTest Passed
[39m[22m

In [50]:
import QuCircuit: dispatch!, nparameters

nparameters(rb::RotBasis) = 2
function dispatch!(f::Function, rb::RotBasis, params::Vector)
    rb.theta = f(rb.theta, params[1])
    rb.phi = f(rb.phi, params[2])
    rb
end

dispatch! (generic function with 12 methods)

In [65]:
rb = roll(1, RotBasis(0.1,0.3))#rot_basis(1)
angles = randpolar(1)
# prepair a state in the angles direction.
psi = angles |> polar2u |> register

# rotate to the same direction for measurements.
dispatch!(rb, vec(angles))
@test state(psi |> rb) ≈ [1,0]

[1m[32mTest Passed
[39m[22m

In [113]:
################################################
#     Random Basis WaveFunction Learning       #
################################################
# Define a circuit
num_bit = 6
depth = 10
circuit = qcbm(num_bit, depth)
rot = rot_basis(num_bit)

rand_basis!(rot) = dispatch!(rot, vec(randpolar(nqubit(rot))))

# rotate the output wave function
function (qcbm::QCBM)(params)
    psi = zero_state(nqubit(qcbm.circuit))
    dispatch!(qcbm.circuit, params)
    vec(rot(state(qcbm.circuit(psi))))
end

function train_rand_basis(qcbm::QCBM, psi_train; learning_rate = 0.1, maxiter=10)
    params = 2pi * rand(nparameters(qcbm))
    kernel = RBFKernel(qcbm.n, [2.0], false)

    for i = 1:maxiter
        # select random basis and rotate target wave function to selected basis.
        rt |> rand_basis
        ptrain = abs2(rot(psi_train))
        
        grad = gradient(qcbm, copy(params), kernel, ptrain)
        println(i, " step, loss = ", loss(qcbm, copy(params), kernel, ptrain))
        params .-= learning_rate * grad
    end
    params
end

LoadError: [91mUndefVarError: qcbm not defined[39m

In [None]:
################################################
#                   Clone GAN                  #
################################################

# probabilities and kernel
psi_train = ... # some state
p_data(γ::Vector{Float64}) = copy(psi_train) |> dispatch!(rot, γ) |> abs2
p_model(θ::Vector{Float64}) = zero_state(num_bit) |> dispatch!(circuit,θ) |> rot |> abs2
kernel = RBFKernel(num_bit, [2], false)

# get gradient with respect to single parameter
function grad1(theta_list::Vector, pd::Vector, pm_func::Function, pm::Vector, i::Int)
    theta_list_ = copy(theta_list)
    # +pi/2 phase
    pp = pm_func(theta_list_[i] += pi/2; theta_list_)
    # -pi/2 phase
    pn = pm_func(theta_list_[i] -= pi/2; theta_list_)
    expect(kernel, pm, pmp) - expect(kernel, pm0, pmn) - (expect(kernel, pd, pmp) - expect(kernel, pd, pmn))
end

takerot(θ::Vector) = view(θ, end-2*num_bit+1:end)
loss(θ::Vector{Float64}) = MMDLoss(kernel, p_model(θ), p_data(θ|>takerot))

function gradient(theta_list::Vector)
    # for stability consern, we do not use the cached probability output.
    γ = θ|>takerot
    pm = p_model(θ)
    pd = p_data(γ)

    grad = map(i->grad1(θ, pd, p_model, pm, i), 1:length(θ))
    grad_data = map(i->grad1(γ, pm, p_data, pd, i), 1:length(γ))
    grad |> takerot += grad_data

    grad
end

In [None]:
function train_clonegan(psi_train; learning_rate = 0.1, maxiter=10)
    θ = vcat(nparam(circuit), randpolar(num_bit)|>vec)
    kernel = RBFKernel(num_bit, [2.0], false)

    for i = 1:maxiter
        # select random basis and rotate target wave function to selected basis.
        grad = gradient(θ)
        grad |> takerot *= -5  # tune the rotation parameters
        println(i, " step, loss = ", loss(θ))
        
        params .-= learning_rate * grad
    end
    params
end

In [None]:
# check with respect to numerical differenciation
θ = vcat(nparam(circuit), randpolar(num_bit)|>vec)
@test all(isapprox.(gradient(θ), num_gradient(θ)))

In [None]:
#=
scipy sort CSR method.
/*
 * Sort CSR column indices inplace
 *
 * Input Arguments:
 *   I  n_row           - number of rows in A
 *   I  Ap[n_row+1]     - row pointer
 *   I  Aj[nnz(A)]      - column indices
 *   T  Ax[nnz(A)]      - nonzeros
 *
 */
template<class I, class T>
void csr_sort_indices(const I n_row,
                      const I Ap[],
                            I Aj[],
                            T Ax[])
{
    std::vector< std::pair<I,T> > temp;

    for(I i = 0; i < n_row; i++){
        I row_start = Ap[i];
        I row_end   = Ap[i+1];

        temp.resize(row_end - row_start);
        for (I jj = row_start, n = 0; jj < row_end; jj++, n++){
            temp[n].first  = Aj[jj];
            temp[n].second = Ax[jj];
        }

        std::sort(temp.begin(),temp.end(),kv_pair_less<I,T>);

        for(I jj = row_start, n = 0; jj < row_end; jj++, n++){
            Aj[jj] = temp[n].first;
            Ax[jj] = temp[n].second;
        }
    }
}
=#