<a href="https://colab.research.google.com/github/Pluviophile-1/MMO_LAB/blob/main/MMO_lb4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [36]:
!pip install gym matplotlib numpy



In [40]:
import gym
import numpy as np

# 创建环境
env = gym.make('Taxi-v3').env

# 初始化策略、价值函数等参数
num_states = env.observation_space.n
num_actions = env.action_space.n
gamma = 0.9  # 折扣因子
theta = 1e-8  # 收敛阈值
policy = np.ones([num_states, num_actions]) / num_actions  # 初始随机策略

# 策略评估函数
def policy_evaluation(policy, env, gamma, theta):
    num_states = env.observation_space.n
    num_actions = env.action_space.n
    V = np.zeros(num_states)  # 初始化价值函数
    while True:
        delta = 0
        for state in range(num_states):
            v = 0
            for action in range(num_actions):
                prob = policy[state][action]
                for next_sr in env.env.P[state][action]:
                    next_prob, next_state, reward, done = next_sr
                    v += prob * next_prob * (reward + gamma * V[next_state])
            delta = max(delta, abs(v - V[state]))
            V[state] = v
        if delta < theta:
            break
    return V

# 策略改进函数
def policy_improvement(policy, V, env, gamma):
    num_states = env.observation_space.n
    num_actions = env.action_space.n
    policy_stable = True
    for state in range(num_states):
        old_action = np.argmax(policy[state])
        action_values = np.zeros(num_actions)
        for action in range(num_actions):
            for next_sr in env.env.P[state][action]:
                prob, next_state, reward, done = next_sr
                action_values[action] += prob * (reward + gamma * V[next_state])
        best_action = np.argmax(action_values)
        policy[state] = np.eye(num_actions)[best_action]
        if old_action != best_action:
            policy_stable = False
    return policy, policy_stable

# 策略迭代算法
def policy_iteration(env, gamma, theta):
    policy = np.ones([env.observation_space.n, env.action_space.n]) / env.action_space.n  # 初始化策略
    while True:
        V = policy_evaluation(policy, env, gamma, theta)  # 策略评估
        policy, policy_stable = policy_improvement(policy, V, env, gamma)  # 策略改进
        if policy_stable:
            break
    return policy, V

# 执行策略迭代算法
final_policy, final_value = policy_iteration(env, gamma, theta)

# 打印策略和价值函数
print("最终策略：")
print(final_policy)
print("\n最终价值函数：")
print(final_value)


最终策略：
[[0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0.]
 ...
 [0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]]

最终价值函数：
[ 89.47368419  32.82015928  55.26468418  37.57795475   8.43267449
  32.82015928   8.43267448  15.28487583  32.82015929  18.0943065
  55.26468418  21.2158961   12.75638827  18.0943065   12.75638826
  37.57795475 100.52631577  37.57795476  62.51631576  42.86439418
  79.52631577  28.53814335  48.73821576  32.82015928  10.48074944
  37.57795476  10.48074943  18.09430649  28.53814336  15.28487585
  48.73821576  18.09430649  15.28487586  21.21589611  15.28487585
  42.86439418  89.4736842   42.86439419  55.26468419  48.73821576
  42.8643942   12.75638826  24.68432902  15.28487584  24.68432903
  70.57368419  24.68432902  37.57795476  24.68432903  12.75638826
  42.86439419  15.28487584  18.09430651  24.68432902  18.0943065
  48.73821576  48.73821578  79.52631577  48.73821577  55.26468418
  37.57795478  10.48074944  21.21589612  12.75638826  28.538143

In [42]:
# 测试策略效果的示例函数（可选）
def test_policy(policy, env, num_episodes=100):
    total_rewards = []
    for episode in range(num_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        while not done:
            action = np.argmax(policy[state])
            state, reward, done, info = env.step(action)
            total_reward += reward
        total_rewards.append(total_reward)
    average_reward = np.mean(total_rewards)
    return average_reward

# 测试策略
average_reward = test_policy(final_policy, env)
print(f"\n测试策略的平均奖励：{average_reward}")


测试策略的平均奖励：7.84
