In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import ale_py

In [10]:
import math
class Dueling_DQN(nn.Module):
    def __init__(self, args):
        super(Dueling_DQN, self).__init__()
        self.fc1 = nn.Linear(args.state_dim, args.hidden_dim)
        self.fc2 = nn.Linear(args.hidden_dim, args.hidden_dim)
        self.V = nn.Linear(args.hidden_dim, 1)
        self.A = nn.Linear(args.hidden_dim, args.action_dim)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        V = self.V(x)
        A = self.A(x)
        Q = V + A - torch.mean(A, dim=1, keepdim=True)
        return Q 
class DQN_Net(nn.Module):
    def __init__(self, args):
        super(DQN_Net, self).__init__()
        self.fc1 = nn.Linear(args.state_dim, args.hidden_dim)
        self.fc2 = nn.Linear(args.hidden_dim, args.hidden_dim)
        self.fc3 = nn.Linear(args.hidden_dim, args.action_dim)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        Q = self.fc3(x)
        return Q
# class DQN_Net(nn.Module):
#     def __init__(self, args, input_shape=(4, 84, 84)):
#         super(DQN_Net, self).__init__()
#         # 卷积层
#         self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4)
#         self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
#         self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        
#         # 计算卷积后的特征大小
#         conv_out_size = self._get_conv_output(input_shape)
        
#         # 全连接层
#         self.fc1 = nn.Linear(conv_out_size, args.hidden_dim)
#         self.fc2 = nn.Linear(args.hidden_dim, args.action_dim)
        
#     def _get_conv_output(self, shape):
#         bs = 1
#         input = torch.rand(bs, *shape)
#         output = self._forward_conv(input)
#         return int(np.prod(output.size()))
        
#     def _forward_conv(self, x):
#         x = F.relu(self.conv1(x))
#         x = F.relu(self.conv2(x))
#         x = F.relu(self.conv3(x))
#         return x
        
#     def forward(self, x):
#         # 检查输入维度，确保是4D张量 [batch, channels, height, width]
#         if len(x.shape) == 3:
#             x = x.unsqueeze(0)  # 添加批次维度
            
#         x = self._forward_conv(x)
#         x = x.view(x.size(0), -1)  # 展平
#         x = F.relu(self.fc1(x))
#         return self.fc2(x)
class Noisy_DQN(nn.Module):
    def __init__(self, input_dim, output_dim,sigma_init=0.5):
        super(Noisy_DQN, self).__init__()
        self.std_init = sigma_init
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weight_mu = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.weight_sigma = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.register_buffer('weight_epsilon', torch.Tensor(output_dim, input_dim))
        self.bias_mu = nn.Parameter(torch.Tensor(output_dim))
        self.bias_sigma = nn.Parameter(torch.Tensor(output_dim))
        self.register_buffer('bias_epsilon', torch.Tensor(output_dim))
        self.is_training = True
        self.reset_parameters()
        self.reset_noisy()
    def forward(self,x):
        if self.is_training:
            #self.reset_noisy()
            weight = self.weight_mu + self.weight_sigma.mul(self.weight_epsilon)
            bias = self.bias_mu + self.bias_sigma.mul(self.bias_epsilon)
        else:
            weight = self.weight_mu
            bias = self.bias_mu
        return F.linear(x, weight, bias)
    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.input_dim)
        self.weight_mu.data.uniform_(-std, std)
        self.weight_sigma.data.fill_(self.std_init/math.sqrt(self.input_dim))
        self.bias_mu.data.uniform_(-std, std)
        self.bias_sigma.data.fill_(self.std_init/math.sqrt(self.output_dim))
    def reset_noisy(self):
        epsilon_in = self._scale_noise(self.input_dim)
        epsilon_out = self._scale_noise(self.output_dim)
        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)
    def _scale_noise(self, size):
        x = torch.randn(size)
        x = x.sign().mul(x.abs().sqrt())
        return x
