In [None]:
import gymnasium as gym

返回值：速度、位置、杆角度、尖端速度

In [None]:
# 初始化环境
env = gym.make("CartPole-v1", render_mode="human")
observation, info = env.reset(seed=42)

observation

In [None]:
done = False

while not done:
    next_state, reward, x, y, _ = env.step(env.action_space.sample())
    done = x or y

next_state, reward

In [None]:
env.close()

In [None]:
import torch
from torch import nn
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from collections import deque, namedtuple
import random

from numpy.typing import NDArray
from typing import NamedTuple, TypedDict

In [None]:
class DeepQNetwork(nn.Module):
    def __init__(self, num_states: int, num_actions: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(num_states, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, num_actions)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


In [None]:
q_net = DeepQNetwork(4, 2)
target_q_net = DeepQNetwork(4, 2)
target_q_net.eval()

In [None]:
class Transition(NamedTuple):
    state: NDArray[np.float64]
    action: int
    reward: float
    next_state: NDArray[np.float64]
    done: bool

class SampleBatch(TypedDict):
    state: NDArray[np.float64]      # shape: (batch_size, ...)
    action: NDArray[np.int64]       # shape: (batch_size,)
    reward: NDArray[np.float64]     # shape: (batch_size,)
    next_state: NDArray[np.float64] # shape: (batch_size, ...)
    done: NDArray[np.bool_]

class ReplayBuffer:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.buffer: deque[Transition] = deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        if len(self.buffer) >= self.capacity:
            self.buffer.popleft()
        self.buffer.append(Transition(state, action, reward, next_state, done))

    def sample(self, batch_size: int) -> SampleBatch:
        indices = np.random.choice(len(self.buffer), batch_size)
        batch = [self.buffer[i] for i in indices]
        
        return SampleBatch(
            state=np.stack([t.state for t in batch], axis=0),
            action=np.array([t.action for t in batch], dtype=np.int64),
            reward=np.array([t.reward for t in batch], dtype=np.float64),
            next_state=np.stack([t.next_state for t in batch], axis=0),
            done=np.array([t.done for t in batch], dtype=np.bool_)
        )

    def __len__(self):
        return len(self.buffer)


In [None]:
num_episodes = 2000
update_period = 10
metric_period = 2
buffer_min = (1 << 10)
buffer_max = (1 << 15)
max_explore = 10000
gamma = 0.99
epsilon = 0.01
lr = 1e-3
device = "cuda"
batch_size = 128

In [None]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

q_net.to(device)
target_q_net.to(device)
env = gym.make("CartPole-v1")
train_dataloader: DataLoader
optimizer = torch.optim.Adam(q_net.parameters(), lr=lr)
replay_buffer = ReplayBuffer(buffer_max)

metric_sum = 0.
metric_count = 0
avg_loss = float('inf')
return_value = 0

with tqdm(range(num_episodes)) as pbar:
    for episode in pbar:

        if episode % update_period == 0:
            target_q_net.load_state_dict(q_net.state_dict())
        
        # interact with the environment
        state, info = env.reset(seed=random.randint(0, 1 << 32 - 1))
        for _ in range(max_explore):
            # epsilon greedy
            if random.random() < epsilon:
                action = int(env.action_space.sample())
            else:
                with torch.no_grad():
                    action = torch.argmax(q_net(
                        torch.FloatTensor(state).to(device)
                    )).item()
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            replay_buffer.add(state, action, reward, next_state, done)
            if done:
                break
            state = next_state
        
        if len(replay_buffer) < buffer_min:
            continue
        
        optimizer.zero_grad()
        batch = replay_buffer.sample(batch_size)
        states = torch.tensor(batch["state"], dtype=torch.float, device=device)
        next_states = torch.tensor(batch["next_state"], dtype=torch.float, device=device)
        actions = torch.tensor(batch["action"], dtype=torch.int64, device=device)
        dones = torch.tensor(batch["done"], dtype=torch.float, device=device)
        rewards = torch.tensor(batch["reward"], dtype=torch.float, device=device)

        with torch.no_grad():
            q_target = rewards + gamma * (1 - dones) * torch.max(target_q_net(next_states), dim=-1).values
        
        q_pred = torch.gather(
            q_net(states), 1, actions.unsqueeze(1)
        ).squeeze(1)

        loss = F.mse_loss(q_pred, q_target)
        loss.backward()
        nn.utils.clip_grad_value_(q_net.parameters(), 1)
        optimizer.step()
        metric_sum += loss.item()
        metric_count += 1

        if episode % metric_period == 0:
            avg_loss = metric_sum / metric_count
            state, info = env.reset(seed=42)
            
            q_net.to("cpu")
            return_value = 0
            done = False
            while not done:
                with torch.no_grad():
                    action = torch.argmax(q_net(
                        torch.FloatTensor(state)
                    )).item()
                state, reward, terminated, truncated, info = env.step(action)
                return_value += 1
                done = terminated or truncated

            q_net.to(device)

            metric_sum = 0.
            metric_count = 0

        pbar.set_postfix({"eps": epsilon, "loss": avg_loss, "return": return_value})

env.close()

In [None]:
q_net(states).shape

In [None]:
# 初始化环境
target_q_net.to("cpu")
env = gym.make("CartPole-v1", render_mode="human")
state, info = env.reset(seed=42)

# 运行 1000 个时间步
ret = 0

for _ in range(1000):
    action = np.argmax(target_q_net(torch.tensor(state)).detach().numpy())
    state, reward, terminated, truncated, info = env.step(action)
    ret += 1
    
    # 若终止或截断，重置环境
    if terminated or truncated:
        print(f"Return: {ret}")
        ret = 0
        break
        # observation, info = env.reset()


In [None]:
env.close()