In [None]:
import numpy as np
import gymnasium as gym
import torch 
import collections
import random
import copy

After adopting the replay buffer, there is still a problem for calculating the gradient of the loss function. If we update and utilize the policy network at the same time, the training process will be unstable. Therefore, we need to use a target network to calculate the target value. The target network is a copy of the main network, and it is updated every specific step size. The target network, with parameters $w_T$ is used to calculate the TD target. The main network, with parameters $w$ is used to calculate the current value. Hence, the training process is as follows:

\begin{equation}
w_{t+1} \gets w_t - \alpha_t \left(r_{t+1}+\gamma \max_{a^{\prime} \in \mathcal{A}(s_{t+1})} \hat{q}\left(s_{t+1}, a^{\prime}, w_{T,t}\right)-\hat{q}(s_t, a_t, w_t)\right) \nabla_{w_t} \hat{q}(s_t, a_t, w_t)
\end{equation}
where \{$s_t$, $a_t$, $r_t$, $s_{t+1}$\} is uniformly sampled from the experience replay buffer.

In [None]:
class DQN_Agent():
    
    def __init__(self,
                 Q_func,
                 action_size,
                 optimizer,
                 replay_buffer,
                 replay_start_size,
                 batch_size,
                 replay_frequent,
                 target_sync_frequent, # The frequency of synchronizing the parameters of the two Q networks
                 epsilon = 0.1, # Initial epsilon
                 mini_epsilon = 0.01, # Minimum epsilon
                 explore_decay_rate = 0.0001, # Decay rate of epsilon
                 gamma = 0.9,
                 device='cpu'):
        
        self.device = device
        self.action_size = action_size
        
        self.exp_counter = 0
        
        self.replay_buffer = replay_buffer
        self.replay_start_size = replay_start_size
        self.batch_size = batch_size
        self.replay_frequent = replay_frequent
        
        self.target_update_frequent = target_sync_frequent
        
        """Two Q functions (mian_Q and target_Q) are used to stabilize the training process. 
            Since they share the same network structure, we can use copy.deepcopy to copy the main_Q to target_Q for initialization."""
        self.main_Q_func = Q_func
        self.target_Q_func = copy.deepcopy(Q_func)
        
        self.criteria = torch.nn.MSELoss()
        self.optimizer = optimizer
        
        self.epsilon = epsilon
        self.mini_epsilon = mini_epsilon
        self.gamma = gamma
        self.explore_decay_rate = explore_decay_rate
        
    pass

    def get_target_action(self,obs):
        obs = torch.tensor(obs,dtype=torch.float32,device=self.device)
        Q_list = self.target_Q_func(obs)
        action = int(torch.argmax(Q_list).clone().detach().cpu().numpy())
        return action

    def get_behavior_action(self,obs):
        """Here, a simple epsilon decay is used to balance the exploration and exploitation.
            The epsilon is decayed from epsilon_init to mini_epsilon."""
        self.epsilon = max(self.mini_epsilon,self.epsilon-self.explore_decay_rate)
        
        if np.random.uniform(0,1) < self.epsilon:
            action = np.random.choice(self.action_size)
        else:
            action = self.get_target_action(obs)
            
        return action
    
    """Here, we defined a function to synchronize the parameters of the main_Q and target_Q."""
    def sync_target_Q_func(self):
        for target_params, main_params in zip(self.target_Q_func.parameters(), self.main_Q_func.parameters()):
            target_params.data.copy_(main_params.data)
            
    
    def batch_Q_approximation(self,batch_obs,batch_action,batch_reward,batch_next_obs,batch_terminated):
        batch_current_Q = torch.gather(self.main_Q_func(batch_obs),1,batch_action).squeeze(1)
        batch_TD_target = batch_reward + (1-batch_terminated) * self.gamma * self.target_Q_func(batch_next_obs).max(1)[0]
        loss = self.criteria(batch_current_Q,batch_TD_target)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
            
    def Q_approximation(self,obs,action,reward,next_obs,terminated):
        self.exp_counter += 1
        self.replay_buffer.append((obs,action,reward,next_obs,terminated))

        if len(self.replay_buffer) > self.replay_start_size and self.exp_counter%self.replay_frequent == 0:
            self.batch_Q_approximation(*self.replay_buffer.sample(self.batch_size))
        
        # Synchronize the parameters of the two Q networks every target_update_frequent steps
        if self.exp_counter%self.target_update_frequent == 0:
            self.sync_target_Q_func()