class Distribution_DQN(nn.Module):
    def __init__(self, args):
        super(Distribution_DQN, self).__init__()
        self.in_dim = args.state_dim
        self.out_dim = args.action_dim
        self.hidden_dim = args.hidden_dim
        self.num_atoms = args.num_atoms
        self.v_min = args.v_min
        self.v_max = args.v_max
        self.device = args.device
        self.num_actions = args.um_actions
        self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
        self.fc1 = nn.Linear(self.in_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, self.num_actions * self.num_atoms)
    def forward(self, x):
        dist = self.dist(x)
        support = torch.linspace(self.v_min, self.v_max, self.num_atoms).to(self.device)
        q_value = torch.sum(dist * support, dim=2)
        return q_value
    def dist(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, self.num_actions, self.num_atoms)
        dist = F.softmax(x, dim=-1)
        dist = dist.clamp(min=1e-3)
        return dist

In [11]:
class NetWork(nn.Module):
    def __init__(self, args):
        super(NetWork,self).__init__()
        self.in_dim = args.state_dim
        self.out_dim = args.action_dim
        self.hidden_dim = args.hidden_dim
        self.num_atoms = args.num_atoms
        self.v_min = args.v_min
        self.v_max = args.v_max
        self.device = args.device
        
        self.fc1 = nn.Linear(self.in_dim, self.hidden_dim)
        #self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.advantage_hidden = Noisy_DQN(self.hidden_dim, self.hidden_dim)
        self.advantage = Noisy_DQN(self.hidden_dim, self.out_dim*self.num_atoms)
        self.value_hidden = Noisy_DQN(self.hidden_dim, self.hidden_dim)
        self.value = Noisy_DQN(self.hidden_dim, 1*self.num_atoms)
    def forward(self,x):
        dist = self.dist(x)
        support = torch.linspace(self.v_min, self.v_max, self.num_atoms).to(self.device)
        #print(dist.shape)
        q_value = torch.sum(dist * support, dim=2)
        return q_value
       
    def dist(self,x) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        #x = F.relu(self.fc2(x))
        advantage = F.relu(self.advantage_hidden(x))
        advantage = self.advantage(advantage)
        value = F.relu(self.value_hidden(x))
        value = self.value(value)
        advantage = advantage.view(-1, self.out_dim, self.num_atoms)
        value = value.view(-1, 1, self.num_atoms)
        q_value = value + advantage - advantage.mean(dim=1, keepdim=True)
        dist = F.softmax(q_value, dim=-1)
        dist = dist.clamp(min=1e-6)
        return dist 
    def reset_noise(self):
        self.advantage_hidden.reset_noisy()
        self.advantage.reset_noisy()
        self.value_hidden.reset_noisy()
        self.value.reset_noisy()
    def train(self, mode=True):
        super(NetWork, self).train(mode)
        self.is_training = mode
    
    # Explicitly set training mode for all Noisy layers
        self.advantage_hidden.is_training = mode
        self.advantage.is_training = mode
        self.value_hidden.is_training = mode
        self.value.is_training = mode
    
        return self
    
    def eval(self):
        return self.train(False)
        
        

