In [17]:
import gymnasium as gym
import torch

In [18]:
from torchvision.models import resnet18
import torch.nn as nn
import torch.nn.functional as F


class VA_Network(nn.Module):
    # 用于Dueling DQN的V/A网
    def __init__(self, hidden_dim, action_dim):
        super(VA_Network, self).__init__()
        self.fc1 = torch.nn.Linear(12, hidden_dim)
        self.fc_A = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_V = torch.nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        A = self.fc_A(x)
        V = self.fc_V(x)
        Q = V + A - A.mean(1).view(-1, 1)
        return Q

In [19]:
import numpy as np
import torchvision

class Dueling_DQN:
    # Dueling DQN算法
    def __init__(self, hidden_dim, action_dim, learning_rate, gamma, epsilon, target_update, device):
        self.hidden_dim = hidden_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon  # epsilon-greedy阈值
        self.target_update = target_update
        self.count = 0
        self.device = device
        self.q_net = VA_Network(hidden_dim=self.hidden_dim, action_dim=self.action_dim).to(device)
        self.target_q_net = VA_Network(hidden_dim=self.hidden_dim, action_dim=self.action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=learning_rate)
        
    def take_action(self, state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            state = torch.tensor([state], dtype=torch.float).to(self.device)
            action = self.q_net(state).argmax().item()
        return action
    
    def max_q_value(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        return self.q_net(state).max().item()
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)

        q_values = self.q_net(states).gather(1, actions)

        max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)
        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()

        if self.count % self.target_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.count += 1
        

In [29]:
def train_DQN(agent, env, num_episodes, replay_buffer, minimal_size,
              batch_size):
    return_list = []
    max_q_value_list = []
    max_q_value = 0
    for i in range(10):
        with tqdm(total=int(num_episodes / 10),
                  desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                episode_return = 0
                state, info = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    max_q_value = agent.max_q_value(
                        state) * 0.005 + max_q_value * 0.995  # 平滑处理
                    max_q_value_list.append(max_q_value)  # 保存每个状态的最大Q值
                    next_state, reward, done, _, _ = env.step(action)
                    
                    if next_state[-3]<0:  # 碰到屏幕上界
                        reward = -0.5
                    
                    episode_return += reward
                    
                    if episode_return >= 7.0:
                        replay_buffer.add(state, action, reward, next_state, done)
                    state = next_state
                    
                    if replay_buffer.size() == minimal_size:
                        print("Now Start Training!")
                    if replay_buffer.size() > minimal_size:
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(
                            batch_size)
                        transition_dict = {
                            'states': b_s,
                            'actions': b_a,
                            'next_states': b_ns,
                            'rewards': b_r,
                            'dones': b_d
                        }
                        agent.update(transition_dict)
                return_list.append(episode_return)
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({
                        'episode':
                        '%d' % (num_episodes / 10 * i + i_episode + 1),
                        'return':
                        '%.3f' % np.mean(return_list[-10:])
                    })
                pbar.update(1)
    return return_list, max_q_value_list

In [None]:
import rl_utils
from tqdm import tqdm
import ale_py
import shimmy
import flappy_bird_gymnasium

hidden_dim = 512
action_dim = 2
learning_rate = 1e-4
gamma = 0.8
epsilon = 0.01
target_update = 50  # 目标网络更新周期
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("mps")
buffer_size = 1000

agent = Dueling_DQN(hidden_dim, action_dim, learning_rate, gamma, epsilon, target_update, device)
env = gym.make("FlappyBird-v0", render_mode="human")
num_episodes = 200
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
minimal_size = 100  # 开始训练的回放池样本数
batch_size = 64

observation, info = env.reset()

pre_trained = True
pre_trained_param_path = "params/param.pt"
replay_buffer_path = "params/replay_buffer.pkl"
if pre_trained:
    agent.q_net.load_state_dict(torch.load(pre_trained_param_path))
    agent.target_q_net.load_state_dict(torch.load(pre_trained_param_path))
    with open(replay_buffer_path, 'rb') as file:
        replay_buffer = pickle.load(file)

return_list, max_q_value_list = train_DQN(agent, env, num_episodes, replay_buffer, minimal_size, batch_size)

Iteration 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [01:55<00:00,  5.78s/it, episode=20, return=3.760]
Iteration 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [01:47<00:00,  5.36s/it, episode=40, return=-7.420]
Iteration 2:  20%|███████████████████████████████▊                                                                                                                               | 4/20 [00:21<01:26,  5.38s/it]

In [27]:
pre_trained_param_path = "params/param.pt"
replay_buffer_path = "params/replay_buffer.pkl"

# 保存神经网络参数
import torch
torch.save(agent.q_net.state_dict(), pre_trained_param_path)

# 保存经验池
import pickle
with open(replay_buffer_path, 'wb') as file:
    pickle.dump(replay_buffer, file)