In [1]:
include("INCLUDEME.jl")

using Yao, Circuit, UnicodePlots, GradOptim, Utils, ArgParse, wavefunctions
import Kernels
using Compat.Test



In [2]:
iter = 200
N = 4
monitor = 10
depth = 6

target0 = (register(bit"1"^N) + register(bit"0"^N))|>normalize!
#target0 = register(ghz(N))  # why this can not work?
qcbm = QCBM{N, 6}(get_nn_pairs(N)) |> initialize!
optim = Adam(lr=0.1)
kernel = Kernels.RBFKernel(nqubits(qcbm), [0.25], false)
rot = roll(N, rotbasis())
circuit = chain(qcbm, rot);

In [12]:
# check target wavefunction.
@test statevec(target0) == ghz(N)

[1m[32mTest Passed
[39m[22m

In [13]:
# check correctly dispatch
dispatch!(rot, ones(2 * N))
@test parameters(rot) == parameters(circuit[end]) == ones(2 * N)

[1m[91mError During Test
[39m[22m  Test threw an exception of type MethodError
  Expression: parameters(rot) == parameters(circuit[end]) == ones(2N)
  [91mMethodError: no method matching parameters(::Yao.Blocks.Roller{4,4,Complex{Float64},NTuple{4,Circuit.RotBasis{Float64}}})[0m
  Closest candidates are:
    parameters([91m::Circuit.QCBM{N,NL,CT,T} where T where CT[39m) where {N, NL} at /home/leo/jcode/QCBM/modules/Circuit.jl:71[39m
  Stacktrace:
   [1] [1minclude_string[22m[22m[1m([22m[22m::String, ::String[1m)[22m[22m at [1m./loading.jl:515[22m[22m
   [2] [1mexecute_request[22m[22m[1m([22m[22m::ZMQ.Socket, ::IJulia.Msg[1m)[22m[22m at [1m/home/leo/.julia/v0.6/IJulia/src/execute_request.jl:158[22m[22m
   [3] [1m(::Compat.#inner#17{Array{Any,1},IJulia.#execute_request,Tuple{ZMQ.Socket,IJulia.Msg}})[22m[22m[1m([22m[22m[1m)[22m[22m at [1m/home/leo/.julia/v0.6/Compat/src/Compat.jl:385[22m[22m
   [4] [1meventloop[22m[22m[1m([22m[22m::ZMQ.So

LoadError: [91mThere was an error during testing[39m

In [None]:
history = Float64[]
fedility = Float64[]

for i = 1:iter
    dispatch!(rot, pi * rand(2 * N))
    target = rot(copy(target0))

    ptrain = abs2.(statevec(target))
    grad = gradient(qcbm, kernel, ptrain)

    if i % monitor == 0
        curr_loss = loss(qcbm, kernel, ptrain)
        curr_fedility = abs.(dot(statevec(qcbm()), statevec(target)))
        push!(history, curr_loss)
        push!(fedility, curr_fedility)
        println(i, " step, loss = ", curr_loss)
        println("fedility: ", curr_fedility)
    end
    # Warn: we need a primitive block to enable
    # BLAS here.
    params = parameters(qcbm)
    update!(params, grad, optim)
    dispatch!(qcbm, params)
end

display(lineplot(history, title = "loss"))
display(lineplot(fidelity, title = "fidelity"))