In [12]:
class SumTree(object):
    data_point = 0
    def __init__(self,buffer_size):
        self.buffer_size = buffer_size
        self.tree_size = 2*buffer_size-1
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tree = np.zeros(self.tree_size)
        self.data = np.zeros(buffer_size)
    def update(self,data_idx,value):
        tree_idx = data_idx+self.buffer_size-1
        change = value-self.tree[tree_idx]
        self.tree[tree_idx] = value
        while tree_idx != 0:
            tree_idx = (tree_idx-1)//2
            self.tree[tree_idx] += change
    def get_leaf(self,v):
        parent_idx = 0
        while True:
            left_child_idx = 2*parent_idx+1
            right_child_idx = left_child_idx+1
            if left_child_idx>=self.tree_size:
                leaf_idx = parent_idx
                break
            else:
                if v<=self.tree[left_child_idx]:
                    parent_idx = left_child_idx
                else:
                    v -= self.tree[left_child_idx]
                    parent_idx = right_child_idx
        data_idx = leaf_idx-self.buffer_size+1
        return data_idx,self.tree[leaf_idx]
    def sample(self,n,beta):
        batch_index = np.zeros(n, dtype=np.long)
        IS_weight = torch.zeros(n, dtype=torch.float32)
       
        total = self.total
        pri_seg = total/n
        
        tree_tensor = torch.tensor(self.tree[-self.buffer_size:], dtype=torch.float32)
        min_prob = tree_tensor.min().item() / total
        if min_prob == 0:
            min_prob = 0.00001
        prob_list = []
        for i in range(n):
            a,b= i*pri_seg,(i+1)*pri_seg
            
            v = np.random.uniform(a,b)
            data_idx,priority = self.get_leaf(v)
            prob = priority/total
            batch_index[i] = data_idx
            prob_list.append(prob/min_prob)
        IS_weight = torch.tensor(prob_list, dtype=torch.float32)
        IS_weight = torch.pow(IS_weight, -beta)
        return batch_index,IS_weight
    def add(self,data_idx,value):
        tree_idx = data_idx+self.buffer_size-1
        self.data[self.data_point] = value
        self.update(data_idx,value)
        self.data_point += 1
        if self.data_point>=self.buffer_size:
            self.data_point = 0
    @property
    def max_priority(self):
        return np.max(self.tree[-self.buffer_size:])
    @property
    def total(self):
        return self.tree[0]
        

In [13]:
from collections import deque
class Replaybuffer(object):
    def __init__(self,args):
        self.batch_size = args.batch_size
        self.buffer_capacity = args.buffer_capacity
        self.current_size = 0
        self.current_index = 0
        
        self.buffer = {
            'state':np.zeros((self.buffer_capacity,args.state_dim)),
            'action':np.zeros((self.buffer_capacity,1)),
            'reward':np.zeros(self.buffer_capacity),
            'next_state':np.zeros((self.buffer_capacity,args.state_dim)),
            'done':np.zeros(self.buffer_capacity)
        }
        
    def store_transition(self,state,action,reward,next_state,done):
        self.buffer['state'][self.current_index] = state
        self.buffer['action'][self.current_index] = action
        self.buffer['reward'][self.current_index] = reward
        self.buffer['next_state'][self.current_index] = next_state
        self.buffer['done'][self.current_index] = done
        self.current_index  = (self.current_index+1)%self.buffer_capacity
        self.current_size = min(self.current_size+1,self.buffer_capacity)
    def sample(self,total_step):
        
        index = np.random.randint(0,self.current_size,size=self.batch_size)
        batch = {}
        for key in self.buffer.keys():
            if key=='action':
                batch[key] = torch.tensor(self.buffer[key][index],dtype=torch.long)
            else:
                batch[key] = torch.tensor(self.buffer[key][index],dtype=torch.float32)
        return batch,None,None
class N_Step_ReplayBuffer(object):
    def __init__(self,args):
            self.batch_size = args.batch_size
            self.buffer_capacity = args.buffer_capacity
            self.current_size = 0
            self.current_index = 0
            self.n_step = args.n_steps
            self.gamma = args.gamma
            self.n_step_deque = deque(maxlen=args.n_steps)
            self.buffer = {
                'state':np.zeros((self.buffer_capacity,args.state_dim)),
                'action':np.zeros((self.buffer_capacity,1)),
                'reward':np.zeros(self.buffer_capacity),
                'next_state':np.zeros((self.buffer_capacity,args.state_dim)),
                'done':np.zeros(self.buffer_capacity)
            }
    def store_transition(self,state,action,reward,next_state,done):
        self.n_step_deque.append((state,action,reward,next_state,done))
        if len(self.n_step_deque)>=self.n_step:
            state, action, n_steps_reward, next_state, done = self.get_n_steps_transition()
            self.buffer['state'][self.current_index] = state
            self.buffer['action'][self.current_index] = action
            self.buffer['reward'][self.current_index] = n_steps_reward
            self.buffer['next_state'][self.current_index] = next_state
            self.buffer['done'][self.current_index] = done
            self.current_index  = (self.current_index+1)%self.buffer_capacity
            self.current_size = min(self.current_size+1,self.buffer_capacity)
    def get_n_steps_transition(self):
        state,action = self.n_step_deque[0][0],self.n_step_deque[0][1]
        next_state,done = self.n_step_deque[-1][3],self.n_step_deque[-1][4]
        n_steps_reward = 0
        for i in reversed(range(self.n_step)):
            r,_s,done = self.n_step_deque[i][2:]
            n_steps_reward = r + self.gamma*(1-done)*n_steps_reward
            if done:
                next_state,done = _s,done
                break
        return state,action,n_steps_reward,next_state,done
    def sample(self,total_step):
        index = np.random.randint(0,self.current_size,size=self.batch_size)
        batch = {}
        for key in self.buffer.keys():
            if key=='action':
                batch[key] = torch.tensor(self.buffer[key][index],dtype=torch.long)
            else:
                batch[key] = torch.tensor(self.buffer[key][index],dtype=torch.float32)
        return batch,None,None
    
