In [1]:
include("INCLUDEME.jl")

using Yao, Circuit, UnicodePlots, GradOptim, Utils, ArgParse
import Kernels

In [2]:
function train!(qcbm::QCBM, ptrain, optim; learning_rate=0.1, maxiter=100)
    initialize!(qcbm)
    kernel = Kernels.RBFKernel(nqubits(qcbm), [0.25], false)
    history = Float64[]

    for i = 1:maxiter
        grad = gradient(qcbm, kernel, ptrain)
        curr_loss = loss(qcbm, kernel, ptrain)
        push!(history, curr_loss)
        println(i, " step, loss = ", curr_loss)

        # Warn: we need a primitive block to enable
        # BLAS here.
        params = parameters(qcbm)
        update!(params, grad, optim)
        dispatch!(qcbm, params)
    end
    history
end

train! (generic function with 1 method)

In [14]:
#data to learn
pg = gaussian_pdf(n, 2^5-0.5, 2^4)
fig = lineplot(0:1<<n - 1, pg)
display(fig)

[37m        ┌────────────────────────────────────────┐[39m 
   [37m0.03[39m[37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m│[39m [37m[39m
       [37m[39m[37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[34m⢀[39m[34m⣀[39m[34m⡀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39

In [15]:
# parameters
n = 6
maxiter = 20

# solver setup
qcbm = QCBM{n, 10}(get_nn_pairs(n))
optim = Adam(lr=0.1);

In [17]:
his = train!(qcbm, pg, optim, maxiter=maxiter);

1 step, loss = 0.020212867842247333
2 step, loss = 0.011890823716518104
3 step, loss = 0.00693257727961275
4 step, loss = 0.005148464742214652
5 step, loss = 0.005227194305435318
6 step, loss = 0.005237624153717067
7 step, loss = 0.004439455400828236
8 step, loss = 0.003827060128211739
9 step, loss = 0.0030615670889751886
10 step, loss = 0.002004315982445699
11 step, loss = 0.0015847366705968842
12 step, loss = 0.0015895464249945789
13 step, loss = 0.001633358101842722
14 step, loss = 0.0016053038478056228
15 step, loss = 0.001254005773569636
16 step, loss = 0.0008001349776185522
17 step, loss = 0.0009168050479497432
18 step, loss = 0.001076649805595626
19 step, loss = 0.000825315115400745
20 step, loss = 0.0005386157959737869


In [18]:
# analyze result
display(lineplot(his, title = "loss"))
psi = qcbm()
p = statevec(psi) .|> abs2
lineplot!(fig, p, color=:yellow, name="trained")
display(fig)

[37m                          loss
[39m[37m        ┌────────────────────────────────────────┐[39m 
   [37m0.03[39m[37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m│[39m [37m[39m
       [37m[39m[37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[

[37m        ┌────────────────────────────────────────┐[39m        
   [37m0.03[39m[37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[33m⢀[39m[33m⢸[39m[33m⡀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m│[39m [33mtrained[39m
       [37m[39m[37m │[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[33m⢸[39m[33m⢸[39m[33m⢱[39m[32m⣾[39m[32m⣸[39m[32m⡇[39m[33m⢠[39m[33m⡆[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀[39m[37m⠀