In [11]:
coord_index(x, y, width) = ((x-1)*width) + y


delete_action!(action_array, action) = deleteat!(action_array, findall(x->x==action,action_array))

function apply_action(state, action, width)
    if action == 'N'
        return state - width
    elseif action == 'S'
        return state + width
    elseif action == 'E'
        return state + 1
    elseif action == 'W'
        return state - 1
    elseif action == ' '
        return state
    else
        println("invalid action $(action) taken")
        return 0
    end
end

function printdelim(width)
    for num in 1:width
        print("-----")
    end
    print("\n")
end

function pretty_print_grid(grid, width, height)
    pretty_world = round.(grid, digits=1)
    for row in 1:height
        printdelim(width)
        for col in 1:width
            num = pretty_world[coord_index(row, col, width)]
            sign(num) != -1.0 ? print("| $(num)") : print("|$(num)")
        end
        print("|\n")
    end
    printdelim(width)
end

function action_to_arrow(action)
    if action == 'N'
        return "^"
    elseif action == 'S'
        return "v"
    elseif action == 'E'
        return ">"
    elseif action == 'W'
        return "<"
    elseif action == ' '
        return ""
    else
        println("invalid action $(action) taken")
        return 0
    end
end
    
function print_action_delim(width)
    for num in 1:width
        print("-----")
    end
    print("\n")
end

function pretty_print_action_grid(grid, width, height)
    for row in 1:height
        print_action_delim(width)
        for col in 1:width
            arrows = join(action_to_arrow.(grid[coord_index(row, col, width)]))    
            while length(arrows) != 4
                arrows = " $arrows"
            end
            
            print("|$(arrows)")
        end
        print("|\n")
    end
    print_action_delim(width)
end

function make_grid_world(width, height, initial_value = -1.0)
    # everything but terminal states are negative values
    gridworld = [initial_value for x in 1:(width*height)]
    gridworld[coord_index(1,1, width)] = 0.0
    gridworld[coord_index(width, height, width)] = 0.0

    # probably could be done more elagently
    possible_actions = [['W', 'N', 'S', 'E'] for x in 1:(width*height)]
    for row in 1:height
        for col in 1:width
            if col == 1
                delete_action!(possible_actions[coord_index(col, row, width)], 'N')
            elseif col == height
                delete_action!(possible_actions[coord_index(col, row, width)], 'S')
            end

            if row == 1
                delete_action!(possible_actions[coord_index(col, row, width)], 'W')
            elseif row == width
                 delete_action!(possible_actions[coord_index(col, row, width)], 'E')
            end
        end
    end
    #terminal states
    possible_actions[coord_index(1,1, width)] = [' ']
    possible_actions[coord_index(width,height, width)] = [' ']
    
    return gridworld, possible_actions
end

make_grid_world (generic function with 2 methods)

In [13]:
# important note: generated policy is always deterministic (i.e one action per state)
function policy_iteration(reward_function, action_space, width, height, discount=1.0, theta_v = 0.5)
    size = length(reward_function)
    # just pick the first possible action for the initial policy
    extract_first(array) = length(array) > 0 ? array[1] : ' '
    
    policy = extract_first.(action_space)
    value = [0.0 for x in 1:size]
    policy_stable = false
    num_iter = 0
    num_internal_iter = 0
    
    while !policy_stable
        # evaluate the value function given that policy
        while true
            delta_v = 0.0
            for x in 1:size
                old_v = value[x]
                value[x] = reward_function[x] + discount * value[apply_action(x, policy[x], width)]
                delta_v = max(delta_v, abs(old_v - value[x]))
            end
            delta_v < theta_v || break
        end
            
        # pick better policy based on value function
        policy_stable = true
        for x in 1:size
            max_val = 0.0
            for a in action_space[x]
                if reward_function[x] + discount * value[apply_action(x, a, width)] > value[x]
                    policy[x] = a
                    policy_stable = false
                end
            end
        end
        num_iter += 1
        num_internal_iter += 1
    end
    println("converged in $num_iter policy runs and $num_internal_iter value estimation runs")
    println("Policy is:")
    pretty_print_action_grid(policy, width, height)
    println("Value function is:")
    pretty_print_grid(value, width, height)
