In [2]:
using LinearAlgebra
using Random
using Flux
using Distributions
using OneHotArrays

using DataFrames
using CSV
using BSON: @save
using BSON: @load
using Plots
#ENV["GKSwstype"]="nul"

In [2]:
struct Env
    Ns::Int #siteの数
    num_var::Int  #operatorの数(ψ_iとψ'_i)
    num_br::Int #binary operatorの数。多分上記の +, -, *, -i[,], {,} の５つ？([,]などは*と+で表現できるが、のちに変数は２回以上使わないという制約を課したいので、変数を一回使うだけで交換関係を表現できるように導入しておく)1000の位に格納
    num_fn::Int #unitary operatorの数。 多分exp[], log[], diag[]の３つ。100の位に格納
    num_ter::Int
    num_tot::Int

    n_level::Int
    n_batch::Int

    op_fn
    op_br

    conv_ac
end

In [3]:
function set_fn()
    function daig_mat(M)
        e,v = eigen(M)
        return diagm(e)
    end
    return [x->exp.(x), x->log.(x), x->daig_mat.(x)]
end

function set_br()
    return [(x,y)->(x.+y), (x,y)->(x.-y), (x,y)->(x.*y)]
    #return [(x,y)->(x+y), (x,y)->(x-y), (x,y)->(x*y), (x,y)->-1.0im*(x*y .- y*x), (x,y)->(x*y .+ y*x)/2]
end

function init_Env(N::Int, b::Int)
    Ns = N #siteの数
    num_var = 2*Ns  #operatorの数(ψ_iとψ'_i)
    num_br = 3 #binary operatorの数。多分上記の +, -, *, -i[,], {,} の５つ？([,]などは*と+で表現できるが、のちに変数は２回以上使わないという制約を課したいので、変数を一回使うだけで交換関係を表現できるように導入しておく)1000の位に格納
    num_fn = 3 #unitary operatorの数。 多分exp[], log[], diag[]の３つ。100の位に格納
    num_ter = 1
    num_tot = num_var + num_br + num_fn

    n_level = 2
    n_batch = b

    op_fn = set_fn()
    op_br = set_br()

    conv_ac = zeros(Int, num_tot)
    for i in 1:num_var
        conv_ac[i] = i
    end
    for i in 1:num_br
        conv_ac[i+num_var] = 1000i
    end
    for i in 1:num_fn
        conv_ac[i+num_var+num_br] = 100i
    end

    return Ns, num_var, num_br, num_fn, num_ter, num_tot, n_level, n_batch, op_fn, op_br, conv_ac
end

init_Env (generic function with 1 method)

In [4]:
struct DQN 
    width
    act_MAX
    ϵ
    prob
    rand_ac
end

In [5]:
function init_DQN(w::Int, a_MAX::Int, ϵ::Float64, en::Env)
    prob = [ϵ, 1-ϵ]
    rand_ac = ones(Float64, en.num_tot)./en.num_tot
    return w, a_MAX, ϵ, prob, rand_ac
end

init_DQN (generic function with 1 method)

In [6]:
struct Sample
    var::Vector{Vector{Float64}}
    gauge_sample::Matrix{Float64}
end

In [7]:
function Gene_Rand_Var(en::Env)
    var::Vector{Vector{Float64}} = []
    θ = zeros(Float64,2)
    push!(var, θ)
    for i in 1:en.Ns-1
        θ += [pi*(1.0+rand(Float64))/2en.Ns , pi*(0.5-rand(Float64))/en.Ns]
        push!(var, θ)
    end
    return var
end

function get_Sample(en::Env)
    var = Gene_Rand_Var(en)
    gauge_sample = 2pi*rand(Float64, en.n_batch, en.Ns)
    return var, gauge_sample
end

get_Sample (generic function with 1 method)

In [10]:
mutable struct Agt
    #model
    state::Vector{Int}
    branch::Vector{Int}
    q_table::Matrix{Float32}
end

In [8]:
function init_agt(en::Env, dq::DQN)
    #model = Chain(Dense(dq.act_MAX, dq.width, relu, init=Flux.zeros32), Dense(dq.width, dq.width, relu, init=Flux.zeros32), Dense(dq.width, en.num_tot, init=Flux.zeros32))
    state::Vector{Int}=[]
    branch::Vector{Int}=[]
    q_table = zeros(Float32, dq.act_MAX, en.num_tot)
    #return model, state, branch, q_table
    return state, branch, q_table
