In [None]:
import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

In [None]:
def train(env, actor_critic, optimizer, gamma=0.99, max_episodes=1000):
    for episode in range(max_episodes):
        state = env.reset()
        episode_reward = 0

        # 在每个episode开始时创建一个新的GradientTape
        with tf.GradientTape() as tape:
            while True:
                state_tensor = tf.expand_dims(tf.convert_to_tensor(state), 0)
                logits, value = actor_critic(state_tensor)
                action_probs = tf.nn.softmax(logits)
                action = np.random.choice(env.action_space.n, p=np.squeeze(action_probs.numpy()))
                next_state, reward, done, _ = env.step(action)
                episode_reward += reward

                next_state_tensor = tf.expand_dims(tf.convert_to_tensor(next_state), 0)
                _, next_value = actor_critic(next_state_tensor)
                td_target = reward + gamma * next_value * (1 - done)
                td_error = td_target - value

                # 计算actor和critic损失
                action_prob = tf.gather(action_probs[0], action)
                actor_loss = -tf.math.log(action_prob) * td_error
                critic_loss = td_error ** 2
                total_loss = actor_loss + critic_loss

                if done:
                    break
                else:
                    state = next_state

        # 计算梯度并更新参数
        grads = tape.gradient(total_loss, actor_critic.trainable_variables)
        optimizer.apply_gradients(zip(grads, actor_critic.trainable_variables))

        if episode % 10 == 0:
            print("Episode {}: Total Reward = {}".format(episode, episode_reward))

# 初始化环境和Actor-Critic模型
env = gym.make('CartPole-v1')
num_actions = env.action_space.n
actor_critic = ActorCritic(num_actions)
optimizer = Adam(learning_rate=0.01)

# 训练Actor-Critic模型
train(env, actor_critic, optimizer)

Episode 0: Total Reward = 31.0
Episode 10: Total Reward = 37.0
Episode 20: Total Reward = 33.0
Episode 30: Total Reward = 33.0
Episode 40: Total Reward = 19.0
Episode 50: Total Reward = 11.0
Episode 60: Total Reward = 9.0
Episode 70: Total Reward = 14.0
Episode 80: Total Reward = 19.0
Episode 90: Total Reward = 36.0
Episode 100: Total Reward = 18.0
Episode 110: Total Reward = 11.0
Episode 120: Total Reward = 14.0
Episode 130: Total Reward = 13.0
Episode 140: Total Reward = 11.0
Episode 150: Total Reward = 12.0
Episode 160: Total Reward = 10.0
Episode 170: Total Reward = 14.0
Episode 180: Total Reward = 9.0
Episode 190: Total Reward = 11.0
Episode 200: Total Reward = 13.0
Episode 210: Total Reward = 17.0
Episode 220: Total Reward = 11.0
Episode 230: Total Reward = 18.0
Episode 240: Total Reward = 9.0
Episode 250: Total Reward = 15.0
Episode 260: Total Reward = 16.0
Episode 270: Total Reward = 10.0
Episode 280: Total Reward = 11.0
Episode 290: Total Reward = 9.0
Episode 300: Total Reward