In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from collections import deque
import random


In [48]:
# 定义Q网络
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )

    def forward(self, state):
        return self.fc(state)

In [49]:
# 创建环境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
print(state_dim, action_dim)

4 2


In [50]:
# 创建网络
online_net = QNetwork(state_dim, action_dim)
target_net = QNetwork(state_dim, action_dim)
target_net.load_state_dict(online_net.state_dict())

<All keys matched successfully>

In [51]:
# 创建优化器
optimizer = optim.Adam(online_net.parameters())

In [52]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
online_net.to(device)
target_net.to(device)

cuda


QNetwork(
  (fc): Sequential(
    (0): Linear(in_features=4, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=2, bias=True)
  )
)

In [53]:
# 创建经验回放缓冲区
replay_buffer = deque(maxlen=10000)

In [54]:
# 设置超参数
epsilon = 1.0  # 探索率
epsilon_decay = 0.995  # 探索率衰减
min_epsilon = 0.01  # 最小探索率
gamma = 0.99  # 折扣因子
batch_size = 64  # 批大小
update_target_every = 100  # 更新目标网络的频率
max_steps = 10000  # 最大步数

In [55]:
# 训练过程
for step in range(max_steps):
    # 选择动作
    state, _ = env.reset()
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
    # epsilon-greedy策略
    if np.random.rand() < epsilon:
        action = env.action_space.sample()  # 探索
    else:
        with torch.no_grad():
            action = torch.argmax(online_net(state)).item()  # 利用

    # 执行动作并存储转移
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated  # 合并终止和截断条件
    next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).to(device)
    reward = torch.tensor([reward], dtype=torch.float32).to(device)
    replay_buffer.append((state, action, reward, next_state, done))
    state = next_state

    # 学习
    if len(replay_buffer) >= batch_size:
        minibatch = random.sample(replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*minibatch)
        states = torch.cat(states).to(device)
        actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1).to(device)
        rewards = torch.cat(rewards).to(device)
        next_states = torch.cat(next_states).to(device)
        dones = torch.tensor(dones, dtype=torch.float32).to(device)

        q_values = online_net(states).gather(1, actions)
        with torch.no_grad():
            max_next_q_values = target_net(next_states).max(1)[0]
            target_q_values = rewards + gamma * (1 - dones) * max_next_q_values

        loss = nn.functional.mse_loss(q_values, target_q_values.unsqueeze(1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 更新目标网络
        if step % update_target_every == 0:
            target_net.load_state_dict(online_net.state_dict())

    # 更新探索率
    epsilon = max(min_epsilon, epsilon * epsilon_decay)

    # 检查是否完成
    if done:
        break

  if not isinstance(terminated, (bool, np.bool8)):
