In [26]:
using POMDPs
using QuickPOMDPs
using POMDPTools: Deterministic, Uniform, SparseCat, stepthrough, RandomPolicy, FunctionPolicy

In [28]:
tiger = QuickPOMDP(
    states = ["left", "right"],
    actions = ["left", "right", "listen"],
    observations = ["left", "right"],
    
    transition = function (s, a)
        if a == "listen"
            return Deterministic(s)
        else
            return Uniform(["left", "right"])
        end
    end,
    
    observation = function (a, sp)
        if a == "listen"
            if sp == "left"
                return SparseCat(["left", "right"], [0.85, 0.15])
            else
                return SparseCat(["right", "left"], [0.85, 0.15])
            end
        else
            return Uniform(["left", "right"])
        end
    end,
    
    reward = function (s, a)
        if a == "listen"
            return -1.0
        elseif a == s
            return -100.0
        else
            return 10.0
        end
    end,
    
    initialstate = Uniform(["left", "right"]),
    
    discount = 0.95
)

QuickPOMDP{UUID("623651bf-cda2-479f-a4ff-ad9d321ad7f7"), String, String, String, NamedTuple{(:stateindex, :isterminal, :obsindex, :states, :observations, :discount, :actions, :observation, :actionindex, :transition, :reward, :initialstate), Tuple{Dict{String, Int64}, Bool, Dict{String, Int64}, Vector{String}, Vector{String}, Float64, Vector{String}, var"#20#23", Dict{String, Int64}, var"#19#22", var"#21#24", Uniform{Set{String}}}}}((stateindex = Dict("left" => 1, "right" => 2), isterminal = false, obsindex = Dict("left" => 1, "right" => 2), states = ["left", "right"], observations = ["left", "right"], discount = 0.95, actions = ["left", "right", "listen"], observation = var"#20#23"(), actionindex = Dict("left" => 1, "right" => 2, "listen" => 3), transition = var"#19#22"(), reward = var"#21#24"(), initialstate = Uniform{Set{String}}(Set(["left", "right"]))))

In [29]:
for step in stepthrough(tiger, RandomPolicy(tiger), "s,a,r,sp,o", max_steps=10)
    display(step)
end

(s = "left", a = "right", r = 10.0, sp = "left", o = "left")

(s = "left", a = "left", r = -100.0, sp = "right", o = "right")

(s = "right", a = "listen", r = -1.0, sp = "right", o = "right")

(s = "right", a = "listen", r = -1.0, sp = "right", o = "right")

(s = "right", a = "left", r = 10.0, sp = "left", o = "left")

(s = "left", a = "left", r = -100.0, sp = "right", o = "right")

(s = "right", a = "left", r = 10.0, sp = "left", o = "left")

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

(s = "left", a = "right", r = 10.0, sp = "left", o = "left")

(s = "left", a = "left", r = -100.0, sp = "left", o = "left")

In [33]:
function belief_update(m::POMDP, b, a, o)
    states = collect(support(b))
    probs = zeros(length(states))
    for i in 1:length(states)
        z = observation(m, a, states[i])
        sp = states[i]
        probs[i] = pdf(z, o)*sum(s -> pdf(b, s)*pdf(transition(m, s, a), sp), states)
    end
    probs ./= sum(probs)
    return SparseCat(states, probs)
end

belief_update (generic function with 1 method)

In [34]:
belief = Uniform(["left", "right"])
display(belief)
for step in stepthrough(tiger, FunctionPolicy(_->"listen"), "s,a,r,sp,o", max_steps=10)
    display(step)
    belief = belief_update(tiger, belief, step.a, step.o)
    display(belief)
end

                      [1mUniform distribution[22m            
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■[39m[0m 0.5 [90m [39m 
   [0m"right" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■[39m[0m 0.5 [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■[39m[0m 0.85 [90m [39m 
   [0m"right" [90m┤[39m[38;5;2m■■■■■■[39m[0m 0.15                             [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9697986577181208 [90m [39m 
   [0m"right" [90m┤[39m[38;5;2m■[39m[0m 0.0302013422818792                    [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9945344129554656 [90m [39m 
   [0m"right" [90m┤[39m[0m 0.005465587044534414                   [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "right")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9697986577181208 [90m [39m 
   [0m"right" [90m┤[39m[38;5;2m■[39m[0m 0.030201342281879203                  [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9945344129554656 [90m [39m 
   [0m"right" [90m┤[39m[0m 0.005465587044534414                   [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9990311236573287 [90m [39m 
   [0m"right" [90m┤[39m[0m 0.0009688763426712281                  [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9998288852897683 [90m [39m 
   [0m"right" [90m┤[39m[0m 0.00017111471023167384                 [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9999697990305696 [90m [39m 
   [0m"right" [90m┤[39m[0m 3.0200969430404745e-5                  [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9999946702846019 [90m [39m 
   [0m"right" [90m┤[39m[0m 5.32971539807174e-6                    [90m [39m 
           [90m└                                        ┘[39m 

(s = "left", a = "listen", r = -1.0, sp = "left", o = "left")

                     [1mSparseCat distribution[22m           
           [90m┌                                        ┐[39m 
    [0m"left" [90m┤[39m[38;5;2m■■■■■■■■■■■■■■■■■■■■[39m[0m 0.9999990594578604 [90m [39m 
   [0m"right" [90m┤[39m[0m 9.405421396307151e-7                   [90m [39m 
           [90m└                                        ┘[39m 