Approximation to continuous POMDP implemented using `github.com/JuliaPOMDP/QuickPOMDPs.jl` and solved using POMCP with reference to tutorial [here](https://htmlview.glitch.me/?https://github.com/JuliaAcademy/Decision-Making-Under-Uncertainty/blob/master/html/4-Approximate-Methods.jl.html).

In [1]:
using POMDPs, QuickPOMDPs, POMDPModelTools, POMDPPolicies, Parameters, Random, Plots, LinearAlgebra
using POMDPTools, BasicPOMCP, D3Trees, GridInterpolations, POMCPOW, POMDPModels, Combinatorics, Dates

In [2]:
expID = Dates.format(Dates.now(), "yymd_HHMMS")

"22820_163235"

In [3]:
function log(s::String)
    s_time = Dates.format(Dates.now(), "HH:MM:SS\t")*s*"\n"
    open("./logs/"*expID*".txt", "a") do file
        write(file, s_time)
    end
    print(s_time)
end

log (generic function with 1 method)

In [4]:
log("Running experiment with ID "*expID)

16:32:36	Running experiment with ID 22820_163235


# Define Problem

In [5]:
@with_kw struct MyParameters
    N::Int = 4         # size of item set
    K::Int = 3         # size of arm set
    M::Int = 2         # size of beta set
    y::Real = 0.99     # discount factor
    umax::Real = 10    # max utility
    u_grain:: Int = 2  # granularity of utility approximation
    d_grain:: Int = 5  # granularity of arm distribution approximation
    beta:: Array{Float64} = [0.01, 10.0]  # teacher beta values
end

params = MyParameters()
log(string(params))

16:32:38	MyParameters
  N: Int64 4
  K: Int64 3
  M: Int64 2
  y: Float64 0.99
  umax: Int64 10
  u_grain: Int64 2
  d_grain: Int64 5
  beta: Array{Float64}((2,)) [0.01, 10.0]



# Create POMDP

In [6]:
struct State
    u::Array{Float64}         # list of N utility values for N items
    d::Array{Array{Float64}}  # list of K arm distributions, each assigning probabilities to N items
    b::Array{Float64}         # list of M beta values
end

In [7]:
# space of utility functions
@time begin
    umin = 0
    grid_coor = fill(range(umin,params.umax,length=params.u_grain), params.N)
    U = RectangleGrid(grid_coor...)
end

@assert length(U[1]) == params.N
log("generated "*string(length(U))*" utilities (each length "*string(length(U[1]))*" items)")

  0.316692 seconds (493.67 k allocations: 27.147 MiB, 99.73% compilation time)
16:32:38	generated 16 utilities (each length 4 items)


In [8]:
function generate_probability_distributions(N::Int, coor::Array{Float64}, S::Float64=1.0)
    if S == 0
        return [[0. for _ in 1:N]]
    end
    if N == 1
        return [[float(S)]]
    end
    out = []
    range = coor[1:findall(x->isapprox(x,S,atol=1e-15), coor)[1]]
    for k in range
        subsolution = generate_probability_distributions(N-1, coor, S-k)
        for lst in subsolution
            if typeof(lst[1]) != Float64
                log("ERROR: lst "*string(lst)*" has type "*string(typeof(lst[1]))*". Must be Float64.")
            end
            prepend!(lst, float(k))
        end
        out = vcat(out, subsolution)
    end
    return out
end

generate_probability_distributions (generic function with 2 methods)

In [28]:
# space of arm distributions
@time begin
    coor = collect(range(0.,1.,length=params.d_grain))    
    simplex_list = generate_probability_distributions(params.N, coor)
    D_tuples = vec(collect(Base.product(fill(simplex_list, params.K)...)))
    D = [collect(d) for d in D_tuples]
end

@assert length(D[1]) == params.K
@assert length(D[1][1]) == params.N
log(string("generated "*string(length(D))*" arm distribution sets (each shape "*string(length(D[1]))*" arms x "*string(length(D[1][1]))*" items)"))
    
    

  0.047984 seconds (139.27 k allocations: 8.111 MiB, 87.50% compilation time)
16:34:31	generated 42875 arm distribution sets (each shape 3 arms x 4 items)


In [29]:
# beta values
B = [params.beta]

# each beta value set must be length M
@assert length(B[1]) == params.M
log(string("generated "*string(length(B))*" beta value sets (each length "*string(length(B[1]))*" teachers)"))

16:34:32	generated 1 beta value sets (each length 2 teachers)


In [11]:
# State space
@time begin     
    S = [[State(u,d,b) for u in U, d in D, b in B]...,]
end

log("generated "*string(length(S))*" states")

  0.508222 seconds (3.46 M allocations: 237.403 MiB, 11.05% gc time, 37.15% compilation time)
16:32:43	generated 686000 states


In [12]:
# Action space - actions are arm choices (K) or beta selections (M)
struct Action
    name::String      # valid names are {B,C} + index
    isBeta::Bool      # true if 'B' action, false if 'C' action
    index::Integer    # index of beta (if 'B' action) or arm choice (if 'C' action)
end

A = Array{Action}(undef, params.K+params.M)
for i in 1:params.K+params.M
    if i <= params.K
        A[i] = Action("C"*string(i), false, i)
    else
        A[i] = Action("B"*string(i-params.K), true, i-params.K)
    end
end
log("generated "*string(length(A))*" actions")

16:32:43	generated 5 actions


In [13]:
# Transition function
function T(s::State, a::Action)
    return SparseCat([s], [1.0])    # categorical distribution
end
log("generated transition function")

16:32:43	generated transition function


In [14]:
# Reward function
function R(s::State, a::Action)
    # if beta selected, return 0
    if a.isBeta
        return 0
    # if arm pulled, return that arm's avg utility
    else
        utilities = s.u
        arm_dist = s.d[a.index]
        return dot(utilities, arm_dist)
    end
end
log("generated reward function")

16:32:43	generated reward function


In [15]:
# item space
I = 1:params.N

# preference space
struct Preference
    i0::Int    # first item to compare, in {1,2,...,N}
    i1::Int    # second item to compare, in {1,2,...,N}
    label::Int # feedback label, in {0,1}
end

P = [[Preference(i0,i1,label) for i0 in I, i1 in I, label in [0,1]]...,]

# observation space
struct Observation
    isItem::Bool    # true if item returned, false otherwise
    i::Int          # item, if item returned
    p::Preference   # preference, if preference returned
end

invalid_i = -1
invalid_p = Preference(-1,-1,-1)
I_obs = [Observation(true, i, invalid_p) for i in I]
P_obs = [Observation(false, invalid_i, p) for p in P]
omega = union(I_obs, P_obs)

log("generated "*string(length(omega))*" observations")

16:32:43	generated 36 observations


In [16]:
# unnormalized query profile (likelihood of querying 1,1; 2,1; 3,1; ... ; N,1; 1,2; 2,2; ... ; N,N)
Q = ones(params.N*params.N)

# preference probability (expected preference, or probability that preference=1)
function Pr(p::Preference, s::State, b::Float64)
    prob_pref_1 = exp(Float64(b)*s.u[p.i1])/(exp(Float64(b)*s.u[p.i1])+exp(Float64(b)*s.u[p.i0]))
    if p.label == 1
        return prob_pref_1
    else
        return 1.0-prob_pref_1
    end
end

Pr (generic function with 1 method)

In [17]:
function O(s::State, a::Action, sp::State)
    # if B action, obs in P_obs
    if a.isBeta
        prob_of_pref = [Pr(o.p, s, s.b[a.index]) for o in P_obs]
        prob_of_query = vcat(Q,Q)   # doubled because each query appears once for each label
        
        # weight by querying profile to get dist
        dist = [prob_of_pref[i]*prob_of_query[i] for i in 1:length(prob_of_pref)]
        normalized_dist = dist/sum(dist)        
        return SparseCat(P_obs, normalized_dist)
    # if C action, obs in I_obs
    else
        return SparseCat(I_obs, s.d[a.index])
    end
end

log("generated observation function")

16:32:44	generated observation function


In [18]:
@time begin
    
    # define POMDP
    abstract type MyPOMDP <: POMDP{State, Action, Observation} end
    pomdp = QuickPOMDP(MyPOMDP,
        states       = S,
        actions      = A,
        observations = omega,
        transition   = T,
        observation  = O,
        reward       = R,
        discount     = params.y,
        initialstate = S);

end

log("created POMDP")

  1.304230 seconds (2.76 M allocations: 191.833 MiB, 4.96% gc time, 80.54% compilation time)
16:32:45	created POMDP


In [19]:
rollout = true

if rollout
    policy = RandomPolicy(pomdp)

    show_state = true
    for (s,a,r,o) in stepthrough(pomdp, policy, "s,a,r,o", max_steps=3)
        if show_state
            @show s
            println("")
            show_state = false
        end
        @show a
        @show r
        @show o
        println()
    end
end

s = State([10.0, 10.0, 10.0, 10.0], Array{Float64}[[0.25, 0.0, 0.0, 0.75], [0.0, 0.75, 0.25, 0.0], [0.5, 0.0, 0.5, 0.0]], [0.01, 10.0])

a = Action("C1", false, 1)
r = 10.0
o = Observation(true, 4, Preference(-1, -1, -1))

a = Action("B2", true, 2)
r = 0
o = Observation(false, -1, Preference(2, 1, 1))

a = Action("B2", true, 2)
r = 0
o = Observation(false, -1, Preference(4, 1, 1))



# Solve POMDP

In [20]:
@time begin
    solver = POMCPOWSolver()
    planner = solve(solver, pomdp);
end
log("solved POMDP")

  0.061144 seconds (36.70 k allocations: 2.206 MiB, 98.13% compilation time)
16:32:48	solved POMDP


In [21]:
action(planner, Uniform(S))

Action("C1", false, 1)

In [22]:
rollout = true

if rollout
    show_state = true
    for (s,a,r,o) in stepthrough(pomdp, planner, "s,a,r,o", max_steps=3)
        if show_state
            @show s
            println("")
            show_state = false
        end
        @show a
        @show r
        @show o
        println()
    end
end

s = State([0.0, 10.0, 10.0, 10.0], Array{Float64}[[0.0, 0.25, 0.0, 0.75], [0.0, 1.0, 0.0, 0.0], [0.0, 0.25, 0.25, 0.5]], [0.01, 10.0])

a = Action("C2", false, 2)
r = 10.0
o = Observation(true, 2, Preference(-1, -1, -1))

a = Action("C1", false, 1)
r = 10.0
o = Observation(true, 4, Preference(-1, -1, -1))

a = Action("C3", false, 3)
r = 10.0
o = Observation(true, 4, Preference(-1, -1, -1))



# Evaluate Solution

In [23]:
aₚ, info = action_info(planner, initialstate(pomdp), tree_in_info=true); aₚ
tree = D3Tree(info[:tree], init_expand=3)

In [24]:
steps = 5
iters = 1
prior = Uniform(S)
initial_state = S[100]
sim = RolloutSimulator(max_steps=steps)

log("generating "*string(iters)*" rollouts for "*string(steps)*" timesteps each")

random_R = zeros(iters)
POMCP_R = zeros(iters)
max_R = fill(maximum([dot(initial_state.u, initial_state.d[i]) for i in 1:params.K])*steps, iters)

@time begin
    for i in 1:iters
        log("Running simulation "*string(i))
        u1 = updater(RandomPolicy(pomdp))
        u2 = updater(planner)
        random_R[i] = simulate(sim, pomdp, RandomPolicy(pomdp), u1, prior, initial_state)
        POMCP_R[i] = simulate(sim, pomdp, planner, u2, prior, initial_state)
    end
end
    
log("ran "*string(iters)*" rollouts for "*string(steps)*" timesteps each")
log("random R: "*string(random_R))
log("POMCP R: "*string(POMCP_R))
log("Max R: "*string(max_R))

16:32:59	generating 1 rollouts for 5 timesteps each
16:32:59	Running simulation 1
  6.669976 seconds (164.27 M allocations: 4.625 GiB, 12.13% gc time, 8.13% compilation time)
16:33:06	ran 1 rollouts for 5 timesteps each
16:33:06	random R: [2.5]
16:33:06	POMCP R: [7.42525]
16:33:06	Max R: [12.5]


In [25]:
fig = plot(1:iters, [random_R,POMCP_R,max_R], 
    seriestype = :scatter, 
    label=["random" "POMCP" "max"], 
    ylims = (0,maximum(max_R)+100),
    xticks = 0:1:iters,
    xlabel = "run",
    ylabel = "reward (" * string(steps) * " timesteps)"
)
savefig(fig,"./plots/reward_ID"*string(expID)*"_step"*string(steps)*"_roll"*string(iters)*".png")