构建一个简单的 3x3 网格世界。

状态 (States): 网格中的每一个格子都是一个状态。我们有 9 个状态，可以从 0 到 8 编号。

动作 (Actions): 在每个状态，智能体可以向四个方向移动：上 (up), 下 (down), 左 (left), 右 (right)。

转移概率 (Transition Probabilities): 在这个简单的例子中，我们假设转移是确定的。如果你选择一个动作，你就会 100% 到达那个方向的下一个状态。如果移动会撞墙（比如在最左边的格子选择“左”），你将停留在原地。

奖励 (Rewards):

移动到状态 8（右下角）会获得 +10 的奖励。

移动到状态 5（中右）会获得 -10 的惩罚。

所有其他的移动奖励都是 -0.1（一个小的时间成本，鼓励尽快结束）。

折扣因子 (Discount Factor) γ: 我们设为 0.9。

In [14]:
import numpy as np

# 定义环境参数
num_states = 9
num_actions = 4  # 0: up, 1: down, 2: left, 3: right

# P[s, a, s'] = p(s' | s, a)
# 转移概率 P[当前状态, 动作, 下一个状态]
P = np.zeros((num_states, num_actions, num_states))

# R[s, a] = E[r | s, a]
# 奖励 R[当前状态, 动作]
R = np.full((num_states, num_actions), -0.1)
R[7, 1] = 10  # 从状态7向下移动到状态8
R[2, 1] = 10  # 从状态2向下移动到状态5（这里为了让策略更有趣，我们假设从上面进入陷阱）
R[4, 1] = -10 # 从状态4向下移动到状态5
R[6, 1] = -10 # 从状态6向下移动到状态5

# 定义状态转移
# 如果撞墙，则停在原地
for s in range(num_states):
    row, col = s // 3, s % 3

    # 动作 0: up
    next_s = s if row == 0 else s - 3
    P[s, 0, next_s] = 1.0

    # 动作 1: down
    next_s = s if row == 2 else s + 3
    P[s, 1, next_s] = 1.0

    # 动作 2: left
    next_s = s if col == 0 else s - 1
    P[s, 2, next_s] = 1.0

    # 动作 3: right
    next_s = s if col == 2 else s + 1
    P[s, 3, next_s] = 1.0

# 特殊奖励状态设置
R[1, 3] = 10  # 从状态1向右移动到状态2，假设进入目标的前一步有高奖励
R[4, 3] = -10 # 从状态4向右移动到状态5（陷阱）
R[7, 3] = 10  # 从状态7向右移动到状态8（目标）

# 折扣因子
gamma = 0.9

# 收敛阈值
theta = 1e-6

Value Iteration 实现

value_iteration只有一个主循环。在每次循环中，对于每个状态 s，它直接计算所有动作的Q值，并用最大值来更新V(s) (V[s] = np.max(q_values)) 最后，当价值函数收敛后，再根据最终的 V 一次性地提取出最优策略。

In [4]:
def value_iteration(P, R, gamma, theta):
    """
    实现价值迭代算法

    参数:
        P: 转移概率矩阵 P[s, a, s']
        R: 奖励函数 R[s, a]
        gamma: 折扣因子
        theta: 策略评估的收敛阈值

    返回:
        policy: 最优策略
        V: 最优价值函数
    """
    # 1. Initialize
    V = np.zeros(num_states)
    iteration = 1

    print('\n Start Value Iteration')
    while True:
        delta = 0
        print(f'\n--- {iteration} iterations ---')
        for s in range(num_states):
            v_old = V[s]
            # Calculate q-value for a actions
            q_values = np.zeros(num_actions)
            for a  in range (num_actions):
                q_values[a] = R[s, a] + gamma * np.sum(P[s, a] * V)

            # Value update: V_{k+1}(s) = max_a q_k(s, a)
            V[s] = np.max(q_values)
            delta = max(delta, abs(v_old - V[s]))

        print(f'  Delta: {delta:.4f}')
        print(f'  Value Function:\n{V.reshape(3,3).round(2)}')

        if delta < theta:
            print('  Value Function converged.')
            break
        iteration += 1

    # 2. Extract optimal policy
    policy = np.zeros([num_states, num_actions])
    for s in range(num_states):
        q_values = np.zeros(num_actions)
        for a in range(num_actions):
            q_values[a] = R[s, a] + gamma * np.sum(P[s, a] * V)

        best_action = np.argmax(q_values)
        policy[s, best_action] = 1.0

    return policy, V