In [None]:
class Q_Network(torch.nn.Module):
    def __init__(self,obs_size,action_size):
        super(Q_Network,self).__init__()
        self.network = self.mlp_network(obs_size,action_size)
        
    def mlp_network(self,obs_size,action_size):
        mlp = torch.nn.Sequential(
            torch.nn.Linear(obs_size,64),
            torch.nn.ReLU(),
            torch.nn.Linear(64,64),
            torch.nn.ReLU(),
            torch.nn.Linear(64,action_size)
        )
        return mlp
    
    def forward(self,x):
        return self.network(x)

In [None]:
class ReplayBuffer():
    def __init__(self,capacity,device="cpu"):
        self.device = device
        self.buffer = collections.deque(maxlen=capacity)
        
    def append(self,exp_data):
        self.buffer.append(exp_data)
        
    def sample(self,batch_size):
        mini_batch = random.sample(self.buffer,batch_size)
        obs_batch, action_batch, reward_batch, next_obs_batch, terminated_batch = zip(*mini_batch)
        
        obs_batch = np.array(obs_batch)
        next_obs_batch = np.array(next_obs_batch)
        
        obs_batch = torch.tensor(obs_batch,dtype=torch.float32,device=self.device)
        
        action_batch = torch.tensor(action_batch,dtype=torch.int64,device=self.device) 
        action_batch = action_batch.unsqueeze(1)
        
        reward_batch = torch.tensor(reward_batch,dtype=torch.float32,device=self.device)
        next_obs_batch = torch.tensor(next_obs_batch,dtype=torch.float32,device=self.device)
        terminated_batch = torch.tensor(terminated_batch,dtype=torch.int64,device=self.device)
          
        return obs_batch, action_batch, reward_batch, next_obs_batch, terminated_batch
    
    def __len__(self):
        return len(self.buffer)

In [None]:
class TrainManager():
    
    def __init__(self,
                 env,
                 episode_num = 1000,
                 lr = 0.001,
                 gamma = 0.9,
                 epsilon = 0.1,
                 mini_epsilon=0.01,
                 explore_decay_rate=0.0001,
                 buffer_capacity = 2000,
                 replay_start_size = 200,
                 replay_frequent = 4,
                 target_sync_frequent = 200,
                 batch_size = 32):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = 'cpu'
        self.env = env
        self.episode_num = episode_num
        obs_size = gym.spaces.utils.flatdim(env.observation_space)
        action_size = env.action_space.n
        self.buffer = ReplayBuffer(capacity=buffer_capacity,device=self.device)
        Q_func = Q_Network(obs_size,action_size)
        Q_func.to(self.device)
        optimizer = torch.optim.Adam(Q_func.parameters(),lr=lr)
        self.agent = DQN_Agent(Q_func = Q_func,
                               action_size = action_size,
                               optimizer = optimizer,
                               replay_buffer = self.buffer,
                               replay_start_size = replay_start_size,
                               batch_size = batch_size,
                               replay_frequent = replay_frequent,
                               target_sync_frequent = target_sync_frequent,
                               epsilon = epsilon,
                               mini_epsilon = mini_epsilon,
                               explore_decay_rate = explore_decay_rate,
                               gamma = gamma,
                               device = self.device)
        
        
    def train_episode(self,is_render=False):
        total_reward = 0 
        obs,_ = self.env.reset() 
        obs = np.array(obs)
        while True:
            action = self.agent.get_behavior_action(obs) 
            next_obs, reward, terminated, _, _ = self.env.step(action) 
            total_reward += reward 
            next_obs = np.array(next_obs)
            self.agent.Q_approximation(obs,action,reward,next_obs,terminated)
            obs = next_obs
            if is_render:
                self.env.render()
                                
            if terminated:
                break
            
        return total_reward       

    def test_episode(self):
        total_reward = 0 
        obs,_ = self.env.reset() 
        obs = np.array(obs)
        while True:
            action = self.agent.get_target_action(obs) 
            next_obs, reward, terminated, _, _= self.env.step(action) 
            obs = np.array(next_obs)
            total_reward += reward
            self.env.render()
            if terminated: break
            
        return total_reward
            
            
    def train(self):        
        for e in range(self.episode_num):
            episode_reward = self.train_episode()
            print('Episode %s: Total Reward = %.2f'%(e,episode_reward)) 
            
            if e%100 == 0: 
                test_reward = self.test_episode()
                print('Test Total Reward = %.2f'%(test_reward))

In [None]:
if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    Manger = TrainManager(env = env,
                        episode_num = 1000,
                        lr = 0.001,
                        gamma = 0.9,
                        epsilon = 0.3,
                        target_sync_frequent = 200,
                        mini_epsilon = 0.01,
                        explore_decay_rate = 0.0001
                        )
    Manger.train()