## 探索空間の定義
まず探索空間を定義する。agentにどのような演算(action)を許すかを決めておく。

ここで、演算の種類として、
1. binary operator:  そこから２つに枝分かれするもの。(eg. +,-,*,/,pow、[,])
2. unitary operator: 関数の中に打ち込むもの。(eg. exp[], log[], diag[],)
3. containing variable: 変数(M(matrix), θ, π)。終端になる。
4. terminal operator: 出発点(最後に行う演算)にしかおけないoperator (Tr[], minimum(eigen())(行列の固有値の最小値を参照)など )

を準備しておく。これによって、木の形でゲージ不変量の表式を表現することが出来る。

In [1]:
using LinearAlgebra
using Random
using Flux
using Distributions
using OneHotArrays

In [None]:
#=struct SearchField
    num_var::Int #変数の数。ここではサイトごとの波動関数を演算に使うはずなので、2*サイト数で良いはず
    num_bin::Int #binary operatorの数。多分上記の +, -, *, -i[,], {,} の５つ？([,]などは*と+で表現できるが、のちに変数は２回以上使わないという制約を課したいので、変数を一回使うだけで交換関係を表現できるように導入しておく)
    num_uni::Int #unitary operatorの数。 多分exp[], log[], diag[]の３つ。
    num_ter::Int #terminal operatorの数。最終的に複素数に落とすためのもの。Tr[]とminimum(eigen())とか？今回はTr[]に限定。
    num_tot::Int
    
    operation
end=#


In [2]:
Ns = 3 #siteの数
num_var = 2*Ns  #operatorの数(ψ_iとψ'_i)
num_br = 3 #binary operatorの数。多分上記の +, -, *, -i[,], {,} の５つ？([,]などは*と+で表現できるが、のちに変数は２回以上使わないという制約を課したいので、変数を一回使うだけで交換関係を表現できるように導入しておく)1000の位に格納
num_fn = 3 #unitary operatorの数。 多分exp[], log[], diag[]の３つ。100の位に格納
num_ter = 1
num_tot = num_var + num_br + num_fn

12

後々のために、fnについて数字を指定した時にexp, log, diagを返すような関数のベクトルを用意しておく

In [3]:
function set_fn()
    function daig_mat(M)
        e,v = eigen(M)
        return diagm(e)
    end
    return [x->exp(x), x->log(x), x->daig_mat(x)]
end

set_fn (generic function with 1 method)

In [4]:
fn_act = set_fn()

3-element Vector{Function}:
 #1 (generic function with 1 method)
 #2 (generic function with 1 method)
 #3 (generic function with 1 method)

In [21]:
println(fn_act[1](0.0))
println(fn_act[1](1.0)) 
println(fn_act[3]([1.0 2.0; 2.0 1.0]))

1.0
2.718281828459045
[-1.0 0.0; 0.0 3.0]


In [5]:
function set_br()
    return [(x,y)->(x+y), (x,y)->(x-y), (x,y)->(x*y)]
    #return [(x,y)->(x+y), (x,y)->(x-y), (x,y)->(x*y), (x,y)->-1.0im*(x*y .- y*x), (x,y)->(x*y .+ y*x)/2]
end

set_br (generic function with 1 method)

In [6]:
op_br = set_br()

3-element Vector{Function}:
 #8 (generic function with 1 method)
 #9 (generic function with 1 method)
 #10 (generic function with 1 method)

In [21]:
A = [1.0, 2.0]
B = [0.0, 1.0]'

1×2 adjoint(::Vector{Float64}) with eltype Float64:
 0.0  1.0

In [27]:
op_br[3](B, A)

2.0

In [7]:
#NNが出力したQ値からActionを決める時Vectorを行動を示す整数に変換する
conv_ac = zeros(Int, num_tot)
for i in 1:num_var
    conv_ac[i] = i
end
for i in 1:num_br
    conv_ac[i+num_var] = 1000i
end
for i in 1:num_fn
    conv_ac[i+num_var+num_br] = 100i
end

Deep-Q networkのための設定と、ε-greedy法を使うための確率関数の準備

In [8]:
width = 64
act_MAX = 40
ϵ = 0.05
prob = [ϵ, 1-ϵ]
rand_ac = ones(Float64, num_tot)./num_tot


12-element Vector{Float64}:
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333
 0.08333333333333333

In [18]:
for i in 1:10
    println(rand(Categorical(rand_ac)))
end

14
6
4
13
9
10
12
1
6
4


In [31]:
model = Chain(Dense(act_MAX, width, relu), Dense(width, width, relu), Dense(width, num_tot))

