In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import swanlab
import os
from Qtabularfunctions import*
from Cartpolefamily import*

In [2]:
def train_on_random_cartpole(category="medium", episodes=200, bins=[12, 12, 12, 12]):
    gen = CartPoleCategoryGenerator()
    env = gen.generate_env(category)

    agent = TabularQLearningAgent(bins=bins)

    rewards = []

    for ep in range(episodes):
        state, _ = env.reset()
        total_reward = 0

        while True:
            action = agent.choose_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            agent.update(state, action, reward, next_state, done)

            state = next_state
            total_reward += reward

            if done:
                break

        rewards.append(total_reward)

        if (ep + 1) % 20 == 0:
            print(f"Episode {ep+1}, reward = {total_reward}")

    return agent, rewards


In [3]:
agent, rewards = train_on_random_cartpole(
    category="hard",
    episodes=300
)

Episode 20, reward = 14.0
Episode 40, reward = 10.0
Episode 60, reward = 9.0
Episode 80, reward = 10.0
Episode 100, reward = 10.0
Episode 120, reward = 10.0
Episode 140, reward = 11.0
Episode 160, reward = 11.0
Episode 180, reward = 11.0
Episode 200, reward = 10.0
Episode 220, reward = 17.0
Episode 240, reward = 10.0
Episode 260, reward = 9.0
Episode 280, reward = 12.0
Episode 300, reward = 16.0


In [3]:
# ËÆæÁΩÆÈöèÊú∫Êï∞ÁßçÂ≠ê
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7accfc878070>

In [3]:
# ËÆ≠ÁªÉËøáÁ®ã
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)


# ÂàùÂßãÂåñSwanLabÊó•ÂøóËÆ∞ÂΩïÂô®
swanlab.init(
    project="RL-All-In-One",
    experiment_name="DQN-CartPole-v1",
    config={
        "state_dim": state_dim,
        "action_dim": action_dim,
        "batch_size": agent.batch_size,
        "gamma": agent.gamma,
        "epsilon": agent.epsilon,
        "update_target_freq": agent.update_target_freq,
        "replay_buffer_size": agent.replay_buffer.maxlen,
        "learning_rate": agent.optimizer.param_groups[0]['lr'],
        "episode": 600,
        "epsilon_start": 1.0,
        "epsilon_end": 0.01,
        "epsilon_decay": 0.995,
    },
    description="Â¢ûÂä†‰∫ÜÂàùÂßãÂåñÁõÆÊ†áÁΩëÁªúÂíåÂΩìÂâçÁΩëÁªú‰∏ÄËá¥ÔºåÈÅøÂÖçÁΩëÁªú‰∏ç‰∏ÄËá¥ÂØºËá¥ÁöÑËÆ≠ÁªÉÊ≥¢Âä®"
)

# ========== ËÆ≠ÁªÉÈò∂ÊÆµ ==========

agent.epsilon = swanlab.config["epsilon_start"]

for episode in range(swanlab.config["episode"]):
    state = env.reset()[0]
    total_reward = 0

    while True:
        action = agent.choose_action(state)
        next_state, reward, done, _, _ = env.step(action)
        agent.store_experience(state, action, reward, next_state, done)
        agent.train()

        total_reward += reward
        state = next_state
        if done or total_reward > 2e4:
            break

    # epsilonÊòØÊé¢Á¥¢Á≥ªÊï∞ÔºåÈöèÁùÄÊØè‰∏ÄËΩÆËÆ≠ÁªÉÔºåepsilon ÈÄêÊ∏êÂáèÂ∞è
    agent.epsilon = max(swanlab.config["epsilon_end"], agent.epsilon * swanlab.config["epsilon_decay"])  

    # ÊØè10‰∏™episodeËØÑ‰º∞‰∏ÄÊ¨°Ê®°Âûã
    if episode % 10 == 0:
        eval_env = gym.make('CartPole-v1')
        avg_reward = agent.evaluate(eval_env)
        eval_env.close()

        if avg_reward > agent.best_avg_reward:
            agent.best_avg_reward = avg_reward
            # Ê∑±Êã∑Ë¥ùÂΩìÂâçÊúÄ‰ºòÊ®°ÂûãÁöÑÂèÇÊï∞
            agent.best_net.load_state_dict({k: v.clone() for k, v in agent.q_net.state_dict().items()})
            agent.save_model(path=f"./output/best_model.pth")
            print(f"New best model saved with average reward: {avg_reward}")

    print(f"Episode: {episode}, Train Reward: {total_reward}, Best Eval Avg Reward: {agent.best_avg_reward}")

    swanlab.log(
        {
            "train/reward": total_reward,
            "eval/best_avg_reward": agent.best_avg_reward,
            "train/epsilon": agent.epsilon
        },
        step=episode,
    )

# ÊµãËØïÂπ∂ÂΩïÂà∂ËßÜÈ¢ë
agent.epsilon = 0  # ÂÖ≥Èó≠Êé¢Á¥¢Á≠ñÁï•
test_env = gym.make('CartPole-v1', render_mode='rgb_array')
# test_env = RecordVideo(test_env, "./dqn_videos", episode_trigger=lambda x: True)  # ‰øùÂ≠òÊâÄÊúâÊµãËØïÂõûÂêà
agent.q_net.load_state_dict(agent.best_net.state_dict())  # ‰ΩøÁî®ÊúÄ‰Ω≥Ê®°Âûã

for episode in range(3):  # ÂΩïÂà∂3‰∏™ÊµãËØïÂõûÂêà
    state = test_env.reset()[0]
    total_reward = 0
    steps = 0

    while True:
        action = agent.choose_action(state)
        next_state, reward, done, _, _ = test_env.step(action)
        total_reward += reward
        state = next_state
        steps += 1

        # ÈôêÂà∂ÊØè‰∏™episodeÊúÄÂ§ö1500Ê≠•,Á∫¶30Áßí,Èò≤Ê≠¢ÂΩïÂà∂Êó∂Èó¥ËøáÈïø
        if done or steps >= 1500:
            break

    print(f"Test Episode: {episode}, Reward: {total_reward}")

test_env.close()


Model saved to ./output/best_model.pth
New best model saved with average reward: 9.4
Episode: 0, Train Reward: 19.0, Best Eval Avg Reward: 9.4
Episode: 1, Train Reward: 16.0, Best Eval Avg Reward: 9.4
Episode: 2, Train Reward: 15.0, Best Eval Avg Reward: 9.4
Episode: 3, Train Reward: 31.0, Best Eval Avg Reward: 9.4
Episode: 4, Train Reward: 24.0, Best Eval Avg Reward: 9.4
Episode: 5, Train Reward: 21.0, Best Eval Avg Reward: 9.4
Episode: 6, Train Reward: 13.0, Best Eval Avg Reward: 9.4
Episode: 7, Train Reward: 18.0, Best Eval Avg Reward: 9.4
Episode: 8, Train Reward: 10.0, Best Eval Avg Reward: 9.4
Episode: 9, Train Reward: 15.0, Best Eval Avg Reward: 9.4
Model saved to ./output/best_model.pth
New best model saved with average reward: 14.2
Episode: 10, Train Reward: 18.0, Best Eval Avg Reward: 14.2
Episode: 11, Train Reward: 12.0, Best Eval Avg Reward: 14.2
Episode: 12, Train Reward: 33.0, Best Eval Avg Reward: 14.2
Episode: 13, Train Reward: 17.0, Best Eval Avg Reward: 14.2
Episode: 