class Prioritized_ReplayBuffer(object):
    def __init__(self,args):
        self.batch_size = args.batch_size
        self.buffer_capacity = args.buffer_capacity
        self.current_size = 0
        self.current_index = 0
        self.alpha = args.alpha
        self.beta = args.beta_init
        self.sumtree = SumTree(self.buffer_capacity)
        self.beta_increment_per_sampling = args.beta_increment_per_sampling
        self.priority_eps = args.priority_eps
        self.max_train_steps = args.max_train_steps
        
        self.tree = SumTree(self.buffer_capacity)
        
        self.buffer = {
            'state':np.zeros((self.buffer_capacity,args.state_dim)),
            'action':np.zeros((self.buffer_capacity,1)),
            'reward':np.zeros(self.buffer_capacity),
            'next_state':np.zeros((self.buffer_capacity,args.state_dim)),
            'done':np.zeros(self.buffer_capacity)
        }
    def store_transition(self,state,action,reward,next_state,done):
        self.buffer['state'][self.current_index] = state
        self.buffer['action'][self.current_index] = action
        self.buffer['reward'][self.current_index] = reward
        self.buffer['next_state'][self.current_index] = next_state
        self.buffer['done'][self.current_index] = done
        max_priority = self.tree.max_priority if self.current_size else 1.0
        self.tree.add(self.current_index,max_priority)
        self.current_index  = (self.current_index+1)%self.buffer_capacity
        self.current_size = min(self.current_size+1,self.buffer_capacity)
    def sample(self,total_step):
        batch_index,IS_weight = self.tree.sample(self.batch_size,self.beta)
        batch = {}
        self.beta = self.beta + (1 - self.beta) * (total_step / self.max_train_steps)
        for key in self.buffer.keys():
            if key=='action':
                batch[key] = torch.tensor(self.buffer[key][batch_index],dtype=torch.long)
            else:
                batch[key] = torch.tensor(self.buffer[key][batch_index],dtype=torch.float32)
        return batch, batch_index, IS_weight 
    def update_batch_priorities(self, batch_index, td_errors):  # 根据传入的td_error，更新batch_index所对应数据的priorities
        priorities = (np.abs(td_errors) + 0.01) ** self.alpha
        for index, priority in zip(batch_index, priorities):
            self.sum_tree.update(index, priority)

        
            



        

