In [1]:
import numpy as np

In [2]:
### Markov Decision Process for Small Gridworld
# 1) State-space -> self.states
# 2) Action-space -> self.actions
# 3) State transition probabilities -> self.P
# 4) Reward function -> self.rewards

SHAPE = (4,4)
UP, DOWN, LEFT, RIGHT = 0, 1, 2, 3

# Helper function to convert state index from 2d to 1d
def get_state_idx_1d(x,y,shape):
    return x*shape[1]+y

class Gridworld:
    def __init__(self, shape):
        self.shape = shape
        num_states = np.prod(shape)
        num_actions = 4
        self.states = np.arange(num_states)
        self.actions = [UP,DOWN,LEFT,RIGHT]
        self.gamma = 1
        rewards = [-1 for j in range(num_states)]
        rewards[0] = rewards[num_states-1] = 0
        self.rewards = rewards
        
        P = np.zeros((num_states,num_actions,num_states,2)) # for each state-action pair ((x,y), a) => stores P((x',y')) and expected reward
        states = np.arange(num_states).reshape(shape)
        iterator = np.nditer(states, flags=['multi_index'])

        # Probability of next state
        def get_next_state(x,y,a):
            if (x==0 and y==0) or (x==shape[0]-1 and y==shape[1]-1):
                return get_state_idx_1d(x,y,shape)
            nx = x
            ny = y
            if a == UP:
                nx = x-1
            elif a == DOWN:
                nx = x+1
            elif a == LEFT:
                ny = y-1
            else:
                 ny = y+1
            if nx < 0 or nx > shape[0]-1:
                nx = x
            if ny < 0 or ny > shape[1]-1:
                ny = y
            return get_state_idx_1d(nx, ny, shape)

        while not iterator.finished:
            x,y = iterator.multi_index
            cur_state = get_state_idx_1d(x,y,shape)
            for a in {UP,DOWN,LEFT,RIGHT}:
                next_state = get_next_state(x,y,a)
                P[cur_state][a][next_state][0] = 1
                P[cur_state][a][next_state][1] = rewards[cur_state]
            iterator.iternext()

        self.P = P

gridworld = Gridworld(SHAPE)

In [3]:
# Policy Evaluation Algorithm

def get_value_grid(values):
    n = len(values[0])
    dim = int(n**0.5)
    return np.round(values,2).reshape((dim,dim))

def policy_evaluation(policy, gridworld, threshold=1e-5):
    states = gridworld.states
    actions = gridworld.actions
    gamma = gridworld.gamma
    n = len(states)
    
    values = np.random.random((1,n))
    values[0][0] = values[0][n-1] = 0
        
    while True:
        delta = 0
        for s in states:
            prev_val = values[0][s]
            temp = 0
            for a in actions:
                transition_prob = gridworld.P[s,a,:,0]
                reward_fun = gridworld.P[s,a,:,1]
                temp += policy[s][a] * np.multiply(transition_prob,(reward_fun + gamma*values)).sum()
            values[0][s] = temp
            delta = max(delta, abs(values[0][s] - prev_val))
        if delta < threshold:
            break

    return values

In [4]:
# Policy Improvement Algorithm

def policy_improvement(policy, values, gridworld):
    states = gridworld.states
    actions = gridworld.actions
    gamma = gridworld.gamma
    n = len(states)
    m = len(actions)
    
    new_policy = np.zeros((n, m))
    has_improved = False
    for s in states:
        best_actions = [0]
        max_value = -float('inf')
        for a in actions:
            transition_prob = gridworld.P[s,a,:,0]
            reward_fun = gridworld.P[s,a,:,1]
            action_value = np.multiply(transition_prob,(reward_fun + gamma*values)).sum()
            if action_value > max_value:
                best_actions = [a]
            elif action_value == max_value:
                best_actions.append(a)
            max_value = max(max_value, action_value)

        prev_action = policy[s].argmax()
        if prev_action not in best_actions:
            has_improved = True
        new_policy[s][best_actions[0]] = 1
        
    if has_improved:
        return (new_policy, True)
    return (policy, False)

In [5]:
states = gridworld.states
actions = gridworld.actions
policy = np.zeros((len(states),len(actions)))+0.25

values = policy_evaluation(policy, gridworld, threshold=1e-5)
new_policy, is_improvement = policy_improvement(policy, values, gridworld)
print("Value function of original policy:\n{}\n".format(get_value_grid(values)))
new_values = policy_evaluation(new_policy, gridworld, threshold=1e-5)
print("Value function of new policy after policy improvement:\n{}\n".format(get_value_grid(new_values)))

Value function of original policy:
[[  0. -14. -20. -22.]
 [-14. -18. -20. -20.]
 [-20. -20. -18. -14.]
 [-22. -20. -14.   0.]]

Value function of new policy after policy improvement:
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]



In [6]:
# Policy Iteration Algorithm

def policy_iteration(gridworld, threshold = 1e-5):
    states = gridworld.states
    actions = gridworld.actions

    policy = np.zeros((len(states),len(actions)))+0.25
    iteration_count = 0

    while True:
        values = policy_evaluation(policy, gridworld, threshold=threshold)
        policy, is_improvement = policy_improvement(policy, values, gridworld)
        iteration_count += 1
        if not is_improvement:
            break

    print("Found optimal policy in {} iterations\n".format(iteration_count))
    print("Optimal policy (action probability distribution):\n{}\n".format(policy))
    print("Optimal state-value function:\n{}\n".format(values))
    print("Optimal state-value function (in grid):\n{}\n".format(get_value_grid(values)))

    return (policy,values)

In [7]:
policy, value = policy_iteration(gridworld, threshold=1e-5)

Found optimal policy in 2 iterations

Optimal policy (action probability distribution):
[[1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 1. 0.]
 [0. 0. 1. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [1. 0. 0. 0.]]

Optimal state-value function:
[[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.]]

Optimal state-value function (in grid):
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]