end

policy_iteration (generic function with 3 methods)

In [18]:
unzip(a) = map(x->getfield.(a, x), fieldnames(eltype(a)))

function value_iteration(reward_function, action_space, width, height, discount=1.0, theta_v=0.5)
    size = length(reward_function)
    value = [0.0 for x in 1:size]
    num_iter = 0
    
    # estimate v*
    while true
        delta_v = 0.0
        for x in 1:size
            old_v = value[x]
            v = [reward_function[x] + discount * value[apply_action(x, a, width)] for a in action_space[x]]
            value[x] = maximum(v)
            delta_v = max(delta_v, abs(old_v - value[x]))
        end
        num_iter += 1
        if delta_v < theta_v
            break
        end
    end
    
    # extract policy from v*
    policy = []
    for x in 1:size      
        v, a = unzip([(reward_function[x] + discount * value[apply_action(x, a, width)], a) for a in action_space[x]])
        max_a = a[findall(x->x==maximum(v), v)[1]]
        append!(policy, max_a)
    end
    
    println("converged in $num_iter iterations")
    println("Policy is:")
    pretty_print_action_grid(policy, width, height)
    println("Value function is:")
    pretty_print_grid(value, width, height)

end

value_iteration (generic function with 4 methods)

In [19]:
reward, action_space = make_grid_world(4,4)

println("rewards per state:")
pretty_print_grid(reward, 4, 4)
print("\n")
println("possible actions per state:")
pretty_print_action_grid(action_space, 4 , 4)

print("===========================\n")
print("policy iteration solution:\n")
print("===========================\n")
policy_iteration(reward, action_space, 4, 4)

print("===========================\n")
print("value iteration solution:\n")
print("===========================\n")
value_iteration(reward, action_space, 4, 4)

rewards per state:
--------------------
| 0.0|-1.0|-1.0|-1.0|
--------------------
|-1.0|-1.0|-1.0|-1.0|
--------------------
|-1.0|-1.0|-1.0|-1.0|
--------------------
|-1.0|-1.0|-1.0| 0.0|
--------------------

possible actions per state:
--------------------
|    | <v>| <v>|  <v|
--------------------
| ^v>|<^v>|<^v>| <^v|
--------------------
| ^v>|<^v>|<^v>| <^v|
--------------------
|  ^>| <^>| <^>|    |
--------------------
policy iteration solution:
converged in 3 policy runs and 3 value estimation runs
Policy is:
--------------------
|    |   <|   <|   <|
--------------------
|   ^|   <|   <|   v|
--------------------
|   ^|   <|   >|   v|
--------------------
|   ^|   >|   >|    |
--------------------
Value function is:
--------------------
| 0.0|-1.0|-2.0|-3.0|
--------------------
|-1.0|-2.0|-3.0|-2.0|
--------------------
|-2.0|-3.0|-2.0|-1.0|
--------------------
|-3.0|-2.0|-1.0| 0.0|
--------------------
value iteration solution:
converged in 4 iterations
Policy is:
-----

In [24]:
big_reward, big_action_space = make_grid_world(10,10)

println("rewards per state:")
pretty_print_grid(big_reward, 10, 10)
print("\n")
println("possible actions per state:")
pretty_print_action_grid(big_action_space, 10 , 10)

print("===========================\n")
print("policy iteration solution:\n")
print("===========================\n")
policy_iteration(big_reward, big_action_space, 10, 10)

print("===========================\n")
print("value iteration solution:\n")
print("===========================\n")
value_iteration(big_reward, big_action_space, 10, 10)

rewards per state:
--------------------------------------------------
| 0.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|-1.0|
--------------------------------------------------
|-1