In [14]:
class ReplayBuffer(object):
    def __init__(self, args):
        self.max_train_steps = args.max_train_steps
        self.alpha = args.alpha
        self.beta_init = args.beta_init
        self.beta = args.beta_init
        self.gamma = args.gamma
        self.batch_size = args.batch_size
        self.buffer_capacity = args.buffer_capacity
        self.sum_tree = SumTree(self.buffer_capacity)
        self.n_steps = args.n_steps
        self.n_steps_deque = deque(maxlen=self.n_steps)
        self.buffer = {'state': np.zeros((self.buffer_capacity, args.state_dim)),
                       'action': np.zeros((self.buffer_capacity, 1)),
                       'reward': np.zeros(self.buffer_capacity),
                       'next_state': np.zeros((self.buffer_capacity, args.state_dim)),
                       'done': np.zeros(self.buffer_capacity),
                       }
        self.current_size = 0
        self.current_index = 0
        self.v_min = args.v_min
        self.v_max = args.v_max
        self.num_atoms = args.num_atoms
        self.device = args.device
        self.support = torch.linspace(self.v_min, self.v_max, self.num_atoms).to(self.device)
    def store_transition(self, state, action, reward, next_state, done):
        self.n_steps_deque.append((state, action, reward, next_state, done))
        if len(self.n_steps_deque) == self.n_steps:
            state, action, n_steps_reward, next_state, done = self.get_n_steps_transition()
            self.buffer['state'][self.current_index] = state
            self.buffer['action'][self.current_index] = action
            self.buffer['reward'][self.current_index] = n_steps_reward
            self.buffer['next_state'][self.current_index] = next_state
            self.buffer['done'][self.current_index] = done
            max_priority = self.sum_tree.max_priority if self.current_size else 1.0
            self.sum_tree.update(self.current_index, max_priority)
            self.current_index = (self.current_index + 1) % self.buffer_capacity
            self.current_size = min(self.current_size + 1, self.buffer_capacity)
    def get_n_steps_transition(self):
        state, action = self.n_steps_deque[0][0], self.n_steps_deque[0][1]
        next_state, done = self.n_steps_deque[-1][3], self.n_steps_deque[-1][4]
        n_steps_reward = 0
        for i in reversed(range(self.n_steps)):
            r, _s, done = self.n_steps_deque[i][2:]
            n_steps_reward = r + self.gamma * (1 - done) * n_steps_reward
            if done:
                next_state, done = _s, done
                break
        return state, action, n_steps_reward, next_state, done
    def sample(self, total_step):
        batch_index, IS_weight = self.sum_tree.sample(self.batch_size, self.beta)
        batch = {}
        self.beta = self.beta_init + (1 - self.beta_init) * (total_step / self.max_train_steps)
        for key in self.buffer.keys():
            if key == 'action':
                batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.long)
            else:
                batch[key] = torch.tensor(self.buffer[key][batch_index], dtype=torch.float32)
        return batch, batch_index, IS_weight
    def update_batch_priorities(self, batch_index, td_errors): 
        values = (np.abs(td_errors) + 1e-6) ** self.alpha
        
        for index, value in zip(batch_index, values):
            
            self.sum_tree.update(index, value)
           

