In [3]:
struct MDP 
    γ 
    S
    A 
    T 
    R 
    TR 
end 


In [11]:
function lookahead(P::MDP, U, s, a)
    S,T,R,γ = P.S, P.T, P.R, P.γ
    return R(s,a) + γ * sum(T(s,a,s_)*U(s_) for s_ in S)
end 

# when U is not a fucntion but a vector s.t. each element corresponds to each state variation 
function lookahead(P::MDP, U::Vector, s, a)  
    S,T,R,γ = P.S, P.T, P.R, P.γ
    return R(s,a) + γ * sum(T(s,a,s_)*U[i] for (i,s_) in enumerate(S))
end 


Lookahead (generic function with 2 methods)

In [None]:
function gredy(P::MDP, U, s)
    u,a = findmax(a->lookahead(P,U,s,a), P.A)
    return (a=a, u=u)
end 


### Lookahead with rollout

In [4]:
struct RolloutLookahead
    P
    π
    d
end 

In [9]:
randstep(P::MDP, s, a) = P.TR(s,a)

function rollout(P, s, π, d)
    ret = 0.0 
    for t in 1:d
        a = π(s)
        s,r = randstep(P,s,a)
        ret += P.γ^(t-1) * r
    end 
    return ret 
end 

function (π::RolloutLookahead)(s)
    U(s) = rollout(π.P, s, π.π, π.d)  # this sets up the utility function for any s 
    return greedy(π.P, U, s).a
end 


### Forward Search 

In [13]:
struct ForwardSearch
    P
    d 
    U 
end 

function forward_search(P, s, d, U)
    if d <= 0 
        return (a=nothing, u=U(s))
    end 
    
    # if there are some depth left to go 
    best = (a=nothing, u=-Inf)
    U_(s) = forward_search(P,s,d-1,U).u
    for a in P.A
        u = lookahead(P, U_, s, a)
        if u > best.u 
            best = (a=a, u=u)
        end 
    end 
    return best 
end 

(π::ForwardSearch)(s) = forward_search(π.P, s, π.d, π.U).a

### Branch and Bound 

In [16]:
struct BranchAndBound 
    P
    d
    Ulo  # this is a function of s. lb on value function U(s) at depth d
    Qhi  # this is a function of (s,a). ub on the action value function 
end

function branch_and_bound(P,s,d,Ulo,Qhi)
    if d <= 0 
        return (a=nothing, u=Ulo(s))
    end 
    U_(s) = branch_and_bound(P,s,d-1,Ulo,Qhi).u
    best = (a=nothing, u=-Inf)
    for a in sort(P.A, by=a->Q(s,a), rev=true)  # rev=true ; big -> small 
        if Qhi(s,a) < best.u
            return best   # safe to prune 
        end 
        # Qhi > best.u  -> it could be the best one! so explore a bit more
        u = lookahead(P,U_,s,a)
        if u > best.u  # yay! we find the best utility ever  
            best = (a=a, u=u)
        end 
    end 
    return best 
end 

(π::BranchAndBound)(s) = branch_and_bound(π.P, s, π.d, π.Ulo, π.Qhi)        

### MCTS

In [17]:
struct MonteCarloTreeSearch
    P
    N   # visit count
    Q   # action-value estimate
    d   
    m   # num of simulations
    c   # exploration constant 
    U   # value function estimate 
end 

function (π::MonteCarloTreeSearch)(s)
    for k in 1:π.m
        simulate!(π,s)
    end 
    return argmax(a->π.Q[(s,a)], π.P.A)
end 


In [18]:
function simulate!(π::MonteCarloTreeSearch, s, d=π.d)
    if d <= 0 
        return π.U(s)
    end 
    P, N, Q, c = π.P, π.N, π.Q, π.c
    A, TR, γ = P.A, P.TR, P.γ
    
    if !haskey(N,(s,first(A)))  # if (s,a) has never been visited
        for a in A 
            N[(s,a)] = 0 
            Q[(s,a)] = 0
        end 
        return π.U(s)    # make an zero utility function, and that's it for now 
    end 
    
    # if (s,a) has been visited
    a = explore(π,s)
    s_, r = TR(s,a) 
    q = r + γ * simulate!(π, s_, d-1) 
    N[(s,a)] += 1  # +1 for visit count 
    Q[(s,a)] += (q-Q[(s,a)]) / N[(s,a)]  # The more you visited, the update of the Q will (usually) converge 
    return q 
end 

bonus(Nsa, Ns) = Nsa==0 ? Inf : sqrt(log(Ns)/Nsa)

function explore(π::MonteCarloTreeSearch, s)
    A, N, Q, c = π.P.A, π.N, π.Q, π.c
    Ns = sum(N[(s,a)] for a in A)
    # objective = Q+bonus term 
    # if there is no past visit, then that exploration is always prioritized 
    return argmax(a -> Q[(s,a)] + c*bonus(N[(s,a)], Ns), A)  
end 


explore (generic function with 1 method)