In [33]:
using Gen, Statistics, StatsPlots, Memoize, BenchmarkTools, Distributions
import POMDPs, POMDPModels, POMDPModelTools

We tried this example and it doesn't work, just returns some random trajectory with no regard to score (or so it seems)

In [73]:
actions = [:up,:down,:left,:right]
grid_world = [
    "V" "#" ;
    "___" "___";
]

grid_world_utilities = Dict(
  "V" => 3,
  "___"=> -0.1,
  "#" => -1000000
)

n_rows, n_cols = size(grid_world)
grid_size = (n_cols, n_rows)
function grid_to_utilities_dict(grid_world)
    utilities_dict = Dict()
    for i=1:n_rows
        for j=1:n_cols
            utility = grid_world_utilities[grid_world[n_rows+1 - i,j]]
            utilities_dict[POMDPModels.GWPos(j, i)] = utility
        end
    end
    return utilities_dict
end

function get_terminal_states(grid_world)
    terminal_states = Set()
    for i=1:n_rows
        for j=1:n_cols 
            world_object = grid_world[n_rows+1-i,j]
            if world_object != "___" && world_object != "#" 
                push!(terminal_states,POMDPModels.GWPos(j,i))
            end
        end
    end
    return terminal_states
end

start_state = POMDPModels.GWPos(2, 1)
gw = POMDPModels.SimpleGridWorld(size=grid_size,
                                    rewards=grid_to_utilities_dict(grid_world),
                                    terminate_from = get_terminal_states(grid_world),
                                    tprob = 1,
                                    discount = 1)

function transition(state,action)
    return POMDPModelTools.rand(POMDPs.transition(gw, state, action))
end

function utility(state)
    return gw.rewards[state]
end

function is_terminal_state(state)
    return POMDPs.isterminal(gw, transition(state,:up))
end

function make_agent()
    @gen function act(state, time_left)
        action_index = @trace(uniform_discrete(1,4),:action_index)
        eu = expected_utility(state, actions[action_index], time_left)
        @trace(bernoulli(exp(100 * (eu - 3))), :factor)
        return action_index
    end

    @memoize function run_act(state, time_left)
        action_indices = []
        trace, = generate(act, (state, time_left), choicemap((:factor,1)))
        for i = 1:1000
            trace, = Gen.mh(trace, select(:action_index))
            push!(action_indices, get_retval(trace))
        end
        return action_indices
    end

    @gen function reward(state, action, time_left)
        next_state = transition(state, action)
        action_indices = run_act(state, time_left)
        rand_choice = @trace(uniform_discrete(1, length(action_indices)), :rand_choice)
        next_action_idx = action_indices[rand_choice]
        return expected_utility(next_state, actions[next_action_idx], time_left)
    end

    @memoize function run_reward(state, action, time_left)
        rewards = []
        trace, = generate(reward, (state, action, time_left))
        for i = 1:1000
            trace, = Gen.mh(trace, select(:rand_choice))
            push!(rewards, get_retval(trace))
        end
        return rewards
    end

    @memoize function expected_utility(state, action, time_left)
        u = utility(state)
        new_time_left = time_left - 1;
        if is_terminal_state(state) || new_time_left == 0
            return u
        else
            return u + mean(run_reward(state, action, new_time_left))
        end
    end
    return run_act
end

function simulate(start_state, total_time)
    states = []
    run_act = make_agent()
    next_state = start_state
    while !is_terminal_state(next_state) && total_time > 0
        push!(states, next_state)
        action_indices = run_act(next_state, total_time)
        println("time left: $total_time \n", Dict([(action,count(x->actions[x]==action,action_indices)/1000) for action in actions]))
        rand_choice = rand(DiscreteUniform(1, length(action_indices)))
        next_action = action_indices[rand_choice]
        next_state = transition(next_state, actions[next_action])
        total_time -= 1
        
    end
    push!(states, next_state)
    return states
end;

In [77]:
trajectory = simulate(start_state, 5)

time left: 5 
Dict(:left => 0.0, :right => 0.451, :up => 0.0, :down => 0.549)
time left: 4 
Dict(:left => 0.0, :right => 0.515, :up => 0.0, :down => 0.485)
time left: 3 
Dict(:left => 0.324, :right => 0.349, :up => 0.0, :down => 0.327)
time left: 2 
Dict(:left => 0.319, :right => 0.34, :up => 0.0, :down => 0.341)
time left: 1 
Dict(:left => 0.244, :right => 0.249, :up => 0.251, :down => 0.256)


