In [1]:
using Gen, Statistics, StatsPlots, Memoize, BenchmarkTools


In [2]:
CONTROL_FACTOR = 10
AMOUNT = 100

grid_world = [
    "V"    "#" ;
    "___" "___";
    "___" "___";
    "___" "___";
    "___" "___";
]

grid_world_utilities = Dict(
  "V" => 3,
  "___"=> -0.1,
)

n_rows, n_cols = size(grid_world)
grid_size = (n_rows, n_cols)
start_state = [2,2]

function transition(state,action)
    if action == :down
        return [state[1]+1,state[2]]
    elseif action == :up
        return [state[1]-1, state[2]]
    elseif action == :left
        return [state[1], state[2]-1]
    elseif action == :right
        return [state[1], state[2]+1]
    end
end

function get_tile(state)
    return grid_world[state...]
end

function state_to_actions(state)
    actions = []
    if state[1] < grid_size[1] && get_tile(transition(state,:down)) != "#" 
        push!(actions,:down)
    end
    if state[1] > 1 && get_tile(transition(state,:up)) != "#"
        push!(actions,:up)
    end 
    if state[2] > 1 && get_tile(transition(state,:left)) != "#"
        push!(actions, :left)
    end 
    if state[2] < grid_size[2] && get_tile(transition(state,:right)) != "#"
        push!(actions,:right)
    end
    return actions
end

function utility(state)
    return grid_world_utilities[get_tile(state)]
end

function is_terminal_state(state)
    return get_tile(state) == "V"
end

function make_agent()
    @gen function act(state, time_left)
        possible_actions = state_to_actions(state)
        action_index = @trace(uniform_discrete(1,length(possible_actions)),:action_index)
        next_action = possible_actions[action_index]
        eu = expected_utility(state, next_action, time_left)
        @trace(bernoulli(exp(CONTROL_FACTOR * eu)), :factor)
        return action_index
    end

    @memoize function run_act(state, time_left)
        action_indices = []
        trace, = generate(act, (state, time_left), choicemap((:factor,1)))
        for i = 1:AMOUNT
            trace, = Gen.mh(trace, select(:action_index))
            push!(action_indices, get_retval(trace))
        end
        return action_indices
    end

    @gen function reward(state, action, time_left)
        next_state = transition(state, action)
        possible_actions = state_to_actions(next_state)
        action_indices = run_act(next_state, time_left)
        rand_choice = @trace(uniform_discrete(1, length(action_indices)), :rand_choice)
        next_action_idx = action_indices[rand_choice]
        next_action = possible_actions[next_action_idx]
        return expected_utility(next_state, next_action, time_left)
    end

    @memoize function run_reward(state, action, time_left)
        rewards = []
        trace, = generate(reward, (state, action, time_left))
        for i =1:AMOUNT
            trace, = Gen.mh(trace, select(:rand_choice))
            push!(rewards, get_retval(trace))
        end
        return rewards
    end

    @memoize function expected_utility(state, action, time_left)
        u = utility(state)
        new_time_left = time_left - 1
        if is_terminal_state(state) || new_time_left == 0
            return u
        else
            return u + mean(run_reward(state, action, new_time_left))
        end
    end
    return run_act
end

function simulate_agent(start_state, total_time)
    states = []
    run_act = make_agent()
    next_state = start_state
    while !is_terminal_state(next_state) && total_time > 0
        push!(states, next_state)
        action_indices = run_act(next_state, total_time)
        possible_actions = state_to_actions(next_state)
        println("time left: $total_time \n", Dict([(action,count(x->possible_actions[x]==action,action_indices)/AMOUNT) for action in possible_actions]))
        rand_choice = uniform_discrete(1, length(action_indices))
        next_action_idx = action_indices[rand_choice]
        next_state = transition(next_state, possible_actions[next_action_idx])
        total_time -= 1
    end
    push!(states, next_state)
    return states
end;

In [3]:
trajectory = simulate_agent(start_state, 5)