In [1]:
import numpy as np

from dynamic_programming import MarkovDecisionProcess, Policy, calculate_value_function, calculate_action_value_function, calculate_action_value_function_from_policy, calculate_greedy_policy, policy_iteration, calculate_optimal_value_function

In [2]:
transition_probabilities = np.array([
    [
        [0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 1]
    ]
])

rewards = np.array([1, 1, 1, 0])

action_probabilities = np.array([[1], [1], [1], [1]])

mdp = MarkovDecisionProcess(transition_probabilities, rewards, 0.9)
policy = Policy(action_probabilities)

value_function = calculate_value_function(mdp, policy)
print("value function", value_function)
action_value_function = calculate_action_value_function_from_policy(mdp, policy)
print("action value function", action_value_function)
greedy_policy = calculate_greedy_policy(action_value_function)
print("greedy policy", greedy_policy.action_probabilities)

value function [ 1.90000000e+00  1.00000000e+00  2.77555756e-17 -1.38777878e-16]
action value function [[ 1.9000000e+00]
 [ 1.0000000e+00]
 [-1.2490009e-16]
 [-1.2490009e-16]]
greedy policy [[1.]
 [1.]
 [1.]
 [1.]]


In [3]:
# actions: 0 -> right, 1 -> up, 2 -> left, 3 -> down
def make_grid_mdp(grid_size: int, discount: float):
    n_states = grid_size * grid_size

    transition_probabilities = np.zeros(shape=(4, n_states, n_states))
    for x in range(grid_size):
        for y in range(grid_size):
            cell_index = x + grid_size * y
            
            if not ((x == 0 and y == 0) or (x == grid_size -1 and y == grid_size - 1)):
                transition_probabilities[0, cell_index, cell_index + (x < grid_size - 1)] = 1
                transition_probabilities[1, cell_index, cell_index - grid_size * (y > 0)] = 1
                transition_probabilities[2, cell_index, cell_index - (x > 0)] = 1
                transition_probabilities[3, cell_index, cell_index + grid_size * (y < grid_size - 1)] = 1
            else:
                transition_probabilities[:, cell_index, cell_index] = 1./4


    rewards = np.full(shape=n_states, fill_value=-1)
    rewards[0] = rewards[-1] = 0

    return MarkovDecisionProcess(
        transition_probabilities=transition_probabilities,
        rewards=rewards,
        discount=discount
    )


def make_uniform_grid_policy(grid_size: int):
    return Policy(np.full(shape=(grid_size * grid_size, 4), fill_value=1.0/4))


def print_grid_values(grid_size: int, values: np.ndarray):
    values = values.reshape(grid_size, grid_size)

    for row in values:
        value_strs = (f" {value:.1f} ".rjust(6) for value in row)
        print(*value_strs, sep='|')

# actions: 0 -> right, 1 -> up, 2 -> left, 3 -> down
# arrows = ['→', '↑', '←', '↓']
def print_grid_policy(grid_size: int, policy: Policy):
    action_props = policy.action_probabilities.reshape(grid_size, grid_size, -1)

    for y, row in enumerate(action_props):

        for x, cell_props in enumerate(row):
            print("    ↑    " if cell_props[1] > 0 else "         ", end='\n' if x == grid_size - 1 else '|')

        for x, cell_props in enumerate(row):
            left_str = "←" if cell_props[2] > 0 else " "
            right_str = "→" if cell_props[0] > 0 else " "

            print(f" {left_str}     {right_str} ", end='\n' if x == grid_size - 1 else '|')

        for x, cell_props in enumerate(row):
            print("    ↓    " if cell_props[3] > 0 else "         ", end='\n' if x == grid_size - 1 else '|')


        if y < grid_size - 1:
            print('-' * (10 * grid_size - 1))


In [17]:
grid_size = 6

mdp = make_grid_mdp(grid_size, 0.9)
policy = make_uniform_grid_policy(grid_size)

In [18]:
value_function = calculate_value_function(mdp, policy)
print("value function uniform policy")
print_grid_values(grid_size, value_function)

value function uniform policy
 -0.0 | -5.3 | -7.8 | -8.9 | -9.3 | -9.4 
 -5.3 | -7.2 | -8.4 | -9.0 | -9.2 | -9.3 
 -7.8 | -8.4 | -8.8 | -9.0 | -9.0 | -8.9 
 -8.9 | -9.0 | -9.0 | -8.8 | -8.4 | -7.8 
 -9.3 | -9.2 | -9.0 | -8.4 | -7.2 | -5.3 
 -9.4 | -9.3 | -8.9 | -7.8 | -5.3 | -0.0 


In [30]:
action_value_function = calculate_action_value_function(mdp, value_function)
print("action value function uniform policy for →")
print_grid_values(grid_size, action_value_function[:,0])

