In [3]:
import numpy as np
import pprint
from operator import itemgetter

In [4]:
class MDP:
    def __init__(self, T, S, R, A, act_list):        
        # State space
        # Integer number of states
        self.S = S
        
        # Transition probabilities
        # Form: np ndarray of shape (start_state, action, end_state)
        self.T = T
        
        # Reward space
        # Form: vector, rewards for each state
        self.R = R
        
        # Action space
        # integer, number of possible actions
        self.A = A
        
        # Possible actions in the MDP
        self.actions = act_list
    

In [5]:
class Grid_world(MDP):
    def __init__(self, grid_size, reward_pos):
        S = grid_size*grid_size
        
        R = np.zeros((grid_size, grid_size))
        
        # Each row of reward_pos is a tuple: x, y, reward
        for row in reward_pos:
            R[row[0], row[1]] = row[2]
        R = R.flatten()
        
        A = 4
        act_list = ['S', 'E', 'N', 'W']
        
        T = np.zeros((S, A, S))
        for start_state in range(S):
            state_i = start_state/grid_size
            state_j = (start_state)%grid_size
            
            # Actions indexed as: 0:S, 1:E, 2:N, 3:W
            for act in range(A):
                feas_grid = np.zeros((grid_size, grid_size))
                if(act == 0 ):
                    if(state_i+1 < grid_size):                        
                        feas_grid[state_i+1, state_j] = 1
                    else:
                        feas_grid[state_i, state_j] = 1
                        
                elif(act == 1):
                    if(state_j+1 < grid_size):                        
                        feas_grid[state_i, state_j+1] = 1
                    else:
                        feas_grid[state_i, state_j] = 1                    
                    
                elif(act == 2):
                    if(state_i-1 >= 0):                        
                        feas_grid[state_i-1, state_j] = 1
                    else:
                        feas_grid[state_i, state_j] = 1                    
                    
                elif(act == 3):
                    if(state_j-1 >= 0):                        
                        feas_grid[state_i, state_j-1] = 1
                    else:
                        feas_grid[state_i, state_j] = 1                    
                    
                    
                # Flatten the feasibility grid and assign to transition matrix
                T[start_state, act, :] = feas_grid.flatten()
        MDP.__init__(self, T, S, R, A, act_list)

In [6]:
test_rewards = [[i, j, -1] for i in range(5) for j in range(5)]
test_rewards[2] = [0, 2, 1]
test_rewards[23] = [4,3,1]
# test_rewards = [[0, 3, 5],
#                 [0, 1, 10]]
print test_rewards
gw = Grid_world(5, test_rewards)

[[0, 0, -1], [0, 1, -1], [0, 2, 1], [0, 3, -1], [0, 4, -1], [1, 0, -1], [1, 1, -1], [1, 2, -1], [1, 3, -1], [1, 4, -1], [2, 0, -1], [2, 1, -1], [2, 2, -1], [2, 3, -1], [2, 4, -1], [3, 0, -1], [3, 1, -1], [3, 2, -1], [3, 3, -1], [3, 4, -1], [4, 0, -1], [4, 1, -1], [4, 2, -1], [4, 3, 1], [4, 4, -1]]


In [274]:
def policy_iteration(mdp, gamma = 0.1):
    # Initialization
    V = [0]*mdp.S
    pol = [1]*mdp.S
    old_V = V
    n_iter = 0
    while True:
        # Policy evaluation
        while True:
            delta = 0.0
            old_V = V
            for s in range(mdp.S):
                v = V[s]
                V[s] = sum(mdp.T[s, pol[s], k] *
                       (mdp.R[k] + gamma * V[k])
                       for k in range(mdp.S))
                delta = max(delta, abs(v-V[s]))                
            if(delta < 0.001):
                break
        # Policy improvement
        policy_stable = True
        for s in range(mdp.S):
            old_action = pol[s]
            possibilities = [sum(mdp.T[s,a,k] *(mdp.R[k] + gamma * V[k]) for k in range(mdp.S)) for a in range(mdp.A)]
            pol[s] = max(enumerate(possibilities), key=itemgetter(1))[0]
            if(old_action != pol[s]):
                policy_stable = False
        if policy_stable:
            return V, pol, n_iter
        n_iter += 1

In [275]:
V, pol, n_iter = policy_iteration(gw)

In [276]:
# Now observe the obtained value and policy:
pprint.pprint(np.reshape(V,(5,5)))
pprint.pprint(np.reshape(pol,(5,5)))
print n_iter 

array([[-0.88888889,  1.11111111,  1.11111111,  1.11111111, -0.88888889],
       [-1.08888889, -0.88888889,  1.11111111, -0.88888889, -1.08888889],
       [-1.10888889, -1.08888889, -0.88888889, -0.88888889, -1.08888889],
       [-1.10888889, -1.08888889, -0.88888889,  1.11111111, -0.88888889],
       [-1.08888889, -0.88888889,  1.11111111,  1.11111111,  1.11111111]])
