In [87]:
import numpy as np
import gymnasium as gym
import torch 
import torch.nn.functional as F
import collections
import random
from typing import Tuple

The *naive* Q-learning with Q function approximation (Q_learning_NN.ipynb) shows a poor performance. One of the reasons is related to the sampling part. Since the continuous collected samples are highly correlated, the estimated expectation could be biased. To solve this problem, we can use the experience replay buffer. The experience replay buffer stores the samples in a buffer and randomly sample from the buffer to train the network. This can reduce the correlation between the samples.

We can use the same network structure as the naive Q-learning with Q function approximation. The only difference is that we use the experience replay buffer to store the samples and randomly sample from the buffer to train the network. The training process uses the samples from the experience replay buffer to update the network parameters. 
\begin{equation}
w_{t+1} \gets w_t - \alpha_t \left(r_{t+1}+\gamma \max_{a^{\prime} \in \mathcal{A}(s_{t+1})} \hat{q}\left(s_{t+1}, a^{\prime}, w_t\right)-\hat{q}(s_t, a_t, w_t)\right) \nabla_{w_t} \hat{q}(s_t, a_t, w_t)
\end{equation}
where \{$s_t$, $a_t$, $r_t$, $s_{t+1}$\} is uniformly sampled from the experience replay buffer.

In [88]:
class DQN_Agent():
    
    """ Since the discrete actions have been redefined as {0,1,2,3} by using the wapper file, we can simply represent the action by a number. """
    
    def __init__(self,
                 Q_func: torch.nn.Module, 
                 action_dim: int,
                 optimizer: torch.optim.Optimizer,
                 replay_buffer: collections.deque, # Object of ReplayBuffer
                 replay_start_size: int, # The number of experiences stored in the replay buffer before learning starts
                 batch_size: int, # The number of experiences to sample from the replay buffer for every learning iteration
                 replay_frequent :int, # Train the network every {replay_frequent} steps, which would also help to decorrelate the samples
                 epsilon:float = 0.1,
                 gamma:float = 0.9,
                 device:torch.device = torch.device("cpu")
                 ) -> None:
        self.device = device
        self.action_dim = action_dim
        
        self.replay_buffer = replay_buffer
        self.replay_start_size = replay_start_size
        self.batch_size = batch_size
        self.replay_frequent = replay_frequent
        
        """Here, we set a Global Counter for **interactions with the environment**, which is used to determine when to update the target network."""
        self.exp_counter = 0 
        
        self.Q_func = Q_func
        self.criteria = torch.nn.MSELoss()
        self.optimizer = optimizer
        
        self.epsilon = epsilon
        self.gamma = gamma
        
    def get_target_action(self,obs:np.ndarray) -> int:
        obs = torch.tensor(obs,dtype=torch.float32,device=self.device)
        Q_list = self.Q_func(obs)
        action = torch.argmax(Q_list).item()
        return action

    def get_behavior_action(self,obs:np.ndarray) -> int:
        if np.random.uniform(0,1) < self.epsilon:
            action = np.random.choice(self.action_dim)
        else:
            action = self.get_target_action(obs)
            
        return action
    
    def batch_Q_approximation(self,
                              batch_obs:np.ndarray,
                              batch_action:np.ndarray,
                              batch_reward:np.ndarray,
                              batch_next_obs:np.ndarray,
                              batch_terminated:np.ndarray) -> None:
        
        """ Update the Q function by minimizing the TD-error """
        batch_current_Q = torch.gather(self.Q_func(batch_obs),1,batch_action).squeeze(1)
        batch_TD_target = batch_reward + (1-batch_terminated) * self.gamma * self.Q_func(batch_next_obs).max(1)[0] # torch.max returns a tuple (max_value, index_of_max_value)
        loss = self.criteria(batch_current_Q,batch_TD_target)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
            
    def Q_approximation(self,
                        obs:np.ndarray,
                        action:int,
                        reward:float,
                        next_obs:np.ndarray,
                        terminated:bool) -> None:
        """After each interaction with the environment, we call this function and check whether the replay buffer has enough experiences to start learning."""
        self.exp_counter += 1 # Update the Global Counter here, since we defined the counter in Agent class.
        self.replay_buffer.append((obs,action,reward,next_obs,terminated)) # Store the experience in the replay buffer
        
        # Start learning after the replay buffer has enough experiences
        if len(self.replay_buffer) > self.replay_start_size and self.exp_counter%self.replay_frequent == 0: 
            self.batch_Q_approximation(*self.replay_buffer.sample(self.batch_size))