In [5]:
value_iteration(P, R, gamma, theta)


 Start Value Iteration

--- 1 iterations ---
  Delta: 10.0000
  Value Function:
[[-0.1 10.  10. ]
 [-0.1  8.9  8.9]
 [-0.1 10.   8.9]]

--- 2 iterations ---
  Delta: 9.0000
  Value Function:
[[ 8.9  19.   18.01]
 [ 7.91 17.   16.11]
 [ 8.9  19.   17.  ]]

--- 3 iterations ---
  Delta: 8.1000
  Value Function:
[[17.   26.21 24.5 ]
 [15.2  23.49 21.95]
 [17.   27.1  24.29]]

--- 4 iterations ---
  Delta: 7.2900
  Value Function:
[[23.49 32.05 29.75]
 [21.04 28.74 26.68]
 [24.29 34.39 30.85]]

--- 5 iterations ---
  Delta: 6.5610
  Value Function:
[[28.74 36.78 34.01]
 [25.77 33.   30.51]
 [30.85 40.95 36.76]]

--- 6 iterations ---
  Delta: 5.9049
  Value Function:
[[33.   40.61 37.46]
 [29.6  36.45 33.61]
 [36.76 46.86 42.07]]

--- 7 iterations ---
  Delta: 5.3144
  Value Function:
[[36.45 43.71 40.25]
 [32.98 39.24 37.76]
 [42.07 52.17 46.85]]

--- 8 iterations ---
  Delta: 4.7830
  Value Function:
[[39.24 46.23 43.99]
 [37.76 41.5  42.07]
 [46.85 56.95 51.16]]

--- 9 iterations ---
  

(array([[0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]]),
 array([75.91048192, 84.45609192, 82.72899192, 80.80999102, 79.99999102,
        80.80999192, 89.89999102, 99.99999102, 89.89999192]))

Policy Iteration 实现

policy_iteration有一个外层循环（while not policy_stable）来迭代策略，以及一个内层循环（while True... if delta < theta）来进行策略评估直到收敛。

In [2]:
def policy_iteration(P, R, gamma, theta):
    """
    实现策略迭代算法

    参数:
        P: 转移概率矩阵 P[s, a, s']
        R: 奖励函数 R[s, a]
        gamma: 折扣因子
        theta: 策略评估的收敛阈值

    返回:
        policy: 最优策略
        V: 最优价值函数
    """
    # 1. Initialize
    # initialize a stochastic policy, each action has equal probability
    policy = np.ones([num_states, num_actions]) / num_actions
    V = np.zeros(num_states)

    policy_stable = False
    iteration = 1

    print('\n Start Policy Iteration')
    while not policy_stable:
        print(f'\n--- {iteration} iteration ---')

        # 2. Policy Evaluation
        while True:
            delta = 0
            for s in range(num_states):
                v_old = V[s]
                v_new = 0
                for a in range(num_actions):
                    # v_new += π(a|s) * E[R_{t+1} + γ * V(S_{t+1}) | S_t=s, A_t=a]
                    v_new += policy[s, a] * (R[s, a] + gamma * np.sum(P[s, a] * V))
                V[s] = v_new
                delta = max(delta, abs(v_old - V[s]))

            if delta < theta:
                print('  Value Function converged.')
                break

        # 3. Policy Improvement
        policy_stable = True
        old_policy = policy.copy()

        for s in range(num_states):
            q_values = np.zeros(num_actions)
            for a in range(num_actions):
                q_values[a] = R[s, a] + gamma * np.sum(P[s, a] * V)

            best_action = np.argmax(q_values)
            if np.argmax(old_policy[s]) != best_action:
                policy_stable = False

            # Update policy
            new_policy_s = np.zeros(num_actions)
            new_policy_s[best_action] = 1.0
            policy[s] = new_policy_s

        if policy_stable:
            print('  Policy converged.')
        else:
            print('  Policy not converged, continue to improve policy.')

        iteration += 1

    return policy, V

In [3]:
policy_iteration(P, R, gamma, theta)


 Start Policy Iteration

--- 1 iteration ---
  Value Function converged.
  Policy not converged, continue to improve policy.