In [15]:
from copy import deepcopy
from torch.nn.utils import clip_grad_norm_
class DQN(object):
    def __init__(self,args):
        self.args = args
        self.action_dim = args.action_dim
        self.state_dim = args.state_dim
        self.lr = args.lr
        self.dqn_net = NetWork(args).to(args.device)
        self.target_net =deepcopy(self.dqn_net)
        self.optimizer = torch.optim.Adam(self.dqn_net.parameters(),lr=args.lr)
        self.update_counter = 0
        self.target_update_freq = args.target_update_freq
        self.max_train_steps = args.max_train_steps
        self.gamma = args.gamma
        self.beta = args.beta_init
        self.beta_increment_per_sampling = args.beta_increment_per_sampling
        self.priority_eps = args.priority_eps
        self.n_step = args.n_steps
        self.num_atoms = args.num_atoms
        self.v_min = args.v_min
        self.v_max = args.v_max
        self.batch_size = args.batch_size
        self.replay_buffer = ReplayBuffer(args)
        self.support = torch.linspace(self.v_min,self.v_max,self.num_atoms).to(args.device)
    def choose_action(self,state,epsilon):
        with torch.no_grad():
        
            state = torch.unsqueeze(torch.tensor(state,dtype=torch.float),0).to(self.args.device)
       
            q = self.dqn_net(state)
            if np.random.uniform() > epsilon:
                action = q.argmax(dim=-1).item()
            else:
                action = np.random.randint(0, self.action_dim)
        return action
    def learn(self,total_step):
        gamma = self.gamma**self.n_step
        batch, idx, IS_weight = self.replay_buffer.sample(total_step)
        IS_weight = torch.FloatTensor(IS_weight.reshape(-1,1)).to(self.args.device)
        element_loss,td_error  = self.compute_loss(batch,total_step,gamma)
        loss  = torch.mean(element_loss*IS_weight)
        self.optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(self.dqn_net.parameters(), 10)
        self.optimizer.step()

        loss_for_prior = td_error.detach().cpu().numpy()  
        new_priority = self.priority_eps + loss_for_prior
        self.replay_buffer.update_batch_priorities(idx, new_priority)
        self.dqn_net.reset_noise()
        self.target_net.reset_noise()
        
        self.update_counter += 1
        if self.update_counter%self.target_update_freq==0:
            self.target_net.load_state_dict(self.dqn_net.state_dict())
        self.update_lr(total_step)
    def update_lr(self,total_steps):
        lr_now = 0.9 * self.lr * (1 - total_steps / self.max_train_steps) + 0.1 * self.lr
        for p in self.optimizer.param_groups:
            p['lr'] = lr_now
    def compute_loss(self,batch,total_step,gamma):
       
        state = torch.FloatTensor(batch["state"]).to(self.args.device)
        next_state = torch.FloatTensor(batch["next_state"]).to(self.args.device)
        action = torch.LongTensor(batch["action"]).to(self.args.device)
        reward = torch.FloatTensor(batch["reward"].reshape(-1, 1)).to(self.args.device)
        done = torch.FloatTensor(batch["done"].reshape(-1, 1)).to(self.args.device)
        delta_z = float(self.v_max - self.v_min)/(self.num_atoms-1)
        action = action.squeeze(1)
        with torch.no_grad():
            # Double DQN
            next_action = self.dqn_net(next_state).argmax(1)
            next_dist = self.target_net.dist(next_state)
            next_dist = next_dist[range(self.batch_size), next_action]

            t_z = reward + (1 - done) * gamma * self.support
            t_z = t_z.clamp(min=self.v_min, max=self.v_max)
            b = (t_z - self.v_min) / delta_z
            l = b.floor().long()
            u = b.ceil().long()

            offset = (
                torch.linspace(
                    0, (self.batch_size - 1) * self.num_atoms, self.batch_size
                ).long()
                .unsqueeze(1)
                .expand(self.batch_size, self.num_atoms)
                .to(self.args.device)
            )

            proj_dist = torch.zeros(next_dist.size(), device=self.args.device)
            proj_dist.view(-1).index_add_(
                0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
            )
            proj_dist.view(-1).index_add_(
                0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
            )

        dist = self.dqn_net.dist(state)
       
        log_p = torch.log(dist[range(self.batch_size), action])
        
        elementwise_loss = -(proj_dist * log_p).sum(1)
       
        Q_s_a = (dist[range(self.batch_size), action] * self.support).sum(dim=1)  

# 目标分布对应的期望 Q(s,a)，即投影后的分布
        Q_target = (proj_dist * self.support).sum(dim=1)  # [batch_size]

# 绝对 TD 误差，用于更新 PER
        td_errors = (Q_target - Q_s_a).abs().detach()     # [batch_size]
        #print(f"td_errors shape: {td_errors.shape}, td_errors dtype: {td_errors.dtype}")
        return elementwise_loss,td_errors
        