In [89]:
class Q_Network(torch.nn.Module):
    def __init__(self,obs_dim:int,action_dim) -> None:
        super(Q_Network,self).__init__()
        self.fc1 = torch.nn.Linear(obs_dim,64)
        self.fc2 = torch.nn.Linear(64,64)
        self.fc3 = torch.nn.Linear(64,action_dim)
            
    def forward(self,x:torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return self.fc3(x)

Here, we use python build-in libratory collections to implement the experience replay buffer. The deque is a double-ended queue. It can be used to add or remove elements from both ends. More details can be found in https://docs.python.org/2/library/collections.html#collections.deque. 

In this code, the replay will start after collecting certain number of experiences.

In [90]:
class ReplayBuffer():
    def __init__(self,capacity:int,device:torch.device = torch.device("cpu")) -> None:
        """The parameter "capacity" is the maximum number of experiences that can be stored in the replay buffer. 
            If the number of experiences exceeds the capacity, the oldest experiences will be removed."""
        self.device = device
        """Here, we use deque to implement the replay buffer. 
            Collections.deque is a double-ended queue, which supports adding and removing elements from either end."""
        self.buffer = collections.deque(maxlen=capacity)
        
    def append(self,exp_data:tuple) -> None:
        self.buffer.append(exp_data)
        
    def sample(self,batch_size:int) -> Tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]:
        """Here, we use random.sample to randomly select a batch of experiences from the replay buffer.
            Note that the return type of random.sample is a list, we need to convert it to a numpy array to avoid low efficiency"""
        mini_batch = random.sample(self.buffer,batch_size)
        obs_batch, action_batch, reward_batch, next_obs_batch, terminated_batch = zip(*mini_batch)
        
        obs_batch = np.array(obs_batch)
        next_obs_batch = np.array(next_obs_batch)
        
        obs_batch = torch.tensor(obs_batch,dtype=torch.float32,device=self.device)
        
        action_batch = torch.tensor(action_batch,dtype=torch.int64,device=self.device) # for gather function, the index should be int type
        action_batch = action_batch.unsqueeze(1)
        
        reward_batch = torch.tensor(reward_batch,dtype=torch.float32,device=self.device)
        next_obs_batch = torch.tensor(next_obs_batch,dtype=torch.float32,device=self.device)
        terminated_batch = torch.tensor(terminated_batch,dtype=torch.int64,device=self.device)
          
        return obs_batch, action_batch, reward_batch, next_obs_batch, terminated_batch
    
    def __len__(self) -> int:
        return len(self.buffer)

In [91]:
class TrainManager():
    
    def __init__(self,
                 env:gym.Env,
                 episode_num:int = 1000,
                 lr:float = 0.001,
                 gamma:float = 0.9,
                 epsilon:float = 0.1,
                 buffer_capacity:int = 2000,
                 replay_start_size:int = 200,
                 replay_frequent:int = 4,
                 batch_size:int = 32) -> None:
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = 'cpu'
        self.env = env
        self.episode_num = episode_num
        obs_dim = gym.spaces.utils.flatdim(env.observation_space) 
        action_dim = env.action_space.n
        self.buffer = ReplayBuffer(capacity=buffer_capacity,device=self.device)
        Q_func = Q_Network(obs_dim,action_dim).to(self.device)
        optimizer = torch.optim.Adam(Q_func.parameters(),lr=lr)
        self.agent = DQN_Agent(Q_func = Q_func,
                               action_dim = action_dim,
                               optimizer = optimizer,
                               replay_buffer = self.buffer,
                               replay_start_size = replay_start_size,
                               batch_size = batch_size,
                               replay_frequent=replay_frequent,
                               epsilon = epsilon,
                               gamma = gamma,
                               device = self.device)
        
    def train_episode(self,is_render:bool=False) -> float:
        total_reward = 0 
        obs,_ = self.env.reset() 
        obs = np.array(obs)
        while True:
            action = self.agent.get_behavior_action(obs) 
            next_obs, reward, terminated, _, _ = self.env.step(action) 
            total_reward += reward 
            next_obs = np.array(next_obs)
            self.agent.Q_approximation(obs,action,reward,next_obs,terminated)
            obs = next_obs
            if is_render:
                self.env.render()
                                
            if terminated:
                break
            
        return total_reward       

    def test_episode(self) -> float:
        total_reward = 0 
        obs,_ = self.env.reset() 
        obs = np.array(obs)
        while True:
            action = self.agent.get_target_action(obs) 
            next_obs, reward, terminated, _, _= self.env.step(action) 
            obs = np.array(next_obs)
            total_reward += reward
            self.env.render()
            if terminated: break
            
        return total_reward
            
            
    def train(self) -> None:        
        for e in range(self.episode_num):
            episode_reward = self.train_episode()
            print('Episode %s: Total Reward = %.2f'%(e,episode_reward)) 
            
            if e % 100 == 0: 
                test_reward = self.test_episode()
                print('Test Total Reward = %.2f'%(test_reward))

