In [1]:
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import swanlab
import os

In [2]:
class QNetwork(nn.Module):
    def __init__(self, state_dim,action_dim):
        super().__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(in_features=state_dim,out_features=64),
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64,action_dim)
        )
    def forward(self,x):
        return self.fc1(x)

In [3]:
class DQNAgent:
    def __init__(self,state_dim,action_dim):
        self.q_net = QNetwork(state_dim,action_dim)
        self.target_qnet = QNetwork(state_dim,action_dim)
        self.target_qnet.load_state_dict(self.q_net.state_dict())
        self.best_net =  QNetwork(state_dim,action_dim)
        self.optimizer = optim.Adam(params=self.q_net.parameters(),lr=1e-3)
        # 双端队列
        self.repaly_buffer = deque(maxlen=10000) 
        self.batch_size = 64
        self.gamma = 0.99
        self.epsilon = 0.1
        self.updata_target_freq = 100
        self.step_count = 0
        self.best_reward = 0
        self.best_avg_reward = 0
        self.eval_episodes = 5
    def choose_action(self,state):
        if np.random.rand() < self.epsilon:
            return np.random.randint(0, 2)  # CartPole有2个动作（左/右）
        else:
            state_tensor = torch.FloatTensor(state)
            outputs = self.q_net(state_tensor)
            action = outputs.cpu().detach().numpy().argmax()
            return action
        
    def store_experience(self,state,action,reward,next_state,done):
        self.repaly_buffer.append((state,action,reward,next_state,done))
    
    def train(self):
        if len(self.repaly_buffer) < self.batch_size:
            return

        batch = random.sample(population=self.repaly_buffer,k=self.batch_size)
        state,action,reward,next_state,done = zip(*batch)

        state = torch.FloatTensor(np.array(state))
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        next_state = torch.FloatTensor(np.array(next_state))
        done = torch.FloatTensor(done)

        current_q = self.q_net(state).gather(1,action.unsqueeze(1)).squeeze()

        with torch.no_grad():
            next_q = self.target_qnet(next_state).max(1)[0]
            target_q = reward + self.gamma * next_q  *(1-done)
        loss = nn.MSELoss()(current_q,target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.step_count +=1
        if self.step_count % self.updata_target_freq == 0:
            self.target_qnet.load_state_dict({
                k: v.clone() for k, v in self.q_net.state_dict().items()
            })
    
    def save_model(self,path):
        if not os.path.exists("./output"):
            os.makedirs("./output")
        torch.save(self.q_net.state_dict(),f=path)
        print("model save success")

    def evaluate(self,env):
        original_epsilon = self.epsilon
        self.epsilon = 0
        totoal_rewards = []

        for _ in range(self.eval_episodes):
            state = env.reset()[0]
            episode_reward = 0
            while True:
                action = self.choose_action(state)
                next_state,reward,done,_,_ = env.step(action)
                episode_reward +=reward
                state = next_state
                if done or episode_reward > 2e4:
                    break
            totoal_rewards.append(episode_reward)

        self.epsilon = original_epsilon
        return np.mean(totoal_rewards)


In [4]:
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)

In [5]:
agent.epsilon = 1
for episode in range(600):
    state = env.reset()[0]

    total_reward = 0

    while True:
        action = agent.choose_action(state)
        next_state , reward, done,_,_, = env.step(action=action)
        agent.store_experience(state,action,reward,next_state,done)
        agent.train()
        total_reward +=reward
        state = next_state
        if done or total_reward > 2e4:
            break
    agent.epsilon = max(0.01,agent.epsilon*0.995)

    if episode % 10 ==0:
        eval_env = gym.make('CartPole-v1')
        avg_reward = agent.evaluate(eval_env)
        eval_env.close()
        if avg_reward > agent.best_avg_reward:
            agent.best_avg_reward = avg_reward
            agent.best_net.load_state_dict({k: v.clone() for k, v in agent.q_net.state_dict().items()})
            agent.save_model(path=f"./output/best_model.pth")
            print(f"New best model saved with average reward: {avg_reward}")

    print(f"Episode: {episode}, Train Reward: {total_reward}, Best Eval Avg Reward: {agent.best_avg_reward}")


model save success
New best model saved with average reward: 9.4
Episode: 0, Train Reward: 50.0, Best Eval Avg Reward: 9.4
Episode: 1, Train Reward: 21.0, Best Eval Avg Reward: 9.4
Episode: 2, Train Reward: 37.0, Best Eval Avg Reward: 9.4
Episode: 3, Train Reward: 28.0, Best Eval Avg Reward: 9.4
Episode: 4, Train Reward: 24.0, Best Eval Avg Reward: 9.4
Episode: 5, Train Reward: 22.0, Best Eval Avg Reward: 9.4
Episode: 6, Train Reward: 13.0, Best Eval Avg Reward: 9.4
Episode: 7, Train Reward: 11.0, Best Eval Avg Reward: 9.4
Episode: 8, Train Reward: 26.0, Best Eval Avg Reward: 9.4
Episode: 9, Train Reward: 14.0, Best Eval Avg Reward: 9.4
Episode: 10, Train Reward: 14.0, Best Eval Avg Reward: 9.4
Episode: 11, Train Reward: 20.0, Best Eval Avg Reward: 9.4
Episode: 12, Train Reward: 21.0, Best Eval Avg Reward: 9.4
Episode: 13, Train Reward: 20.0, Best Eval Avg Reward: 9.4
Episode: 14, Train Reward: 18.0, Best Eval Avg Reward: 9.4
Episode: 15, Train Reward: 16.0, Best Eval Avg Reward: 9.4
E

In [6]:
agent.epsilon = 0
test_env = gym.make('CartPole-v1', render_mode='rgb_array')
test_env = RecordVideo(test_env, "./dqn_videos", episode_trigger=lambda x: True)
agent.q_net.load_state_dict(state_dict=agent.best_net.state_dict())

<All keys matched successfully>

In [7]:
for episodes in range(5):
    state = test_env.reset()[0]
    total_reward = 0
    step = 0
    while True:
        action = agent.choose_action(state)
        next_state,reward,done,_,_ = test_env.step(action)
        total_reward+=reward
        state = next_state
        step +=1
        if done or step >= 1500:
            break
    print(f"Test episode {episodes}, rewards {total_reward}")

test_env.close()

Test episode 0, rewards 1500.0
Test episode 1, rewards 1500.0
Test episode 2, rewards 1500.0
Test episode 3, rewards 1500.0
Test episode 4, rewards 1500.0