In [None]:
import gymnasium as gym
from torch.utils.tensorboard import SummaryWriter
import argparse
class Runner:
    def __init__(self,args,env_name,number,seed):
        self.args = args
        self.env_name = env_name
        self.number = number
        self.seed = seed

        self.env = gym.make(env_name)  # When training the policy, we need to build an environment
        self.env_evaluate = gym.make(env_name)  # When evaluating the policy, we need to rebuild an environment
        #self.env.seed(seed)
        self.env.reset(seed=seed)
        self.env.action_space.seed(seed)
        self.env_evaluate.reset(seed=seed)
        self.env_evaluate.action_space.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed=seed)
        torch.cuda.manual_seed_all(seed=seed)

        self.args.state_dim = self.env.observation_space.shape[0]
        self.args.action_dim = self.env.action_space.n
        self.args.episode_limit = self.env._max_episode_steps  # Maximum number of steps per episode
        self.device = args.device
        self.n_steps = args.n_steps
        print("env={}".format(self.env_name))
        print("state_dim={}".format(self.args.state_dim))
        print("action_dim={}".format(self.args.action_dim))
        print("episode_limit={}".format(self.args.episode_limit))
        print("device={}".format(self.device))
        self.agent = DQN(args)

        self.algorithm = 'Rainbow-DQN'
        self.writer = SummaryWriter(log_dir='runs/DQN/{}_env_{}_number_{}_seed_{}'.format(self.algorithm, env_name, number, seed))

        self.evaluate_num = 0  # Record the number of evaluations
        self.evaluate_rewards = []  # Record the rewards during the evaluating
        self.total_steps = 0  # Record the total steps during the training
        self.epsilon = self.args.epsilon_init
        self.epsilon_min = self.args.epsilon_min
        self.epsilon_decay = (self.args.epsilon_init - self.args.epsilon_min) / self.args.epsilon_decay_steps
        #self.replay_buffer = ReplayBuffer(args)  # Initialize the replay buffer
        

        self.beta = self.args.beta_init
        self.writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(self.args).items()])),
    )
    def run(self, ):
        self.evaluate_policy()
        #total_reward = 0
        while self.total_steps < self.args.max_train_steps:
            state = self.env.reset()
            state = state[0]
            done = False
            episode_steps = 0
            #total_reward = 0
            while not done:
                action = self.agent.choose_action(state, epsilon=self.epsilon)
                next_state, reward, done,  truncated,_= self.env.step(action) 
                done = done or truncated
                episode_steps += 1
                self.total_steps += 1
                #total_reward += reward
                self.agent.replay_buffer.store_transition(state, action, reward, next_state, done)  # Store the transition
                
                
                self.epsilon = self.epsilon - self.epsilon_decay if self.epsilon - self.epsilon_decay > self.epsilon_min else self.epsilon_min

               
                
                state = next_state
                
                if episode_steps%4==0:
                    if self.agent.replay_buffer.current_size >= self.args.batch_size:
                    
                    
                        self.agent.learn(self.total_steps)

                if self.total_steps % self.args.evaluate_freq == 0:
                    
                    self.evaluate_policy()
        # Save reward
        np.save('./data_train/{}_env_{}_number_{}_seed_{}.npy'.format(self.algorithm, self.env_name, self.number, self.seed), np.array(self.evaluate_rewards))
        

    def evaluate_policy(self, ):
        evaluate_reward = 0
        self.agent.dqn_net.eval()
        for _ in range(self.args.evaluate_times):
            state = self.env_evaluate.reset()
            state = state[0]
            done = False
            episode_reward = 0
            while not done:
                action = self.agent.choose_action(state, epsilon=0)
                next_state, reward, done,  truncated ,_= self.env_evaluate.step(action)
                done = done or truncated
                episode_reward += reward
                state = next_state
            evaluate_reward += episode_reward
        self.agent.dqn_net.train()
        evaluate_reward /= self.args.evaluate_times
        self.evaluate_rewards.append(evaluate_reward)
        print("total_steps:{} \t evaluate_reward:{} \t epsilon：{}".format(self.total_steps, evaluate_reward, self.epsilon))
        self.writer.add_scalar('step_rewards_{}'.format(self.env_name), evaluate_reward, global_step=self.total_steps)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Hyperparameter Setting for DQN")
    parser.add_argument("--device", type=str, default='cuda:0', help="device")
    parser.add_argument("--max_train_steps", type=int, default=int(1e5), help=" Maximum number of training steps")
    parser.add_argument("--evaluate_freq", type=float, default=1e3, help="Evaluate the policy every 'evaluate_freq' steps")
    parser.add_argument("--evaluate_times", type=float, default=3, help="Evaluate times")

    parser.add_argument("--buffer_capacity", type=int, default=int(1e5), help="The maximum replay-buffer capacity ")
    parser.add_argument("--batch_size", type=int, default=32, help="batch size")
    parser.add_argument("--hidden_dim", type=int, default=512, help="The number of neurons in hidden layers of the neural network")
    parser.add_argument("--lr", type=float, default=0.0000625, help="Learning rate of actor")
    parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
    parser.add_argument("--epsilon_init", type=float, default=0, help="Initial epsilon")
    parser.add_argument("--epsilon_min", type=float, default=0, help="Minimum epsilon")
    parser.add_argument("--epsilon_decay_steps", type=int, default=int(1e5), help="How many steps before the epsilon decays to the minimum")
    parser.add_argument("--tau", type=float, default=0.005, help="soft update the target network")
    parser.add_argument("--use_soft_update", type=bool, default=False, help="Whether to use soft update")
    parser.add_argument("--target_update_freq", type=int, default=100, help="Update frequency of the target network(hard update)")
    parser.add_argument("--n_steps", type=int, default=3, help="n_steps")
    parser.add_argument("--alpha", type=float, default=0.6, help="PER parameter")
    parser.add_argument("--beta_init", type=float, default=0.4, help="Important sampling parameter in PER")
    parser.add_argument("--use_lr_decay", type=bool, default=True, help="Learning rate Decay")
    parser.add_argument("--grad_clip", type=float, default=10.0, help="Gradient clip")
    parser.add_argument("--num_atoms", type=int, default=51, help="Number of atoms in distributional DQN")
    parser.add_argument("--v_min", type=float, default=-10, help="Minimum value of the support")
    parser.add_argument("--v_max", type=float, default=10, help="Maximum value of the support")
    parser.add_argument("--priority_eps", type=float, default=1e-6, help="Priority eps")
    parser.add_argument("--beta_increment_per_sampling", type=float, default=0.001, help="Increment of beta per sampling")

    parser.add_argument("--use_double", type=bool, default=True, help="Whether to use double Q-learning")
    parser.add_argument("--use_dueling", type=bool, default=True, help="Whether to use dueling network")
    parser.add_argument("--use_noisy", type=bool, default=True, help="Whether to use noisy network")
    parser.add_argument("--use_per", type=bool, default=True, help="Whether to use PER")
    parser.add_argument("--use_n_steps", type=bool, default=True, help="Whether to use n_steps Q-learning")

    args = parser.parse_args(args=[])
    #args = parser.parse_known_args()[0]
   
    env_names = ['CartPole-v1', 'LunarLander-v3','PongNoFrameskip-v4']
    env_index = 1
    seed = 0
    for seed in [0,20,50,70,100]:
        runner = Runner(args=args, env_name=env_names[env_index], number=1, seed=seed)
        runner.run()
    
        