action value function uniform policy for →
 -0.0 | -8.1 | -9.0 | -9.3 | -9.5 | -9.5 
 -7.5 | -8.5 | -9.1 | -9.3 | -9.3 | -9.3 
 -8.5 | -8.9 | -9.1 | -9.1 | -9.0 | -9.0 
 -9.1 | -9.1 | -8.9 | -8.5 | -8.1 | -8.1 
 -9.3 | -9.1 | -8.5 | -7.5 | -5.8 | -5.8 
 -9.3 | -9.0 | -8.1 | -5.8 | -0.0 | -0.0 


In [20]:
greedy_policy = calculate_greedy_policy(action_value_function.round(2))
print("optimal policy")
print_grid_policy(grid_size, greedy_policy)

optimal policy
    ↑    |         |         |         |         |         
 ←     → | ←       | ←       | ←       | ←       | ←       
    ↓    |         |         |         |         |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |         |         |         |         
         | ←       | ←       | ←       | ←       |         
         |         |         |         |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |    ↑    |         |         |         
         |         | ←       | ←       |         |         
         |         |         |    ↓    |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |    ↑    |         |         |         
         |         |       → |       → |         |         
         |         |         |    ↓    |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |     

In [21]:
value_function_for_optimal = calculate_value_function(mdp, greedy_policy)
print("value function for optimal policy")
print_grid_values(grid_size, value_function_for_optimal)

value function for optimal policy
  0.0 |  0.0 | -1.0 | -1.9 | -2.7 | -3.4 
  0.0 | -1.0 | -1.9 | -2.7 | -3.4 | -2.7 
 -1.0 | -1.9 | -2.7 | -3.4 | -2.7 | -1.9 
 -1.9 | -2.7 | -3.4 | -2.7 | -1.9 | -1.0 
 -2.7 | -3.4 | -2.7 | -1.9 | -1.0 |  0.0 
 -3.4 | -2.7 | -1.9 | -1.0 |  0.0 |  0.0 


In [22]:
greedy_policy2 = calculate_greedy_policy(calculate_action_value_function(mdp, value_function_for_optimal).round(2))
print("optimal policy")
print_grid_policy(grid_size, greedy_policy2)

optimal policy
    ↑    |         |         |         |         |         
 ←     → | ←       | ←       | ←       | ←       | ←       
    ↓    |         |         |         |         |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |    ↑    |    ↑    |    ↑    |         
         | ←       | ←       | ←       | ←     → |         
         |         |         |         |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |    ↑    |    ↑    |         |         
         | ←       | ←       | ←     → |       → |         
         |         |         |    ↓    |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |    ↑    |         |         |         
         | ←       | ←     → |       → |       → |         
         |         |    ↓    |    ↓    |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |     

In [24]:
optimal_policy = policy_iteration(mdp, make_uniform_grid_policy(grid_size), 10)
print("policy obtained from policy iteration algorithm")
print_grid_policy(grid_size, optimal_policy)

policy obtained from policy iteration algorithm
    ↑    |         |         |         |         |         
 ←     → | ←       | ←       | ←       | ←       |         
    ↓    |         |         |         |         |    ↓    
-----------------------------------------------------------
    ↑    |         |         |         |         |         
         | ←       | ←       | ←       |         |         
         |         |         |         |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |         |    ↑    |         |         |         
         | ←       |         |       → |       → |         
         |         |         |         |         |    ↓    
-----------------------------------------------------------
    ↑    |         |    ↑    |         |         |         
         | ←       |         |       → |         |         
         |         |    ↓    |         |    ↓    |    ↓    
----------------------------------------------------

In [25]:
optimal_value_function = calculate_optimal_value_function(mdp, 100)
print("optimal value function")
print_grid_values(grid_size, optimal_value_function)

optimal value function
  0.0 |  0.0 | -1.0 | -1.9 | -2.7 | -3.4 
  0.0 | -1.0 | -1.9 | -2.7 | -3.4 | -2.7 
 -1.0 | -1.9 | -2.7 | -3.4 | -2.7 | -1.9 
 -1.9 | -2.7 | -3.4 | -2.7 | -1.9 | -1.0 
 -2.7 | -3.4 | -2.7 | -1.9 | -1.0 |  0.0 
 -3.4 | -2.7 | -1.9 | -1.0 |  0.0 |  0.0 


In [26]:
optimal_action_value_function = calculate_action_value_function(mdp, optimal_value_function)
optimal_greedy_policy = calculate_greedy_policy(optimal_action_value_function)
print("greedy policy associated to optimal value function")
print_grid_policy(grid_size, optimal_greedy_policy)

greedy policy associated to optimal value function
    ↑    |         |         |         |         |         
 ←     → | ←       | ←       | ←       | ←       | ←       
    ↓    |         |         |         |         |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |    ↑    |    ↑    |    ↑    |         
         | ←       | ←       | ←       | ←     → |         
         |         |         |         |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |    ↑    |    ↑    |         |         
         | ←       | ←       | ←     → |       → |         
         |         |         |    ↓    |    ↓    |    ↓    
-----------------------------------------------------------
    ↑    |    ↑    |    ↑    |         |         |         
         | ←       | ←     → |       → |       → |         
         |         |    ↓    |    ↓    |    ↓    |    ↓    
-------------------------------------------------