--- 2 iteration ---
  Value Function converged.
  Policy not converged, continue to improve policy.

--- 3 iteration ---
  Value Function converged.
  Policy converged.


(array([[0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]]),
 array([75.9104855, 84.4560955, 82.7289955, 80.809995 , 79.999995 ,
        80.8099955, 89.899995 , 99.999995 , 89.8999955]))

Truncated Policy Iteration 实现

truncated_policy_iteration的关键区别在于策略评估部分，它不再等待 delta < theta，而是简单地循环固定的 j_truncate 次数。注意 v_k^(0) = v_{k-1} 的初始化，在我们的代码中，V 是在主循环之外定义的，并且在迭代中不断被更新，这自然地实现了将上一次策略改进后的价值函数作为下一次评估的初始值。

In [18]:
def truncated_policy_iteration(P, R, gamma, j_truncate=5):
    """
    实现截断策略迭代算法

    参数:
        P: 转移概率矩阵 P[s, a, s']
        R: 奖励函数 R[s, a]
        gamma: 折扣因子
        j_truncate: 策略评估的迭代次数

    返回:
        policy: 最优策略
        V: 最优价值函数
    """
    # 1. Initialize
    policy = np.ones([num_states, num_actions]) / num_actions
    V = np.zeros(num_states)

    iteration = 1

    print('\n Start Truncated Policy Iteration (j_truncate={})'.format(j_truncate))
    while True:
        print(f'\n--- {iteration} iteration ---')

        # 2. Policy Evaluation
        for _ in range(j_truncate):
            for s in range(num_states):
                v_new = 0
                for a in range(num_actions):
                    # v_new += π(a|s) * E[R_{t+1} + γ * V(S_{t+1}) | S_t=s, A_t=a]
                    v_new += policy[s, a] * (R[s, a] + gamma * np.sum(P[s, a] * V))
                V[s] = v_new
        #print(f'  Value Function:\n{V.reshape(3,3).round(2)}')

        # 3. Policy Improvement
        new_policy = np.zeros_like(policy)
        q_values_all = np.zeros((num_states, num_actions))

        for s in range(num_states):
            for a in range(num_actions):
                q_values_all[s, a] = R[s, a] + gamma * np.sum(P[s, a] * V)

            best_action = np.argmax(q_values_all[s])
            new_policy[s, best_action] = 1.0

        if np.array_equal(policy, new_policy):
            print('  Policy converged.')
            break

        policy = new_policy
        print('Policy Updated')
        iteration += 1
        print(f'  Policy:\n{policy.reshape(3, 3, num_actions).round(2)}')

        if iteration > 100:
            print('  Too many iterations, stopping.')
            break

    return policy, V


In [19]:
truncated_policy_iteration(P, R, gamma, j_truncate=5)


 Start Truncated Policy Iteration (j_truncate=5)

--- 1 iteration ---
Policy Updated
  Policy:
[[[0. 0. 0. 1.]
  [0. 0. 0. 1.]
  [0. 1. 0. 0.]]

 [[1. 0. 0. 0.]
  [1. 0. 0. 0.]
  [1. 0. 0. 0.]]

 [[0. 0. 0. 1.]
  [0. 1. 0. 0.]
  [0. 0. 1. 0.]]]

--- 2 iteration ---
Policy Updated
  Policy:
[[[0. 0. 0. 1.]
  [0. 0. 0. 1.]
  [0. 1. 0. 0.]]

 [[0. 1. 0. 0.]
  [1. 0. 0. 0.]
  [0. 1. 0. 0.]]

 [[0. 0. 0. 1.]
  [0. 1. 0. 0.]
  [0. 0. 1. 0.]]]

--- 3 iteration ---
Policy Updated
  Policy:
[[[0. 0. 0. 1.]
  [0. 0. 0. 1.]
  [0. 1. 0. 0.]]

 [[0. 1. 0. 0.]
  [0. 1. 0. 0.]
  [0. 1. 0. 0.]]

 [[0. 0. 0. 1.]
  [0. 1. 0. 0.]
  [0. 0. 1. 0.]]]

--- 4 iteration ---
  Policy converged.


(array([[0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]]),
 array([58.26219067, 66.80780067, 65.08070067, 61.20077852, 60.39077852,
        63.16170067, 70.29077852, 80.39077852, 72.25170067]))