In [1]:
import copy

class CliffWalkingEnv:
    
    def __init__(self, nrow=4, ncol=12):
        self.ncol = ncol
        self.nrow = nrow
        # Transition matrix : P[state][action] = (probability, next_state, reward, done)
        self.P = self.createP()

    def createP(self):
        P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]
        movement = [(0, -1), (0, 1), (-1, 0), (1, 0)]
        
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(4):
                    if i == self.nrow - 1 and j > 0:
                        P[i * self.ncol + j][a] = [(1, i * self.ncol + j, 0, True)]
                        continue
                    nextX = min(self.ncol - 1, max(0, j + movement[a][0]))
                    nextY = min(self.nrow - 1, max(0, i + movement[a][1]))
                    next_state = nextY * self.ncol + nextX
                    reward = -1
                    done = False
                    if nextY == self.nrow - 1 and nextX > 0:
                        done = True
                        if nextX != self.ncol - 1:
                            reward = -100
                    P[i * self.ncol + j][a] = [(1, next_state, reward, done)]
        return P

In [2]:
class PolicyIteration:
    
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * self.env.nrow * self.env.ncol
        self.pi = [[0.25, 0.25, 0.25, 0.25] for _ in range(self.env.nrow * self.env.ncol)]
        self.theta = theta
        self.gamma = gamma
        
    def policy_evaluation(self):
        cnt = 0
        while True:
            max_diff = 0
            new_v = [0] * self.env.nrow * self.env.ncol # Initial state value list
            for s in range(self.env.nrow * self.env.ncol):
                qsa_list = []
                for a in range(4):
                    qsa = 0
                    for res in self.env.P[s][a]:
                        prob, next_state, reward, done = res # prob refer to the transition probability to next_state
                        qsa += prob * (reward + self.gamma * self.v[next_state] * (1 - done))
                    qsa_list.append(self.pi[s][a] * qsa)
                new_v[s] = sum(qsa_list)
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            self.v = new_v 
            if max_diff < self.theta:
                break
            cnt += 1
            
    def policy_improvement(self):
        for s in range(self.env.nrow * self.env.ncol):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for res in self.env.P[s][a]:
                    prob, next_state, reward, done = res
                    qsa += prob * (reward + self.gamma * self.v[next_state] * (1 - done))
                qsa_list.append(qsa)
            max_qsa = max(qsa_list)
            cnt_max_qsa = qsa_list.count(max_qsa)
            self.pi[s] = [1 / cnt_max_qsa if q == max_qsa else 0 for q in qsa_list]
        return self.pi
    
    def policy_iteration(self):
        while True:
            self.policy_evaluation()
            old_pi = copy.deepcopy(self.pi)
            new_pi = self.policy_improvement()
            if old_pi == new_pi:
                break

In [3]:
class ValueIteration:
    
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * self.env.nrow * self.env.ncol
        self.gamma = gamma
        self.theta = theta
        self.pi = [None for i in range(self.env.nrow * self.env.ncol)]
        
    def value_policy_update(self):
        cnt = 0
        while True:
            max_diff = 0
            new_v = [0] * self.env.nrow * self.env.ncol
            for s in range(self.env.nrow * self.env.ncol):
                qsa_list = []
                for a in range(4):
                    qsa = 0
                    for res in self.env.P[s][a]:
                        prob, next_state, reward, done = res
                        qsa += prob * (reward + self.gamma * self.v[next_state] * (1 - done))
                    qsa_list.append(qsa)
                new_v[s] = max(qsa_list)
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            self.v = new_v
            if max_diff < self.theta:
                break
            cnt += 1
        self.get_policy()

    def get_policy(self):
        for s in range(self.env.nrow * self.env.ncol):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for res in self.env.P[s][a]:
                    prob, next_state, reward, done = res
                    qsa += prob * (reward + self.gamma * self.v[next_state] * (1 - done))
                qsa_list.append(qsa)
            max_qsa = max(qsa_list)
            cnt_max_qsa = qsa_list.count(max_qsa)
            self.pi[s] = [1 / cnt_max_qsa if q == max_qsa else 0 for q in qsa_list]

In [4]:
def print_agent(agent, action_meaning, disaster=[], end=[]):
    print("State Value:")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            print('%6.6s' % ('%.3f' % agent.v[i * agent.env.ncol + j]), end=' ')
        print()
    print('\n')
    
    print("Strategy:")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            if (i * agent.env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * agent.env.ncol + j) in end:
                print('EEEE', end=' ')
            else:
                a = agent.pi[i * agent.env.ncol + j]
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()
    print('\n')

In [5]:
env = CliffWalkingEnv()
action_meaning = ['^', 'v', '<', '>']
theta = 0.001
gamma = 0.9
disaster = list(range(37, 47))
end = [47]
iter1 = PolicyIteration(env, theta, gamma)
iter1.policy_iteration()
print_agent(iter1, action_meaning, disaster, end)
iter2 = ValueIteration(env, theta, gamma)
iter2.value_policy_update()
print_agent(iter2, action_meaning, disaster, end)

State Value:
-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 
-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 
-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 
-7.458  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 


Strategy:
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 


State Value:
-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 
-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 
-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 
-7.458  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 


Strategy:
ovo> ovo> ovo> ovo> ovo>