end

function action_vec(q_t::Vector{Float32}, en::Env, dq::DQN)
    sw = rand(Categorical(dq.prob))
    if(sw == 1)
        act = rand(Categorical(dq.rand_ac))
    else
        act_n = findall(q_t .== maximum(q_t))
        n_a = length(act_n)
        ac_prob = ones(Float64, n_a)./n_a
        act = act_n[rand(Categorical(ac_prob))]
    end
    return onehot(Int, 1:en.num_tot, act)
end

action_vec (generic function with 1 method)

In [11]:
function decide_action!(en::Env, dq::DQN, ag::Agt, model, t::Int)
    rem_turn = zeros(Int, dq.act_MAX + 1 - t)
    st_vec = vcat(ag.state, rem_turn)
    #q_t = ag.model(st_vec)
    q_t = model(st_vec)
    act = en.conv_ac' * action_vec(q_t, en, dq) 
    return q_t, act
end

decide_action! (generic function with 1 method)

In [12]:
function rule_violate(ag::Agt, ac::Int)
    if(length(ag.state)>0)
        if(ac>99)
            if(ac>999) #branch 
                if(ac == ag.state[end])
                    #println("branch violation!")
                    return true
                elseif(ac<3 && ag.state[end]<3)
                    #println("branch violation!")
                    return true
                else
                    return false
                end
            else
                if((ac==1 && ag.state[end]==2) || (ac==2 && ag.state[end]==1))
                    #println("fn violation!")
                    return true
                elseif(ac== ag.state[end])
                    #println("fn violation!")
                    return true
                else
                    return false
                end
            end
        else
            if(length(findall(isequal(ac),ag.state))==0)
                return false
            else
                #println("reuse the same var!")
                return true
            end
        end
    else
        return false
    end
end

function VarToLoss(var::Vector{Matrix{ComplexF64}})
    loss = 0.0
    sw = (size(var[1])[1]==size(var[1])[2])
    if(sw)
        for i in 2:size(var)[1]
            for j in 1:i
                loss += abs(tr(var[i])-tr(var[j]))^2
            end
        end
    else
        for i in 2:size(var)[1]
            for j in 1:i
                loss += sum((abs.(var[i]-var[j])).^2)
            end
        end
    end
    return loss/size(var)[1]
end