env=LunarLander-v3
state_dim=8
action_dim=4
episode_limit=1000
device=cuda:0
total_steps:0 	 evaluate_reward:-115.02922986528642 	 epsilon：0
total_steps:1000 	 evaluate_reward:-620.197239036679 	 epsilon：0
total_steps:2000 	 evaluate_reward:-260.10190032545705 	 epsilon：0
total_steps:3000 	 evaluate_reward:-478.35231195943305 	 epsilon：0
total_steps:4000 	 evaluate_reward:-386.1971691503624 	 epsilon：0
total_steps:5000 	 evaluate_reward:-316.96642725195755 	 epsilon：0
total_steps:6000 	 evaluate_reward:-256.3774207086031 	 epsilon：0
total_steps:7000 	 evaluate_reward:-70.40187010626921 	 epsilon：0
total_steps:8000 	 evaluate_reward:-102.26345992775676 	 epsilon：0
total_steps:9000 	 evaluate_reward:-176.73868174138155 	 epsilon：0
total_steps:10000 	 evaluate_reward:57.003470914967124 	 epsilon：0
total_steps:11000 	 evaluate_reward:-223.12129674423485 	 epsilon：0
total_steps:12000 	 evaluate_reward:-152.57316638542514 	 epsilon：0
total_steps:13000 	 evaluate_reward:-116.11231307568703 	 

KeyboardInterrupt: 