# Experiment: Comparing different sampling parameters for training

In [1]:
using Revise, QARBoM, DWave, DataFrames, CSV

In [2]:
MOI = QARBoM.QUBO.ToQUBO.MOI

MathOptInterface

In [3]:
df = DataFrame(CSV.File(raw"./converted_bool_only.csv"))

x_train = Vector{Vector{Int}}()

for row in eachrow(df)
    push!(x_train, collect(row))
end

In [4]:
learning_rate = 0.0002

0.0002

In [5]:
num_sweeps = [i for i in 1:25:101]
num_reads = [i for i in 1:25:101]

5-element Vector{Int64}:
   1
  26
  51
  76
 101

In [31]:
W = randn(22,10)

22×10 Matrix{Float64}:
 -0.216505   -0.765594    1.28173   …   0.45392    0.38484     -0.615173
  0.345152   -1.12937     0.541348     -0.18382   -0.672884    -1.71366
 -0.439022   -0.761079   -1.31743      -0.475569  -0.0731931   -0.314445
  1.39857    -0.513349   -0.863664      0.383678  -0.391906    -0.795293
 -0.713223   -0.868238   -0.24816      -0.212508   1.27777      0.156998
  1.05306     0.470819   -2.15005   …  -0.774081   1.53349     -0.113756
  0.425849    0.765588    0.024888      0.574296  -0.330341     0.669741
 -1.1631     -1.27941     0.492526     -0.592812  -0.00362266  -1.21985
  0.974727   -0.835029    0.545166      1.31181    0.229642     1.61809
  1.37342     1.22238    -1.06174      -0.385003   0.271793     1.03826
  0.463458    0.482911   -0.431246  …  -0.161528   1.72504      0.864121
  0.0699086  -0.0543774  -0.493776     -0.959612   0.666959    -0.217816
 -0.38154     0.628583    0.805866     -0.261545   0.935392     0.336466
 -0.823389    1.23848    -1.1350

In [32]:
all_mse = []
times = Vector{Float64}()

Float64[]

In [33]:
MOI.supports(::DWave.Neal.Optimizer, ::MOI.ObjectiveSense) = true

In [34]:
for nr in num_reads
    for ns in num_sweeps
        rbm = RBM(22,10, W)

        function setup_dwave(model, sampler)
            MOI.set(model, MOI.RawOptimizerAttribute("num_reads"), nr)
            MOI.set(model, MOI.RawOptimizerAttribute("num_sweeps"), ns)
        end
    
        println()
        println("Reads $(nr) | Sweeps $(ns)") 
        println()

        t = time()
    
        mse = QARBoM.train_persistent_qubo!(
            rbm, 
            x_train[1:10000];
            batch_size = 10, 
            n_epochs = 50,  
            learning_rate = [learning_rate for i in 1:50],
            model_setup = setup_dwave,
            sampler = DWave.Neal.Optimizer
        )
    
        push!(times, time() -t)
        push!(all_mse, mse)
    end
end


Reads 1 | Sweeps 1

Setting up QUBO model
Setting mini-batches
Starting training
|------------------------------------------------------------------------------|
| Epoch |    MSE    | Time (Sample) | Time (Qsamp) | Time (Update) | Total     |
|------------------------------------------------------------------------------|
|     1 |    7.2888 |        0.1820 |       4.1522 |        1.8667 |    6.2008 |
|------------------------------------------------------------------------------|
|------------------------------------------------------------------------------|
| Epoch |    MSE    | Time (Sample) | Time (Qsamp) | Time (Update) | Total     |
|------------------------------------------------------------------------------|
|     2 |    5.9785 |        0.0124 |       2.9002 |        1.2857 |   10.3992 |
|------------------------------------------------------------------------------|
|------------------------------------------------------------------------------|
| Epoch |    MSE    | Time 

In [35]:
df = DataFrame()
i = 1
for nr in num_reads
    for ns in num_sweeps
        df[!,"r_$(nr)_s_$(ns)"] =  all_mse[i]
        i += 1
    end
end
df

