In [2]:
# Cliff Walking

import copy

class CliffWalking:

    def __init__(self, ncol, nrow):
        self.ncol = ncol
        self.nrow = nrow
        self.P =self.createP()

    def createP(self):
        P = [[ [] for j in range(4)] for i in range(self.nrow*self.ncol)]
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(4):
                    if  i == self.nrow-1 and j>0:
                        P[i*self.ncol+j][a] = [(1,i*self.ncol+j,0,True)]
                        continue

                    next_x = min(self.ncol-1, max(0, j+change[a][0]))
                    next_y = min(self.nrow-1, max(0, i+change[a][1]))
                    next_state = next_y*self.ncol + next_x
                    reward = -1
                    done = False
                    if next_y == self.nrow-1 and next_x>0:
                        done = True
                        if next_x == self.ncol-1:
                            # reward = 0
                            pass
                        else:
                            reward = -100
                    P[i*self.ncol+j][a] = [(1, next_state, reward, done)]
        return P

In [3]:
class PolicyIteration:
    def __init__(self, env, theta, gamma):
        self.env = env
        self.theta = theta
        self.gamma = gamma
        self.v = [0 for i in range(env.ncol * env.nrow)]
        self.pi = [[0.25, 0.25, 0.25, 0.25] for i in range(env.ncol * env.nrow)]

    def policy_evaluation(self):
        cnt = 1
        while True:
            max_diff = 0
            new_v = [0 for i in range(self.env.ncol * self.env.nrow)]
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []
                for a in range(4):
                    qsa = 0
                    for prob, next_state, reward, done in self.env.P[s][a]:
                        qsa +=  prob * (reward + self.gamma * self.v[next_state]*(not done))
                    qsa_list.append(qsa*self.pi[s][a])
                new_v[s] = sum(qsa_list)
                max_diff = max(max_diff, abs(new_v[s]-self.v[s]))
            self.v = new_v
            if max_diff < self.theta: break
            cnt+=1
        print("策略评估进行了{}次".format(cnt))

    def policy_improvement(self):
        for s in range(self.env.ncol * self.env.nrow):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for prob, next_state, reward, done in self.env.P[s][a]:
                    qsa +=  prob * (reward + self.gamma * self.v[next_state]*(not done))
                qsa_list.append(qsa)
            max_qsa = max(qsa_list)
            max_a = [i for i, x in enumerate(qsa_list) if x == max_qsa]
            self.pi[s] = [0 for i in range(4)]
            for a in max_a:
                self.pi[s][a] = 1/len(max_a)
        print("策略改进完成")
        return self.pi

    def policy_iteration(self):
        while True:
            self.policy_evaluation()
            pi_old = copy.deepcopy(self.pi)
            self.policy_improvement()
            if pi_old == self.pi: break


In [4]:
def print_agent(agent, action_meaning, disaster=[], end=0):
    print("状态价值： ")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            print('%6.6s' % ('%.3f' % agent.v[i*agent.env.ncol+j]), end=' ')
        print()

    print("策略： ")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            if i*agent.env.ncol+j in disaster:
                print("****", end=' ')
            elif i*agent.env.ncol+j in end:
                print("EEEE", end=' ')
            else:
                a = agent.pi[i*agent.env.ncol+j]
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()

env = CliffWalking(12, 4)
action_meaning = ['^', 'v', '<', '>']
theta = 0.001
gamma = 0.9
agent = PolicyIteration(env, theta, gamma)
agent.policy_iteration()
print_agent(agent, action_meaning, list(range(37, 47)), [47])

策略评估进行了60次
策略改进完成
策略评估进行了72次
策略改进完成
策略评估进行了44次
策略改进完成
策略评估进行了12次
策略改进完成
策略评估进行了1次
策略改进完成
状态价值： 
-7.712 -7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 
-7.458 -7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 
-7.176 -6.862 -6.513 -6.126 -5.695 -5.217 -4.686 -4.095 -3.439 -2.710 -1.900 -1.000 
-7.458  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000  0.000 
策略： 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovo> ovoo 
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 