function VarToLoss(var::Vector{ComplexF64})
    loss = 0.0
    for i in 2:size(var)[1]
        for j in 1:i
            loss += real((var[i]-var[j])'*(var[i]-var[j]))            
        end
    end
    return loss/size(var)[1]
end

function VarToLoss(var::Vector{Vector{ComplexF64}})
    loss = 0.0
    for i in 2:size(var)[1]
        for j in 1:i
            loss += real((var[i]-var[j])'*(var[i]-var[j]))            
        end
    end
    return loss/size(var)[1]
end

function VarToLoss(var::Vector{Adjoint{ComplexF64, Vector{ComplexF64}}})
    loss = 0.0
    for i in 2:size(var)[1]
        for j in 1:i
            loss += real((var[i]-var[j])*(var[i]-var[j])')            
        end
    end
    return loss/size(var)[1]
end

function wave_fn(var::Vector{Float64}, sw::Int)
    if(sw==1)
        wv_fn = ([cos(var[1]), sin(var[1])*exp(1.0im*var[2])])'
    else
        wv_fn = [cos(var[1]), sin(var[1])*exp(1.0im*var[2])]
    end
    return wv_fn
end

wave_fn (generic function with 1 method)

In [32]:
function Fn_Gauge(en::Env, sample::Sample, st::Vector{Int}, var_now, var_sub1, var_sub2, it::Int)
    if(it==0)
        return VarToLoss(var_now)
    else
        ac = st[it]
    end
    if(ac<100)
        #println("var")
        var_sub1 = var_now
        var_sub2 = var_sub1
        i_s = (ac-1)%en.Ns + 1
        sw = div(ac, en.Ns+1)
        var_now = [exp(1.0im*sample.gauge_sample[b,i_s])*wave_fn(sample.var[i_s], sw) for b in 1:en.n_batch]
    elseif(ac < 1000)
        #println("fn")
        #=
        if(typeof(var_now)==Vector{Matrix{ComplexF64}})
            var_now = en.op_fn[div(ac,100)](var_now)
        else
            return 100.0
        end=#
        try
            var_now = en.op_fn[div(ac,100)](var_now)
        catch
            return 100.0
        end
    else
        #println("br")
        #=
        if(div(ac,1000)<3 && typeof(var_now)!=typeof(var_sub1))
            return 100.0
        elseif(div(ac,1000)==3 && typeof(var_now)==typeof(var_sub1) && typeof(var_now)!=Vector{Matrix{ComplexF64}})
            return 100.0
        else
            var_now = en.op_br[div(ac,1000)](var_sub1, var_now)
            var_sub1 = nothing
            var_sub1 = var_sub2
        end=#
        try
            var_now = en.op_br[div(ac,1000)](var_sub1, var_now)
            var_sub1 = nothing
            var_sub1 = var_sub2
        catch
            return 100.0
        end
    end
    Fn_Gauge(en, sample, st, var_now, var_sub1, var_sub2, it-1)
end



function reward(en::Env, sample::Sample, ag::Agt)
    #var = Gene_Rand_Var()
    T = length(ag.state)
    #gauge_sample = 2pi*rand(Float64, n_batch, Ns)
    #st_copy = copy(ag.state)
    l = Fn_Gauge(en, sample, ag.state, nothing, nothing, nothing, T)
    return -l + 1.0
end

function q_update!(en::Env, ag::Agt, r::Float64)
    T = length(ag.state)
    q_max = Float32(r)
    for t in T:1
        ag.q_table[t, act_ind(ag.state[t],en)] = q_max
        q_max = maximum(ag.q_table[t,:])
    end
end

#function Search!(en::Env, dq::DQN, sample::Sample, ag::Agt)
function Search!(en::Env, dq::DQN, sample::Sample, ag::Agt, model)
    r = 0.0
    # = zeros(Float32, act_MAX)
    #q_table = []
    ag.state = []
    ag.branch = []
    for turn in 1:dq.act_MAX
        #println(turn)
        #ag.q_table[turn,:], act = decide_action!(en, dq, ag, turn)
        ag.q_table[turn,:], act = decide_action!(en, dq, ag,model, turn)
        if(act > 999) #actionでbinaryを選んだ場合、２つに分岐するので分岐点を覚えておくためにbranchに入れておく
            push!(ag.branch, act)
        elseif(act < 100)
            if(size(ag.branch)[1]==0)#残りのbranchがなければ関数形が完成しているので終了
                push!(ag.state, act)
                break;
            else #branchが残っていれば、下っ側を埋めていく
                b = pop!(ag.branch)
            end
        end
        if(rule_violate(ag, act)) #rule違反をしていたら、罰則(負の報酬)を与えて終了
            push!(ag.state, act)
            r = -100.0
            break;
        end

        push!(ag.state, act)
        if(turn == dq.act_MAX)
            r = -20.0
        end
        #r[turn] = reward(state, act)
    end
    if(r==0.0)
        r = reward(en, sample, ag)
    end
    q_update!(en, ag, r)
end

function act_ind(ac::Int, en::Env)
    id = 0
    if(ac<100)
        id = ac
    elseif(ac>999)
        id = en.num_var + ac%1000
    else
        id = en.num_var + en.num_br + ac%100
    end
    return id
end

function loss(dq::DQN,ag::Agt, model)
    T = length(ag.state)
    l=0.0
    for t in 1:T
        if(t==1)
            st_turn::Vector{Int} = []
        else
            st_turn = ag.state[1:t-1]
        end
        rem_turn = zeros(Int, dq.act_MAX - (t-1))
        st = vcat(st_turn, rem_turn) 
        #q = ag.model(st)
        q = model(st)
        l += sum((q - ag.q_table[t, :]).^2)
        #l += q'*q
    end
    return l
end

loss (generic function with 1 method)

In [14]:
function RandPolitics(en::Env, dq::DQN, sample::Sample, ag::Agt, model)
    ag.state = []
    ag.branch = []
    #branch::Vector{Int} = []
    r = 0.0
    # = zeros(Float32, act_MAX)
    #q_table = []
    for turn in 1:dq.act_MAX
        #ag.q_table[turn,:], act = decide_action!(en, dq, ag, turn)
        ag.q_table[turn,:], act = decide_action!(en, dq, ag,model, turn)
        if(act > 999) #actionでbinaryを選んだ場合、２つに分岐するので分岐点を覚えておくためにbranchに入れておく
            push!(ag.branch, act)
        elseif(act < 100)
            if(size(ag.branch)[1]==0)#残りのbranchがなければ関数形が完成しているので終了
                push!(ag.state, act)
                break;
            else #branchが残っていれば、下っ側を埋めていく
                b = pop!(ag.branch)
            end
        end
        push!(ag.state, act)
        if(rule_violate(ag, act)) #rule違反をしていたら、罰則(負の報酬)を与えて終了
            r = -100.0
            break;
        end
        #r[turn] = reward(state, act)
    end
    if(r == 0.0)
        r = reward(en, sample, ag)
    end
    
    return r, ag.state
end

RandPolitics (generic function with 1 method)

In [34]:
#N_s, n_batch
en = Env(init_Env(3, 10)...)
#width, act_MAX, ϵ
dq = DQN(init_DQN(16, 15, 0.05, en)...)
ag = Agt(init_agt(en, dq)...)

Agt(Int64[], Int64[], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0])

In [16]:
sample = Sample(get_Sample(en)...)

Sample([[0.0, 0.0], [0.7084727885637342, 0.2758179537758741], [1.668841825029368, -0.06482254457532582]], [0.0005941756221670947 3.89609678002478 1.8862549033122866; 4.076235115298894 2.690692944529253 3.464402780219548; … ; 1.1498116596137142 4.51334490713543 3.4244066778769096; 6.243135271965848 0.37355108415637667 1.7097892078910677])

In [36]:
model = Chain(Dense(dq.act_MAX, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, en.num_tot, init=Flux.zeros32, relu))

Chain(
  Dense(15 => 16, relu),                [90m# 256 parameters[39m
  Dense(16 => 16, relu),                [90m# 272 parameters[39m
  Dense(16 => 16, relu),                [90m# 272 parameters[39m
  Dense(16 => 16, relu),                [90m# 272 parameters[39m
  Dense(16 => 16, relu),                [90m# 272 parameters[39m
  Dense(16 => 12, relu),                [90m# 204 parameters[39m[36m  (all zero)[39m
) [90m                  # Total: 12 arrays, [39m1_548 parameters, 6.797 KiB.

In [38]:
ll_MAX = 20
ll_it = zeros(Float64, ll_MAX)
for it in 1:ll_MAX
    #var = Gene_Rand_Var()
    #gauge_sample = 2pi*rand(Float64, n_batch, Ns)
    sample = Sample(get_Sample(en)...)
    #Search!(en, dq, sample, ag)
    Search!(en, dq, sample, ag, model)
    #println(ag.state)
    #ll = 0.0
    #grads = Flux.gradient(Flux.params(ag.model)) do
    grads = Flux.gradient(Flux.params(model)) do
        #l = loss(dq,ag)
        loss(dq,ag, model)/length(ag.state)
    end
    ll_it[it] = loss(dq,ag, model)
    print("loss:$(ll_it[it]), ")
    println(ag.state)
    #Flux.Optimise.update!(ADAM(), Flux.params(ag.model), grads)
    Flux.Optimise.update!(ADAM(), Flux.params(model), grads)
end

loss:78.21925354003906, [3]
loss:0.0, [1000, 5, 6]
loss:0.0, [1000, 300, 200, 6, 2000, 200, 1, 4]
loss:0.0, [2000, 1, 1]
loss:0.0, [300, 200, 5]
loss:0.0, [3000, 3, 200, 3000, 1000, 4, 1, 100, 6]
loss:61.607173919677734, [6]
loss:64.69889068603516, [1]
loss:0.0, [100, 1000, 5, 1000, 300, 1000, 200, 4, 4]
loss:75.24616241455078, [2]
loss:0.0, [2000, 200, 5, 200, 2000, 300, 2, 300, 1]
loss:34.64826583862305, [1]
loss:0.0, [1000, 6, 1]
loss:0.0, [1000, 4, 300, 3000, 2000, 2000]
loss:61.39663314819336, [6]
loss:0.0, [3000, 100, 4, 6]
loss:79.0329360961914, [1]
loss:0.0, [200, 4]
loss:0.0, [300, 3000, 200, 1, 100, 2]
loss:0.0, [200, 300, 4]


In [27]:
is_test = (2-1)%en.Ns + 1
sw_test = div(2, en.Ns+1)
var_now = [exp((2sw_test-1)*1.0im*sample.gauge_sample[b,is_test])*wave_fn(sample.var[is_test], sw_test) for b in 1:en.n_batch]

1-element Vector{Vector{ComplexF64}}:
 [-0.5532756888199392 + 0.5201040998356157im, -0.5775395376605479 + 0.2997095592967405im]

In [30]:
is_test1 = (5-1)%en.Ns + 1
sw_test1 = div(5, en.Ns+1)
var_now1 = [exp((2sw_test1-1)*1.0im*sample.gauge_sample[b,is_test1])*wave_fn(sample.var[is_test1], sw_test1) for b in 1:en.n_batch]

1-element Vector{Adjoint{ComplexF64, Vector{ComplexF64}}}:
 [-0.5532756888199392 - 0.5201040998356157im -0.5775395376605479 - 0.2997095592967405im]

In [31]:
var_now1[1]*var_now[1]

1.0 - 1.8924777989989362e-17im

In [24]:
w1 = wave_fn(sample.var[2], 0)
w2 = wave_fn(sample.var[2], 1)
w2*w1

1.0 - 4.9765067934427186e-18im

In [16]:
Search!(en, dq, sample, ag)

1
1
1


In [19]:
println(ag.state)
println(ag.branch)
println(ag.q_table)

[1]
Int64[]
Float32[-21.252821 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.

In [24]:
Search!(en, dq, sample, ag)

1
0
20
20
1
1


In [25]:
ll = loss(dq,ag)

451.6824035644531

In [32]:
function loss(dq::DQN,ag::Agt)
    T = length(ag.state)
    l=0.0
    for t in 1:T
        if(t==1)
            st_turn::Vector{Int} = []
        else
            st_turn = ag.state[1:t-1]
        end
        rem_turn = zeros(Int, dq.act_MAX - (t-1))
        st = vcat(st_turn, rem_turn) 
        q = ag.model(st)
        #l += sum((q - ag.q_table[t, :]).^2)
        l += q'*q
    end
    return l
end

loss (generic function with 1 method)

In [33]:
grads = Flux.gradient(Flux.params(ag.model)) do
    ll = loss(dq,ag)
end

Grads(...)

In [31]:
st_turn = []
rem_turn = zeros(Int, dq.act_MAX)
st = vcat(st_turn, rem_turn)

20-element Vector{Any}:
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0

In [34]:
test1 = [1.0, 2.0, 0.0, 2.0, 2.0]
argmax(test1)

2

In [35]:
findall(test1 .== maximum(test1))

3-element Vector{Int64}:
 2
 4
 5

In [36]:
1000/1000

1.0

In [1]:
@load "./mymodel.bson" model
@show model

LoadError: LoadError: UndefVarError: @load not defined
in expression starting at /Users/michishita/Documents/Codes/julia/MCTS_GI/e-greedy_test.ipynb:1

In [14]:
@load "./mymodel.bson" model
@show model
test = [0, 0, 0, 0]
A = model(test)

model = Chain(Dense(4 => 8), Dense(8 => 8), Dense(8 => 8), Dense(8 => 5))


5-element Vector{Float32}:
 -8.183295
 -7.7837644
 -6.1905055
  4.3143835
 -3.0684485

In [15]:
test = [4, 0, 0, 0]
A = model(test)

5-element Vector{Float32}:
 -16.232706
 -15.39147
 -13.146998
   9.454458
  -5.539211

In [10]:
test = [5, 10, 0, 0, 0]
A = model(test)

5-element Vector{Float32}:
  -6.812312
  -6.8972197
 -10.720585
 -11.248022
 -12.042385

In [126]:
softmax(A)

5-element Vector{Float32}:
 0.19777422
 0.8005499
 0.00073171436
 0.0004198221
 0.0005244279

In [81]:
test = []
test[end]

LoadError: BoundsError: attempt to access 0-element Vector{Any} at index [0]

In [64]:
model = Chain(Dense(dq.act_MAX, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, en.num_tot, init=Flux.ones32))

Chain(
  Dense(15 => 16, relu),                [90m# 256 parameters[39m
  Dense(16 => 16, relu),                [90m# 272 parameters[39m
  Dense(16 => 16, relu),                [90m# 272 parameters[39m
  Dense(16 => 16, relu),                [90m# 272 parameters[39m
  Dense(16 => 16, relu),                [90m# 272 parameters[39m
  Dense(16 => 12),                      [90m# 204 parameters[39m
) [90m                  # Total: 12 arrays, [39m1_548 parameters, 6.797 KiB.

In [58]:
test = 1000*ones(Float32, 15)
a=model(test)
println(a)
println(softmax(a))

Float32[-4.133523, -15.241394, -18.704393, -7.3172717, 23.11225, 23.308975, 32.77651, 3.095732, 4.5304666, 3.2952948, 23.721226, -10.427928]
Float32[9.333953f-17, 1.3995167f-21, 4.3854665f-23, 3.867056f-18, 6.3497086f-5, 7.7302f-5, 0.99974245, 1.2873322f-13, 5.404903f-13, 1.5716657f-13, 0.00011674247, 1.7236118f-19]


In [43]:
test = []
push!(test, 1)
push!(test, 2)
push!(test, 3)

3-element Vector{Any}:
 1
 2
 3

In [48]:
pop!(test)

1

In [49]:
test

Any[]

In [50]:
size(test)[1]

0

In [1]:
maximum([1.0, 5.0, 2.0])

5.0

In [32]:
include("DQN_MC.jl")

main (generic function with 1 method)

In [2]:
#N_s, n_batch, penalty
en = Env(init_Env(1, 10, 100.0)...)
#width, act_MAX, ϵ
dq = DQN(init_DQN(16, 5, 0.005, en)...)
ag = Agt(init_agt(en, dq)...)

model = Chain(Dense(dq.act_MAX, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, dq.width, relu), Dense(dq.width, en.num_tot))
#model = Chain(Dense(dq.act_MAX, dq.width), Dense(dq.width, dq.width) , Dense(dq.width, dq.width), Dense(dq.width, en.num_tot))
#, Dense(dq.width, dq.width, relu) , Dense(dq.width, dq.width, relu),

η = 0.005

0.005

In [20]:
sample = Sample(get_Sample(en)...)
Search!(en, dq, sample, ag, model)

In [21]:
println(ag.state)

[1, 3, 10, 20, 2]


In [34]:
@show ag.q_table

ag.q_table = Float32[0.010017703 -0.0049412083 1.7958635 -0.0033347472 0.0026585727; -0.008794929 0.088410944 -0.12989365 0.0016303277 1.7958635; 1.7958635 0.046672914 0.13063908 0.058654938 -0.097600676; 1.029346 1.7958635 -0.11251773 -0.25043476 -0.16263339; 0.15029517 0.22413763 1.7958635 -99.0 -1.7459356]


5×5 Matrix{Float32}:
  0.0100177   -0.00494121   1.79586    -0.00333475   0.00265857
 -0.00879493   0.0884109   -0.129894    0.00163033   1.79586
  1.79586      0.0466729    0.130639    0.0586549   -0.0976007
  1.02935      1.79586     -0.112518   -0.250435    -0.162633
  0.150295     0.224138     1.79586   -99.0         -1.74594

In [26]:
reward(en, sample, ag)

-99.0

In [33]:
q_update!(en, ag, -99.0)

-99.0
-99.0
1.7958635
1.7958635
1.7958635
1.7958635


In [35]:
grads = Flux.gradient(Flux.params(model)) do
    #l = loss(dq,ag)
    loss(dq,ag, model)/length(ag.state)
end
l = loss(dq,ag, model)

10035.52076625824

In [36]:
Flux.Optimise.update!(ADAM(η), Flux.params(model), grads)

In [15]:
best_score = -100.0
best_state::Vector{Int} = []
n_replay = 0
n_up = 0

0

In [37]:
best_score

-100.0

In [38]:
reward(en, sample, ag)

-99.0

In [16]:
(best_score > reward(en, sample, ag))

false

In [17]:
if(best_score > reward(en, sample, ag))
    n_replay += 1
    ag.state = best_state
    grads = Flux.gradient(Flux.params(model)) do
        loss(dq,ag, model)/length(ag.state)
    end
    Flux.Optimise.update!(ADAM(η), Flux.params(model), grads)
else
    n_up +=1
    best_state = ag.state
    best_score = reward(en, sample, ag)
end

1-element Vector{Int64}:
 20

In [40]:
println("replay: $(n_replay)")
println("noup: $(n_up)")
println(best_score)
println(best_state)

replay: 0
noup: 1
-100.0
[20]