In [92]:
if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    Manger = TrainManager(env = env,
                        episode_num = 1000,
                        lr = 0.001,
                        gamma = 0.9,
                        epsilon = 0.1
                        )
    Manger.train()

Episode 0: Total Reward = 10.00
Test Total Reward = 8.00
Episode 1: Total Reward = 9.00
Episode 2: Total Reward = 9.00
Episode 3: Total Reward = 10.00
Episode 4: Total Reward = 10.00
Episode 5: Total Reward = 11.00
Episode 6: Total Reward = 9.00
Episode 7: Total Reward = 10.00
Episode 8: Total Reward = 12.00
Episode 9: Total Reward = 13.00
Episode 10: Total Reward = 11.00
Episode 11: Total Reward = 9.00
Episode 12: Total Reward = 8.00
Episode 13: Total Reward = 9.00
Episode 14: Total Reward = 9.00
Episode 15: Total Reward = 11.00
Episode 16: Total Reward = 12.00
Episode 17: Total Reward = 8.00
Episode 18: Total Reward = 10.00
Episode 19: Total Reward = 9.00
torch.return_types.max(
values=tensor([-0.0192, -0.0087, -0.0019,  0.0412,  0.0366, -0.0021,  0.0236,  0.0380,
         0.0421,  0.0036, -0.0198, -0.0201, -0.0028,  0.0424,  0.0007,  0.0014,
         0.0367, -0.0099,  0.0141,  0.0386,  0.0401,  0.0048, -0.0122,  0.0126,
        -0.0106,  0.0038,  0.0100,  0.0119,  0.0158,  0.0403, -

  gym.logger.warn(


torch.return_types.max(
values=tensor([0.0670, 0.0667, 0.0858, 0.0938, 0.0965, 0.0895, 0.0887, 0.0695, 0.0970,
        0.0602, 0.0884, 0.0735, 0.0644, 0.0900, 0.0677, 0.0870, 0.0979, 0.0964,
        0.0955, 0.0587, 0.0609, 0.0918, 0.0818, 0.0927, 0.0898, 0.0738, 0.0860,
        0.0943, 0.0691, 0.0617, 0.0966, 0.0926], device='cuda:0',
       grad_fn=<MaxBackward0>),
indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0'))
Episode 21: Total Reward = 10.00
torch.return_types.max(
values=tensor([0.1257, 0.1215, 0.1011, 0.1031, 0.0894, 0.1071, 0.1027, 0.1063, 0.1029,
        0.0984, 0.1089, 0.1075, 0.1148, 0.0938, 0.0894, 0.1097, 0.0915, 0.0998,
        0.0957, 0.0897, 0.1003, 0.0903, 0.1124, 0.0872, 0.1055, 0.1107, 0.0982,
        0.1077, 0.1072, 0.0979, 0.1073, 0.1002], device='cuda:0',
       grad_fn=<MaxBackward0>),
indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  

KeyboardInterrupt: 