6-element Vector{Any}:
 [2, 1]
 [2, 1]
 [2, 1]
 [2, 1]
 [2, 1]
 [2, 1]

When we try it with an infinite even horizon (meaning no time limit) it encounters a stack overflow, which makes no sense since the grid is so small

In [64]:
actions = [:up,:down,:left,:right]
grid_world = [
    "V" "___" ;
    "___" "___";
]

grid_world_utilities = Dict(
  "V" => 3,
  "___"=> -0.1,
  "#" => -1000000
)

n_rows, n_cols = size(grid_world)
grid_size = (n_cols, n_rows)
function grid_to_utilities_dict(grid_world)
    utilities_dict = Dict()
    for i=1:n_rows
        for j=1:n_cols
            utility = grid_world_utilities[grid_world[n_rows+1 - i,j]]
            utilities_dict[POMDPModels.GWPos(j, i)] = utility
        end
    end
    return utilities_dict
end

function get_terminal_states(grid_world)
    terminal_states = Set()
    for i=1:n_rows
        for j=1:n_cols 
            world_object = grid_world[n_rows+1-i,j]
            if world_object != "___"
                push!(terminal_states,POMDPModels.GWPos(j,i))
            end
        end
    end
    return terminal_states
end

start_state = POMDPModels.GWPos(2, 1)
gw = POMDPModels.SimpleGridWorld(size=grid_size,
                                    rewards=grid_to_utilities_dict(grid_world),
                                    terminate_from = get_terminal_states(grid_world),
                                    tprob = 1,
                                    discount = 1)

function transition(state,action)
    return POMDPModelTools.rand(POMDPs.transition(gw, state, action))
end

function utility(state)
    return gw.rewards[state]
end

function is_terminal_state(state)
    return POMDPs.isterminal(gw, transition(state,:up))
end

function make_agent()
    @gen function act(state, time_left)
        action_index = @trace(uniform_discrete(1,4),:action_index)
        eu = expected_utility(state, actions[action_index], time_left)
        @trace(bernoulli(exp(100 * (eu - 3))), :factor)
        return action_index
    end

    @memoize function run_act(state, time_left)
        action_indices = []
        trace, = generate(act, (state, time_left), choicemap((:factor,1)))
        for i = 1:1000
            trace, = Gen.mh(trace, select(:action_index))
            push!(action_indices, get_retval(trace))
        end
        return action_indices
    end

    @gen function reward(state, action, time_left)
        next_state = transition(state, action)
        action_indices = run_act(state, time_left)
        rand_choice = @trace(uniform_discrete(1, length(action_indices)), :rand_choice)
        next_action_idx = action_indices[rand_choice]
        return expected_utility(next_state, actions[next_action_idx], time_left)
    end

    @memoize function run_reward(state, action, time_left)
        rewards = []
        trace, = generate(reward, (state, action, time_left))
        for i = 1:1000
            trace, = Gen.mh(trace, select(:rand_choice))
            push!(rewards, get_retval(trace))
        end
        return rewards
    end

    @memoize function expected_utility(state, action, time_left)
        u = utility(state)
        new_time_left = time_left - 1;
        if is_terminal_state(state) || new_time_left == 0
            return u
        else
            return u + mean(run_reward(state, action, new_time_left))
        end
    end
    return run_act
end

function simulate(start_state, total_time)
    states = []
    run_act = make_agent()
    next_state = start_state
    while !is_terminal_state(next_state) && total_time > 0
        push!(states, next_state)
        action_indices = run_act(next_state, total_time)
        println("time left: $total_time \n", Dict([(action,count(x->actions[x]==action,action_indices)/1000) for action in actions]))
        rand_choice = rand(DiscreteUniform(1, length(action_indices)))
        next_action = action_indices[rand_choice]
        next_state = transition(next_state, actions[next_action])
        total_time -= 1
    end
    push!(states, next_state)
    return states
end;

In [69]:
trajectory = simulate(start_state, 5)

3-element Vector{Any}:
 [2, 1]
 [1, 1]
 [1, 2]