In [1]:
import gymnasium as gym
import torch

In [2]:
from torchvision.models import resnet18
import torch.nn as nn
import torch.nn.functional as F


class VA_Network(nn.Module):
    # 用于Dueling DQN的V/A网
    def __init__(self, hidden_dim, action_dim):
        super(VA_Network, self).__init__()
        self.fc1 = torch.nn.Linear(12, hidden_dim)
        self.fc_A = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_V = torch.nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        A = self.fc_A(x)
        V = self.fc_V(x)
        Q = V + A - A.mean(1).view(-1, 1)
        return Q

In [3]:
import numpy as np
import torchvision

class Dueling_DQN:
    # Dueling DQN算法
    def __init__(self, hidden_dim, action_dim, learning_rate, gamma, epsilon, target_update, device):
        self.hidden_dim = hidden_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon  # epsilon-greedy阈值
        self.target_update = target_update
        self.count = 0
        self.device = device
        self.q_net = VA_Network(hidden_dim=self.hidden_dim, action_dim=self.action_dim).to(device)
        self.target_q_net = VA_Network(hidden_dim=self.hidden_dim, action_dim=self.action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=learning_rate)
        
    def take_action(self, state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_dim)
        else:
            action = self.q_net(state).argmax().item()
        return action
    
    def max_q_value(self, state):
        return self.q_net(state).max().item()
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)
        
        
        # ======================================================================
#         print("states_shape:", states.shape)
        q_values = self.q_net(states).gather(1, actions)
#         print("next_states_shape:", next_states.shape)
        max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)
        
        q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)
        dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()

        if self.count % self.target_update == 0:  # 将决策网络参数同步给目标网络
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.count += 1
        
        

In [6]:
def dis_to_con(discrete_action, env, action_dim):  # 离散动作转回连续的函数
    action_lowbound = env.action_space.low[0]  # 连续动作的最小值
    action_upbound = env.action_space.high[0]  # 连续动作的最大值
    return action_lowbound + (discrete_action /
                              (action_dim - 1)) * (action_upbound -
                                                   action_lowbound)

def train_DQN(agent, env, num_episodes, replay_buffer, minimal_size,
              batch_size):
    return_list = []
    max_q_value_list = []
    max_q_value = 0
    for i in range(10):
        with tqdm(total=int(num_episodes / 10),
                  desc='Iteration %d' % i) as pbar:
            for i_episode in range(int(num_episodes / 10)):
                episode_return = 0
                state, info = env.reset()
                done = False
                while not done:
                    action = agent.take_action(state)
                    max_q_value = agent.max_q_value(
                        state) * 0.005 + max_q_value * 0.995  # 平滑处理
                    max_q_value_list.append(max_q_value)  # 保存每个状态的最大Q值
                    next_state, reward, done, _, _ = env.step(action)
                    
                    state_tensor = torch.tensor(state, dtype=torch.float)
                    state_tensor = state_tensor.permute([2, 0, 1])
                    transform = torchvision.transforms.Resize((224, 224))
                    state_tensor = transform(state_tensor)
                    state_tensor = state_tensor / 255
                    
                    next_state_tensor = torch.tensor(next_state, dtype=torch.float)
                    next_state_tensor = next_state_tensor.permute([2, 0, 1])
                    transform = torchvision.transforms.Resize((224, 224))
                    next_state_tensor = transform(next_state_tensor)
                    next_state_tensor = next_state_tensor / 255
                    
                    replay_buffer.add(state_tensor.numpy(), action, reward, next_state_tensor.numpy(), done)
                    
                    state = next_state
                    episode_return += reward
                    if replay_buffer.size() == minimal_size:
                        print("Now Start Training!")
                    if replay_buffer.size() > minimal_size:
                        b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(
                            batch_size)
                        transition_dict = {
                            'states': b_s,
                            'actions': b_a,
                            'next_states': b_ns,
                            'rewards': b_r,
                            'dones': b_d
                        }
                        agent.update(transition_dict)
                return_list.append(episode_return)
                if (i_episode + 1) % 10 == 0:
                    pbar.set_postfix({
                        'episode':
                        '%d' % (num_episodes / 10 * i + i_episode + 1),
                        'return':
                        '%.3f' % np.mean(return_list[-10:])
                    })
                pbar.update(1)
    return return_list, max_q_value_list

In [9]:
import rl_utils
from tqdm import tqdm
import ale_py
import shimmy

hidden_dim = 512
action_dim = 4
learning_rate = 1e-3
gamma = 0.98
epsilon = 0.01
target_update = 50  # 目标网络更新周期
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda")
buffer_size = 10000

agent = Dueling_DQN(hidden_dim, action_dim, learning_rate, gamma, epsilon, target_update, device)
env = gym.make("ALE/Breakout-v5", render_mode="human", obs_type="rgb")
num_episodes = 100
replay_buffer = rl_utils.ReplayBuffer(buffer_size)
minimal_size = 100  # 开始训练的回放池样本数
batch_size = 64

observation, info = env.reset()

pre_trained = True
pre_trained_param_path = "params/param_iter_2.pt"
replay_buffer_path = "params/replay_buffer_iter_2.pkl"
if pre_trained:
    agent.q_net.load_state_dict(torch.load(pre_trained_param_path))
    agent.target_q_net.load_state_dict(torch.load(pre_trained_param_path))
    with open(replay_buffer_path, 'rb') as file:
        replay_buffer = pickle.load(file)

return_list, max_q_value_list = train_DQN(agent, env, num_episodes, replay_buffer, minimal_size, batch_size)



Now Start Training!


Iteration 0: 100%|███████████████████████████████████████████| 10/10 [11:32<00:00, 69.26s/it, episode=10, return=0.900]
Iteration 1:   0%|                                                                              | 0/10 [00:16<?, ?it/s]


KeyboardInterrupt: 

In [10]:
pre_trained_param_path = "params/param_iter_3.pt"
replay_buffer_path = "params/replay_buffer_iter_3.pkl"

# 保存神经网络参数
torch.save(agent.q_net.state_dict(), pre_trained_param_path)

# 保存经验池
import pickle
with open(replay_buffer_path, 'wb') as file:
    pickle.dump(replay_buffer, file)