In [1]:
import gym
env = gym.make("FrozenLake-v1")  # 创建环境
env = env.unwrapped  # 解封装才能访问状态转移矩阵P
print(env.ncol)
print(env.nrow)
print(env.desc)

holes = set()
ends = set()
for s in env.P:
    for a in env.P[s]:
        for s_ in env.P[s][a]:
            # s_格式: (prob, next_state, reward, done)
            if s_[2] == 1.0:
                ends.add(s_[1])
            if s_[3] is True:
                holes.add(s_[1])
holes -= ends
print("冰洞的索引:", holes)
print("目标的索引:", ends)

for a in env.P[14]:
    print(env.P[14][a])

4
4
[[b'S' b'F' b'F' b'F']
 [b'F' b'H' b'F' b'H']
 [b'F' b'F' b'F' b'H']
 [b'H' b'F' b'F' b'G']]
冰洞的索引: {5, 7, 11, 12}
目标的索引: {15}
[(0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False)]
[(0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True)]
[(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False)]
[(0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False)]


In [45]:
class ValueIteration:
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * env.ncol * env.nrow
        self.theta = theta
        self.gamma = gamma
        # 初始化策略
        self.pi = [0] * env.ncol * env.nrow
    
    def value_iteration(self):
        cnt = 0
        while True:
            delta = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []
                for a in range(4):
                    qsa = 0
                    for s_ in self.env.P[s][a]:
                        qsa += s_[0] * (s_[2] + self.gamma * self.v[s_[1]] * (1 - s_[3]))
                    qsa_list.append(qsa)
                new_v[s] = max(qsa_list)
                delta = max(delta, abs(new_v[s] - self.v[s]))
            self.v = new_v
            if delta < self.theta:
                break
            cnt += 1
        print("迭代次数:", cnt)
        self.get_policy()
        
    def get_policy(self):
        for s in range(self.env.ncol * self.env.nrow):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for s_ in self.env.P[s][a]:
                    qsa += s_[0] * (s_[2] + self.gamma * self.v[s_[1]] * (1 - s_[3]))
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)
            # 平分最大值
            self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]
        print("策略:", self.pi)

In [46]:
def print_agent(agent, action_meaning, disaster=[], end=[]):
    print("状态价值：")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 为了输出美观,保持输出6个字符
            print('%6.6s' % ('%.3f' % agent.v[i * agent.env.ncol + j]), end=' ')
        print()

    print("策略：")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 一些特殊的状态,例如悬崖漫步中的悬崖
            if (i * agent.env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * agent.env.ncol + j) in end:  # 目标状态
                print('EEEE', end=' ')
            else:
                a = agent.pi[i * agent.env.ncol + j]
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()

In [48]:
action_meaning = ['<', 'v', '>', '^']
theta = 1e-5
gamma = 0.9
agent = ValueIteration(env, theta, gamma)
agent.value_iteration()
print_agent(agent, action_meaning, disaster=holes, end=ends)

迭代次数: 60
策略: [[1.0, 0, 0, 0], [0, 0, 0, 1.0], [1.0, 0, 0, 0], [0, 0, 0, 1.0], [1.0, 0, 0, 0], [0.25, 0.25, 0.25, 0.25], [0.5, 0, 0.5, 0], [0.25, 0.25, 0.25, 0.25], [0, 0, 0, 1.0], [0, 1.0, 0, 0], [1.0, 0, 0, 0], [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], [0, 0, 1.0, 0], [0, 1.0, 0, 0], [0.25, 0.25, 0.25, 0.25]]
状态价值：
 0.069  0.061  0.074  0.056 
 0.092  0.000  0.112  0.000 
 0.145  0.247  0.300  0.000 
 0.000  0.380  0.639  0.000 
策略：
<ooo ooo^ <ooo ooo^ 
<ooo **** <o>o **** 
ooo^ ovoo <ooo **** 
**** oo>o ovoo EEEE 