Chain(
  Dense(40 => 64, relu),                [90m# 2_624 parameters[39m
  Dense(64 => 64, relu),                [90m# 4_160 parameters[39m
  Dense(64 => 12),                      [90m# 780 parameters[39m
) [90m                  # Total: 6 arrays, [39m7_564 parameters, 29.922 KiB.

In [32]:
function action_vec(q_t::Vector{Float32})
    sw = rand(Categorical(prob))
    if(sw == 1)
        act = rand(Categorical(rand_ac))
    else
        act = argmax(q_t)
    end
    return onehot(Int, 1:num_tot, act)
end


action_vec (generic function with 2 methods)

In [30]:
function decide_action!(mm, state::Vector{Int}, t::Int)
    rem_turn = zeros(Int, act_MAX + 1 - t)
    st_vec = vcat(state, rem_turn)
    q_t = mm(st_vec)
    act = conv_ac' * action_vec(q_t) 
    return q_t, act
end

decide_action! (generic function with 1 method)

In [11]:
function rule_violate(st::Vector{Int}, ac::Int)
    if(ac>99)
        if(ac>999)
            if(ac == st[end])
                return true
            elseif(ac<3 && st[end]<3)
                return true
            else
                return false
            end
        else
            if((ac==1 && st[end]==2) || (ac==2 && st[end]==1))
                return true
            elseif(ac== st[end])
                return true
            else
                return false
            end
        end
    else
        if(length(findall(isequal(ac),st))==0)
            return false
        else
            return true
        end
    end
end

rule_violate (generic function with 1 method)

In [3]:
gauge_test = pi*[0.0, 0.1, 0.3, 0.5, 0.8, 1.0]
st_test = [0.2, 0.8]
st_test2 = [gauge_test[i]*st_test for i in 1:6]

6-element Vector{Vector{Float64}}:
 [0.0, 0.0]
 [0.06283185307179587, 0.25132741228718347]
 [0.1884955592153876, 0.7539822368615504]
 [0.3141592653589793, 1.2566370614359172]
 [0.5026548245743669, 2.0106192982974678]
 [0.6283185307179586, 2.5132741228718345]

In [12]:
function VarToLoss(var::Vector{ComplexF64})
    loss = 0.0
    for i in 2:size(var)[1]
        for j in 1:i
            loss += abs(var[i]-var[j])^2            
        end
    end
    return loss
end


VarToLoss (generic function with 1 method)

In [13]:
function wave_fn(var::Vector{Float64}, sw::Int)
    if(sw==1)
        wv_fn = ([cos(var[1]), sin(var[1])*exp(1.0im*var[2])])'
    else
        wv_fn = [cos(var[1]), sin(var[1])*exp(1.0im*var[2])]
    end
    return wv_fn
end

wave_fn (generic function with 1 method)

In [34]:
function Fn_Gauge(st::Vector{Int}, var::Vector{Vector{Float64}}, gauge_sample::Matrix{Float64}, var_sub, var_now)
    if(length(st)==0)
        return VarToLoss(var_now)
    else
        ac = pop!(st)
    end
    if(ac<100)
        if(var_now =! nothing)
            var_sub = var_now
        end
        #縦横を区別するためにやっぱりここで作り直すべき
        i_s = ac%(Ns+1) + 1
        sw = div(ac, Ns+1)
        var_now = [exp(1.0im*gauge_sample[b,ac])*wave_fn(var[i_s], sw) for b in 1:n_batch]
    elseif(ac < 1000)
        var_now = fn_act[ac%100](var_now)
    else
        var_sub, var_now = op_br[ac%1000](var_sub, var_now)
    end
    Fn_Gauge(st, var, gauge_sample, var_sub, var_now)
end


Fn_Gauge (generic function with 1 method)

In [22]:
n_level = 2
n_batch = 20

20

In [23]:
function Gene_Rand_Var()
    var::Vector{Vector{Float64}} = []
    θ = zeros(Float64,2)
    push!(var, θ)
    for i in 1:Ns-1
        θ += [pi*(1.0+rand(Float64))/2Ns , pi*(0.5-rand(Float64))/Ns]
        push!(var, θ)
    end
    return var
end

Gene_Rand_Var (generic function with 1 method)

In [24]:
function reward(st::Vector{Int}, var::Vector{Vector{Float64}}, gauge_sample::Matrix{Float64})
    #var = Gene_Rand_Var()
    #gauge_sample = 2pi*rand(Float64, n_batch, Ns)
    st_copy = st
    l = Fn_Gauge(st_copy, var, gauge_sample, nothing, nothing)
    return -l
end
    

reward (generic function with 1 method)

In [25]:
function act_ind(ac::Int)
    id = 0
    if(ac<100)
        id = ac
    elseif(ac>999)
        id = num_var + ac%1000
    else
        id = num_var + num_br + ac%100
    end
    return id
end

act_ind (generic function with 1 method)

In [26]:
function loss(r::Float64, q_t::Matrix{Float64}, st::Vector{Int})
    T = length(st)
    l = (q_t[T, act_ind(st[T])] - r)^2
    q_t[T, act_ind(st[T])] = r
    q_max = maximum(q_t[T,:])
    for i in 1:T-1
        l += (q_max - q_t[T-i, act_ind(st[T-i])])^2
        q_t[T-i, act_ind(st[T-i])] = q_max
        q_max = maximum(q_t[T-i,:])
    end
    return l
end 

loss (generic function with 1 method)

In [27]:
function train_search(mm, var::Vector{Vector{Float64}}, gauge_sample::Matrix{Float64})
    state::Vector{Int} = []
    branch::Vector{Int} = []
    r = 0.0
    # = zeros(Float32, act_MAX)
    q_table = []
    for turn in 1:act_MAX
        q_t, act = decide_action!(mm, state, turn)
        if(act > 999) #actionでbinaryを選んだ場合、２つに分岐するので分岐点を覚えておくためにbranchに入れておく
            push!(branch, act)
        elseif(act < 100)
            if(size(branch)[1]==0)#残りのbranchがなければ関数形が完成しているので終了
                push!(state, act)
                break;
            else #branchが残っていれば、下っ側を埋めていく
                b = pop!(branch)
            end
        end
        if(rule_violate(state, act)) #rule違反をしていたら、罰則(負の報酬)を与えて終了
            r = -100.0
            break;
        end

        push!(state, act)
        #r[turn] = reward(state, act)
    end
    r = reward(state, var, gauge_sample)
    
    return loss(r, q_t, state)
end

train_search (generic function with 1 method)

In [28]:
model = Chain(Dense(act_MAX, width, relu, init=Flux.zeros32), Dense(width, width, relu, init=Flux.zeros32), Dense(width, num_tot, init=Flux.zeros32))

Chain(
  Dense(40 => 64, relu),                [90m# 2_624 parameters[39m[36m  (all zero)[39m
  Dense(64 => 64, relu),                [90m# 4_160 parameters[39m[36m  (all zero)[39m
  Dense(64 => 12),                      [90m# 780 parameters[39m[36m  (all zero)[39m
) [90m                  # Total: 6 arrays, [39m7_564 parameters, 29.922 KiB.

In [35]:
ll_it = zeros(Float64, 100)
for it in 1:100
    var = Gene_Rand_Var()
    gauge_sample = 2pi*rand(Float64, n_batch, Ns)
    ll = 0.0
    grads = Flux.gradient(Flux.params(model)) do
        ll = train_search(model, var, gauge_sample)
    end
    ll_it[it] = ll
    Flux.Optimise.update!(ADAM(), Flux.params(model), grads)
end

LoadError: MethodError: no method matching !(::Nothing)
[0mClosest candidates are:
[0m  !([91m::Function[39m) at operators.jl:1077
[0m  !([91m::Bool[39m) at bool.jl:35
[0m  !([91m::Missing[39m) at missing.jl:101

In [38]:
max(maximum([1.0, 2.0]), 1.0)

2.0

In [50]:
map([1,2,3]) do x sin(x) end

3-element Vector{Float64}:
 0.8414709848078965
 0.9092974268256817
 0.1411200080598672

In [10]:
function test_square(x)
    H = [x[1]^2 x[2]^2; x[1]^2 x[2]^2]
    e = tr(H)
    return e
end


test_square (generic function with 1 method)

In [2]:
x = [1.0, 0.5]

2-element Vector{Float64}:
 1.0
 0.5

In [11]:
grads = gradient(test_square, x)

([2.0, 1.0],)

In [5]:
function test2_square(x)
    H = zeros(Float64, 2, 2)
    H[1,1] = x[1]^2
    H[1,2] = x[2]^2
    H[2,1] = x[1]^2
    H[2,2] = x[2]^2
    e = tr(H)
    return e
end

test2_square (generic function with 1 method)

In [11]:
grads2 = gradient(test2_square, x)

([2.0, 1.0],)

In [10]:
function ChainRulesCore.rrule(::typeof(test2_square), x)
    y = test2_square(x)
    _axis = length(x)
    function test2_square_pullback(Δ)
        H1 = zeros(Float64, 2, 2)
        H1[1,1] = 2x[1]
        H1[1,2] = 0.0
        H1[2,1] = 2x[1]
        H1[2,2] = 0.0
        e1 = tr(H1)

        H2 = zeros(Float64, 2, 2)
        H2[1,1] = 0.0
        H2[1,2] = 2x[2]
        H2[2,1] = 0.0
        H2[2,2] = 2x[2]
        e2 = tr(H2)
        e = [e1, e2]

        return NoTangent(), @thunk(e*Δ)
    end
    return y, test2_square_pullback
end
        

In [3]:
using ChainRulesCore
using Zygote