Row,r_1_s_1,r_1_s_26,r_1_s_51,r_1_s_76,r_1_s_101,r_26_s_1,r_26_s_26,r_26_s_51,r_26_s_76,r_26_s_101,r_51_s_1,r_51_s_26,r_51_s_51,r_51_s_76,r_51_s_101,r_76_s_1,r_76_s_26,r_76_s_51,r_76_s_76,r_76_s_101,r_101_s_1,r_101_s_26,r_101_s_51,r_101_s_76,r_101_s_101
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64
1,7.2888,7.03752,7.04563,7.0526,7.051,7.28749,7.02651,7.02826,7.02849,7.02877,7.29227,7.02728,7.02598,7.02738,7.02769,7.29338,7.02719,7.0259,7.02614,7.02718,7.29441,7.02713,7.02605,7.02606,7.02554
2,5.97852,5.38069,5.40335,5.41236,5.41446,5.97611,5.36003,5.3717,5.37437,5.37698,5.98251,5.36244,5.36717,5.36997,5.37365,5.98366,5.36401,5.36941,5.37004,5.37232,5.98948,5.36689,5.368,5.37075,5.36992
3,5.01713,4.18229,4.19616,4.20641,4.20466,5.04079,4.19114,4.1964,4.19737,4.20043,5.05121,4.19564,4.19615,4.2003,4.20232,5.05002,4.19984,4.19945,4.19984,4.19986,5.05452,4.20658,4.19883,4.20026,4.19917
4,4.34916,3.45759,3.46613,3.47573,3.47669,4.40346,3.52484,3.50925,3.50351,3.50204,4.40901,3.53636,3.52591,3.52062,3.5137,4.41277,3.5475,3.53046,3.52667,3.51753,4.41401,3.55613,3.53246,3.52888,3.52209
5,3.89132,3.04271,3.0525,3.06036,3.06259,3.96376,3.15107,3.11808,3.10214,3.09494,3.96616,3.19412,3.15934,3.13913,3.12258,3.97047,3.22123,3.17883,3.16164,3.14218,3.97093,3.24564,3.19361,3.17072,3.15474
6,3.58056,2.77788,2.79866,2.80493,2.80865,3.6537,2.91395,2.88438,2.86391,2.85635,3.65799,2.98808,2.93641,2.91687,2.89808,3.66306,3.03447,2.97035,2.94342,2.92746,3.66545,3.06926,2.99856,2.95983,2.94276
7,3.34215,2.60644,2.62475,2.6296,2.63482,3.43825,2.74593,2.70364,2.6859,2.67739,3.44545,2.83288,2.76327,2.74012,2.72468,3.45044,2.88592,2.80451,2.7674,2.75116,3.4539,2.92508,2.83117,2.78507,2.76683
8,3.17788,2.47644,2.49196,2.49811,2.5019,3.28142,2.61318,2.56049,2.5494,2.53901,3.29171,2.69338,2.61198,2.59035,2.58245,3.2967,2.74,2.64064,2.61049,2.59897,3.29986,2.77984,2.66323,2.62405,2.61098
9,3.0476,2.38706,2.39832,2.40384,2.40768,3.16965,2.49278,2.43645,2.42567,2.41728,3.17752,2.57605,2.48654,2.46294,2.45189,3.18262,2.61972,2.52015,2.48402,2.4676,3.18347,2.65102,2.54074,2.50095,2.47906
10,2.9628,2.3226,2.33238,2.33774,2.34212,3.0871,2.3862,2.33391,2.32348,2.32041,3.09463,2.45458,2.37565,2.35448,2.3445,3.0992,2.49482,2.4066,2.37618,2.35973,3.10141,2.52059,2.43008,2.39022,2.3727


In [36]:
CSV.write("benchmark_dwave_parameters.csv", df)

"benchmark_dwave_parameters.csv"

In [37]:
using Plots

In [38]:
df_loaded = DataFrame(CSV.File("benchmark_dwave_parameters.csv"))