array([[1, 1, 2, 3, 3],
       [1, 1, 2, 2, 2],
       [1, 1, 2, 0, 0],
       [0, 0, 0, 0, 0],
       [1, 1, 1, 0, 3]])
3


In [48]:
def policy_iteration_by_inversion(mdp, gamma = 0.9):
    # Initialization
    V = [0]*mdp.S
    pol = [1]*mdp.S
    old_V = V
    n_iter = 0
    while True:
        # Policy evaluation
        V = [np.linalg.solve((np.eye(mdp.S, mdp.S) - gamma * mdp.T[:, pol[k], :]), mdp.R) for k in range(mdp.S)]
        V = V[0]
        # Policy improvement
        policy_stable = True
        for s in range(mdp.S):
            old_action = pol[s]
            possibilities = [sum(mdp.T[s,a,k] *(mdp.R[k] + gamma * V[k]) for k in range(mdp.S)) for a in range(mdp.A)]
            pol[s] = max(enumerate(possibilities), key=itemgetter(1))[0]
            if(old_action != pol[s]):
                policy_stable = False        
        if policy_stable:
            return V, pol, n_iter
        n_iter += 1

In [49]:
V_2, pol_2, n_iter2 = policy_iteration_by_inversion(gw)

In [50]:
# Now observe the obtained value and policy:
pprint.pprint(np.reshape(V_2,(5,5)))
pprint.pprint(np.reshape(pol_2,(5,5)))
print n_iter2

array([[ -8.38 ,  -8.2  ,  -8.   , -10.   , -10.   ],
       [-10.   , -10.   , -10.   , -10.   , -10.   ],
       [-10.   , -10.   , -10.   , -10.   , -10.   ],
       [-10.   , -10.   , -10.   , -10.   , -10.   ],
       [ -8.542,  -8.38 ,  -8.2  ,  -8.   , -10.   ]])
array([[1, 1, 2, 3, 0],
       [2, 2, 2, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [1, 1, 1, 0, 3]])
1


In [25]:
def modified_policy_iteration(mdp, epsilon = 0.1, gamma = 0.9, m = 20):
    # Initialization
    V = [0]*mdp.S
    pol = [1]*mdp.S
#     old_V = V
    n_iter = 0
    
    while True:
        u = []
        u.append([])
        
        # Policy improvement
        policy_stable = True
        for s in range(mdp.S):
            old_action = pol[s]
            possibilities = [sum(mdp.T[s,a,k] *(mdp.R[k] + gamma * V[k]) for k in range(mdp.S)) for a in range(mdp.A)]
            pol[s] = max(enumerate(possibilities), key=itemgetter(1))[0]
            u[0].append(max(possibilities))
            if old_action != pol[s]:
                policy_stable = False
#         if policy_stable:
#             return V, pol, n_iter
        u.append([])
        # Policy evaluation
        i = 0
        if np.linalg.norm(np.asarray(u[0]) - np.asarray(V)) < epsilon/(2*gamma):
            return V, pol, n_iter
        
        while True:
            if (i == m):
                V = u[i]
                n_iter+=1
                break
            else:
                for s in range(mdp.S):
                    u[i+1].append(sum(mdp.T[s, pol[s], k] * (mdp.R[k] + gamma * u[i][k]) for k in range(mdp.S)))
                    u.append([])

            if np.linalg.norm(np.asarray(u[i+1]) - np.asarray(u[i])) < epsilon/(2*gamma):
                V = u[i]
                n_iter+=1
                break
            else:
                i+=1
        n_iter += 1

In [26]:
V_3, pol_3, n_iter3 = modified_policy_iteration(gw)

In [27]:
# Now observe the obtained value and policy:
pprint.pprint(np.reshape(V_3,(5,5)))
pprint.pprint(np.reshape(pol_3,(5,5)))
print n_iter3

array([[ 7.89224736,  9.89224736,  9.89224736,  9.89224736,  7.89224736],
       [ 6.09224736,  7.89224736,  9.89224736,  7.89224736,  6.09224736],
       [ 4.47224736,  6.09224736,  7.89224736,  7.89224736,  6.09224736],
       [ 4.47224736,  6.09224736,  7.89224736,  9.89224736,  7.89224736],
       [ 6.09224736,  7.89224736,  9.89224736,  9.89224736,  9.89224736]])
array([[1, 1, 2, 3, 3],
       [1, 1, 2, 2, 2],
       [1, 1, 2, 0, 0],
       [0, 0, 0, 0, 0],
       [1, 1, 1, 0, 3]])
6
