In [1]:
import copy

class CliffWalkingEnv:
    """Cliff Walking Environment"""
    def __init__(self, ncol=12, nrow=4):
        self.ncol = ncol
        self.nrow = nrow
        # state transition matrix P[state][action] = [(prob., next_state, reward, terminal)]
        self.P = self.create_P()

    def create_P(self):
        # initialize state transition matrix
        P = [[[] for j in range(4)] for i in range(self.ncol * self.nrow)]
        # change[0] = up, change[1] = down, change[2] = left, change[3] = right
        # [0, 0] is top-left corner
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
        for y in range(self.nrow):
            for x in range(self.ncol):
                for action in range(4):
                    # if in cliff area or goal state, each action reward is 0
                    if y == self.nrow - 1 and x > 0:
                        P[y * self.ncol + x][action] = [(1, y * self.ncol + x, 0, True)]
                        continue
                    # other area
                    next_x = min(self.ncol - 1, max(0, x + change[action][0]))
                    next_y = min(self.nrow - 1, max(0, y + change[action][1]))
                    next_state = next_y * self.ncol + next_x
                    reward = -1
                    terminal = False
                    # if next state is cliff area or goal state
                    if next_y == self.nrow - 1 and next_x > 0:
                        terminal = True
                        if next_x != self.ncol - 1: # cliff area
                            reward = -100
                    
                    P[y * self.ncol + x][action] = [(1, next_state, reward, terminal)]

        return P

In [2]:
class PolicyIteration:
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * self.env.ncol * self.env.nrow
        self.policy = [[0.25, 0.25, 0.25, 0.25] for _ in range(self.env.ncol * self.env.nrow)]
        self.theta = theta # threshold for stopping evaluation
        self.gamma = gamma # discount factor

    def policy_evaluation(self):
        cnt = 1 # count of evaluation
        while True:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for state in range(self.env.ncol * self.env.nrow):
                q_sa_list = [] # list of q value for each state-action pair under state
                for action in range(4):
                    q_sa = 0
                    for res in self.env.P[state][action]: # each possible outcome (there is only one in this env)
                        prob, next_state, reward, terminal = res
                        # $\P(a|s) * (r(s, a) + \gamma V(s') * (1 - termial))$ when terminal is True, V(s') cannot be added
                        q_sa += prob * (reward + self.gamma * self.v[next_state] * (1 - terminal))
                    q_sa_list.append(self.policy[state][action] * q_sa) # multiply by policy prob
                new_v[state] = sum(q_sa_list)
                max_diff = max(max_diff, abs(new_v[state] - self.v[state])) # record max diff of all states
            self.v = new_v
            if max_diff < self.theta: # reached threshold
                break
            cnt += 1
        print(f'Policy Evaluation converged in {cnt} iterations.')

    def policy_improvement(self):
        for state in range(self.env.ncol * self.env.nrow):
            q_sa_list = [] # list of q value for each state-action pair under state
            for action in range(4):
                q_sa = 0
                for res in self.env.P[state][action]:
                    prob, next_state, reward, terminal = res
                    q_sa += prob * (reward + self.gamma * self.v[next_state] * (1 - terminal))
                q_sa_list.append(q_sa)
            max_q = max(q_sa_list) # find max q value of all actions
            cnt_max_q = q_sa_list.count(max_q) # count how many actions have the max q value
            # average prob for all actions with max q value
            self.policy[state] = [1 / cnt_max_q if q == max_q else 0 for q in q_sa_list]
        print('Policy Improvement done.')

        return self.policy
    
    def policy_iteration(self):
        while True:
            self.policy_evaluation()
            old_policy = copy.deepcopy(self.policy)
            new_policy = self.policy_improvement()
            if old_policy == new_policy:
                break
        print('Policy Iteration done.')

In [6]:
def print_agent(agent: CliffWalkingEnv, action_meaning, disaster=[], end=[]):
    print('Value of States:')
    for y in range(agent.env.nrow):
        for x in range(agent.env.ncol):
            print(f"{agent.v[y * agent.env.ncol + x]:6.3f}", end=' ')
        print()

    print('\nPolicy:')
    for y in range(agent.env.nrow):
        for x in range(agent.env.ncol):
            if (y * agent.env.ncol + x) in disaster: # cliff
                print('****', end=' ')
            elif (y * agent.env.ncol + x) in end: # goal
                print('EEEE', end=' ')
            else:
                action = agent.policy[y * agent.env.ncol + x]
                policy_str = ''
                for k in range(len(action_meaning)):
                    policy_str += action_meaning[k] if action[k] > 0 else 'o'
                print(f'{policy_str}', end=' ')
        print()


env = CliffWalkingEnv()
agent = PolicyIteration(env, theta=0.01, gamma=0.9)
agent.policy_iteration()
print_agent(agent, ['^', 'v', '<', '>'], disaster=list(range(37, 47)), end=[47])


Policy Evaluation converged in 45 iterations.
Policy Improvement done.
Policy Evaluation converged in 50 iterations.
Policy Improvement done.
Policy Evaluation converged in 27 iterations.
Policy Improvement done.
Policy Evaluation converged in 12 iterations.
Policy Improvement done.
Policy Evaluation converged in 1 iterations.
Policy Improvement done.
Policy Iteration done.
Value of States:
-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 
-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 
-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 
-7.458  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 

Policy:
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 
