## Implement Monte Carlo Early Start as dynamical system

### Test GraphPlot.jl package

In [None]:
using LightGraphs
using GraphPlot

In [None]:
# number of states
N = 5

# mask with network connections
M = rand([0, 1], N, N)
# reward matrix
R = M .* rand(N,N)

# construct directed graph
g = SimpleDiGraph(N)

for i=1:N
    for j=1:N
        if M[i,j]==1
            add_edge!(g, i, j)
        end
    end
end

nodelabel = 1:nv(g)
edgelabel = 1:LightGraphs.ne(g)

gplot(g, nodelabel=nodelabel, edgelabel=edgelabel)

### Test POMDPs.jl package

In [None]:
using POMDPs
using POMDPModels
using POMDPSimulators
using POMDPPolicies

In [None]:
T = zeros(2,3,2) # |S|x|A|x|S|, T[s', a, s] = p(s'|a,s)
T[:,:,1] = [1. 0.5 0.5; 
            0. 0.5 0.5]
T[:,:,2] = [0. 0.5 0.5; 
            1. 0.5 0.5]

# O = zeros(2,3,2) # |O|x|A|x|S|, O[o, a, s] = p(o|a,s)
# O[:,:,1] = [0.85 0.5 0.5; 
#             0.15 0.5 0.5]
# O[:,:,2] = [0.15 0.5 0.5; 
#             0.85 0.5 0.5]

R = [-1. -100. 10.; 
     -1. 10. -100.] # |S|x|A| state-action pair rewards

discount = 0.95

pomdp = TabularMDP(T, R, discount);


# policy that takes a random action
policy = RandomPolicy(pomdp)

for (s, a, r) in stepthrough(pomdp,policy, "s,a,r", max_steps=10)
    @show s
    @show a
    @show r
    println()
end

### Test own Exploring States algorithm

In [None]:
using Plots

In [None]:
M = 5
gam = 0.9
# number of terminal states
n_fin = 2

R = [ [  0   0   0   0   0] ; 
      [  0   0   0   0   0] ;
      [  0   0   0   0   0] ;
      [-10 -10   0   0   0] ;
      [  0   0  10   0   0] ]

C = [ [ 0 0 1 0 0 ] ;
      [ 1 0 0 0 0 ] ;
      [ 0 1 0 0 0 ] ;
      [ 1 1 0 1 0 ] ;
      [ 0 0 1 0 1 ] ]

T = [ [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 1 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] ;
      [ 0 1 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] ;
      [ 0 0 0 0 0 ] [ 0 0 1 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] ;
      [ 0 0 0 1 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] ;
      [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] [ 0 0 0 0 0 ] ]


In [None]:
q = rand(M^2, 1)

a = 0.

#### old test does not work

In [None]:
# Q = 10 * C .* rand(M,M)
Q = 0 * ones(M,M) + C
N = zeros(M,M)

In [None]:
# no exploring states
# wrong method that should not converge

K = Int(1e5)
Q_all = zeros(M,M,K)

for k = 1:K
   # state matrix
    A = zeros(M,M)
    vals,inds = findmax(Q, dims=1)
    A[inds] .= 1
    
#     display("A: ")
#     display(A)
    
    nz_idx = findall(!iszero, C[:,1:end-n_fin])
    i = rand(1:length(nz_idx))
    s_t = map( x -> (x==nz_idx[i][2]), 1:M )
    u_t = map( x -> (x==nz_idx[i][1]), 1:M ) - A*s_t
    N_t = N
    E_t = zeros(M,M)
    Q_t = Q
    
    
    for t = 1:M

        s_tp1 = A*s_t + u_t
        
       
        N_tp1 = N_t + s_tp1*transpose(s_t)

#         display("N: ")
#         display(N_tp1)
        
        E_tp1 = gam*E_t + s_tp1*transpose(s_t) / (transpose(s_tp1)*N_tp1*s_t)

        Q_tp1 = Q_t + ( transpose(s_tp1)*R*s_t +
                        gam*transpose(A*s_tp1)*Q*s_tp1 -
                        transpose(s_tp1)*Q*s_t ) * E_tp1

        s_t, N_t, E_t, Q_t = s_tp1, N_tp1, E_tp1, Q_tp1
        u_t = zeros(M)
    end

    N = N_t
    Q = Q_t
    
    Q_all[:,:,k] = Q
end

display(Q)

In [None]:
display(gam*transpose(A*s_tp1)*Q*s_tp1)

In [None]:
using Plots

plot( 1:K, transpose(reshape(Q_all, (M^2,K))) )