Row,r_1_s_1,r_1_s_26,r_1_s_51,r_1_s_76,r_1_s_101,r_26_s_1,r_26_s_26,r_26_s_51,r_26_s_76,r_26_s_101,r_51_s_1,r_51_s_26,r_51_s_51,r_51_s_76,r_51_s_101,r_76_s_1,r_76_s_26,r_76_s_51,r_76_s_76,r_76_s_101,r_101_s_1,r_101_s_26,r_101_s_51,r_101_s_76,r_101_s_101
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64
1,7.2888,7.03752,7.04563,7.0526,7.051,7.28749,7.02651,7.02826,7.02849,7.02877,7.29227,7.02728,7.02598,7.02738,7.02769,7.29338,7.02719,7.0259,7.02614,7.02718,7.29441,7.02713,7.02605,7.02606,7.02554
2,5.97852,5.38069,5.40335,5.41236,5.41446,5.97611,5.36003,5.3717,5.37437,5.37698,5.98251,5.36244,5.36717,5.36997,5.37365,5.98366,5.36401,5.36941,5.37004,5.37232,5.98948,5.36689,5.368,5.37075,5.36992
3,5.01713,4.18229,4.19616,4.20641,4.20466,5.04079,4.19114,4.1964,4.19737,4.20043,5.05121,4.19564,4.19615,4.2003,4.20232,5.05002,4.19984,4.19945,4.19984,4.19986,5.05452,4.20658,4.19883,4.20026,4.19917
4,4.34916,3.45759,3.46613,3.47573,3.47669,4.40346,3.52484,3.50925,3.50351,3.50204,4.40901,3.53636,3.52591,3.52062,3.5137,4.41277,3.5475,3.53046,3.52667,3.51753,4.41401,3.55613,3.53246,3.52888,3.52209
5,3.89132,3.04271,3.0525,3.06036,3.06259,3.96376,3.15107,3.11808,3.10214,3.09494,3.96616,3.19412,3.15934,3.13913,3.12258,3.97047,3.22123,3.17883,3.16164,3.14218,3.97093,3.24564,3.19361,3.17072,3.15474
6,3.58056,2.77788,2.79866,2.80493,2.80865,3.6537,2.91395,2.88438,2.86391,2.85635,3.65799,2.98808,2.93641,2.91687,2.89808,3.66306,3.03447,2.97035,2.94342,2.92746,3.66545,3.06926,2.99856,2.95983,2.94276
7,3.34215,2.60644,2.62475,2.6296,2.63482,3.43825,2.74593,2.70364,2.6859,2.67739,3.44545,2.83288,2.76327,2.74012,2.72468,3.45044,2.88592,2.80451,2.7674,2.75116,3.4539,2.92508,2.83117,2.78507,2.76683
8,3.17788,2.47644,2.49196,2.49811,2.5019,3.28142,2.61318,2.56049,2.5494,2.53901,3.29171,2.69338,2.61198,2.59035,2.58245,3.2967,2.74,2.64064,2.61049,2.59897,3.29986,2.77984,2.66323,2.62405,2.61098
9,3.0476,2.38706,2.39832,2.40384,2.40768,3.16965,2.49278,2.43645,2.42567,2.41728,3.17752,2.57605,2.48654,2.46294,2.45189,3.18262,2.61972,2.52015,2.48402,2.4676,3.18347,2.65102,2.54074,2.50095,2.47906
10,2.9628,2.3226,2.33238,2.33774,2.34212,3.0871,2.3862,2.33391,2.32348,2.32041,3.09463,2.45458,2.37565,2.35448,2.3445,3.0992,2.49482,2.4066,2.37618,2.35973,3.10141,2.52059,2.43008,2.39022,2.3727


In [79]:
df_loaded[1,"r_1_s_1"]

7.28880148213567

In [83]:
keys_to_remove = []
for key in names(df_loaded)
    if df_loaded[50, key] > 1.95
        push!(keys_to_remove, key)
    end
end

In [84]:
# df_to_plot = select!(df_loaded, Not(["0.0005", "0.0007", "0.0008", "0.0009", "0.001"]))
df_to_plot = select!(df_loaded, Not(keys_to_remove))
# df_to_plot = df_loaded

Row,r_26_s_26,r_26_s_51,r_26_s_76,r_26_s_101,r_51_s_76,r_51_s_101,r_76_s_76,r_76_s_101,r_101_s_101
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64
1,7.02651,7.02826,7.02849,7.02877,7.02738,7.02769,7.02614,7.02718,7.02554
2,5.36003,5.3717,5.37437,5.37698,5.36997,5.37365,5.37004,5.37232,5.36992
3,4.19114,4.1964,4.19737,4.20043,4.2003,4.20232,4.19984,4.19986,4.19917
4,3.52484,3.50925,3.50351,3.50204,3.52062,3.5137,3.52667,3.51753,3.52209
5,3.15107,3.11808,3.10214,3.09494,3.13913,3.12258,3.16164,3.14218,3.15474
6,2.91395,2.88438,2.86391,2.85635,2.91687,2.89808,2.94342,2.92746,2.94276
7,2.74593,2.70364,2.6859,2.67739,2.74012,2.72468,2.7674,2.75116,2.76683
8,2.61318,2.56049,2.5494,2.53901,2.59035,2.58245,2.61049,2.59897,2.61098
9,2.49278,2.43645,2.42567,2.41728,2.46294,2.45189,2.48402,2.4676,2.47906
10,2.3862,2.33391,2.32348,2.32041,2.35448,2.3445,2.37618,2.35973,2.3727


In [92]:
plot(Matrix(df_to_plot), labels=permutedims(string.(names(df_to_plot))), linewidth=2, botton_margin=0, right_margin = 3Plots.mm)
# plot(range(1,50, length = 50), all_mse[2],label = "ϵ = 0.002")
theme(:wong)


lens!([49, 50], [1.88, 1.95], inset = (1, bbox(0.5, 0.0, 0.4, 0.4)), subplot=2)


xlabel!("Epoch", subplot =1)
ylabel!("Mean Squared Error", subplot =1 )

plot!(legend=:outerbottom, legendcolumns=3)

savefig("errors.pdf")

"/Users/pripper/Documents/GitHub/RBM/example/Heart/errors.pdf"