In [7]:
import numpy as np
import gymnasium as gym
import torch 
import collections
import random

The *naive* Q-learning with Q function approximation (Q_learning_NN.ipynb) shows a poor performance. One of the reasons is related to the sampling part. Since the continuous collected samples are highly correlated, the estimated expectation could be biased. To solve this problem, we can use the experience replay buffer. The experience replay buffer stores the samples in a buffer and randomly sample from the buffer to train the network. This can reduce the correlation between the samples.

We can use the same network structure as the naive Q-learning with Q function approximation. The only difference is that we use the experience replay buffer to store the samples and randomly sample from the buffer to train the network. The training process uses the samples from the experience replay buffer to update the network parameters. 
\begin{equation}
w_{t+1} \gets w_t - \alpha_t \left(r_{t+1}+\gamma \max_{a^{\prime} \in \mathcal{A}(s_{t+1})} \hat{q}\left(s_{t+1}, a^{\prime}, w_t\right)-\hat{q}(s_t, a_t, w_t)\right) \nabla_{w_t} \hat{q}(s_t, a_t, w_t)
\end{equation}
where \{$s_t$, $a_t$, $r_t$, $s_{t+1}$\} is uniformly sampled from the experience replay buffer.

In [8]:
class DQN_Agent():
    
    """ Since the discrete actions have been redefined as {0,1,2,3} by using the wapper file, we can simply represent the action by a number. """
    
    def __init__(self,
                 Q_func,
                 action_size,
                 optimizer,
                 replay_buffer, # Object of ReplayBuffer
                 replay_start_size, # The number of experiences stored in the replay buffer before learning starts
                 batch_size, # The number of experiences to sample from the replay buffer for every learning iteration
                 replay_frequent, # Train the network every {replay_frequent} steps, which would also help to decorrelate the samples
                 epsilon=0.1,
                 gamma=0.9,
                 device='cpu'):
        self.device = device
        self.action_size = action_size
        
        self.replay_buffer = replay_buffer
        self.replay_start_size = replay_start_size
        self.batch_size = batch_size
        self.replay_frequent = replay_frequent
        
        """Here, we set a Global Counter for **interactions with the environment**, which is used to determine when to update the target network."""
        self.exp_counter = 0 
        
        self.Q_func = Q_func
        self.criteria = torch.nn.MSELoss()
        self.optimizer = optimizer
        
        self.epsilon = epsilon
        self.gamma = gamma
        
    pass

    def get_target_action(self,obs):
        obs = torch.tensor(obs,dtype=torch.float32,device=self.device)
        Q_list = self.Q_func(obs)
        action = int(torch.argmax(Q_list).clone().detach().cpu().numpy())
        return action

    def get_behavior_action(self,obs):
        if np.random.uniform(0,1) < self.epsilon:
            action = np.random.choice(self.action_size)
        else:
            action = self.get_target_action(obs)
            
        return action
    
    def batch_Q_approximation(self,batch_obs,batch_action,batch_reward,batch_next_obs,batch_terminated):
        """"""
        batch_current_Q = torch.gather(self.Q_func(batch_obs),1,batch_action).squeeze(1)
        batch_TD_target = batch_reward + (1-batch_terminated) * self.gamma * self.Q_func(batch_next_obs).max(1)[0]
        loss = self.criteria(batch_current_Q,batch_TD_target)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
            
    def Q_approximation(self,obs,action,reward,next_obs,terminated):
        """After each interaction with the environment, we call this function and check whether the replay buffer has enough experiences to start learning."""
        self.exp_counter += 1 # Update the Global Counter here, since we defined the counter in Agent class.
        self.replay_buffer.append((obs,action,reward,next_obs,terminated)) # Store the experience in the replay buffer
        
        # Start learning after the replay buffer has enough experiences
        if len(self.replay_buffer) > self.replay_start_size and self.exp_counter%self.replay_frequent == 0: 
            self.batch_Q_approximation(*self.replay_buffer.sample(self.batch_size))

In [9]:
class Q_Network(torch.nn.Module):
    def __init__(self,obs_size,action_size):
        super(Q_Network,self).__init__()
        self.network = self.mlp_network(obs_size,action_size)
        
    def mlp_network(self,obs_size,action_size):
        mlp = torch.nn.Sequential(
            torch.nn.Linear(obs_size,64),
            torch.nn.ReLU(),
            torch.nn.Linear(64,64),
            torch.nn.ReLU(),
            torch.nn.Linear(64,action_size)
        )
        return mlp
    
    def forward(self,x):
        return self.network(x)

Here, we use python build-in libratory collections to implement the experience replay buffer. The deque is a double-ended queue. It can be used to add or remove elements from both ends. More details can be found in https://docs.python.org/2/library/collections.html#collections.deque. 

In this code, the replay will start after collecting certain number of experiences.

In [10]:
class ReplayBuffer():
    def __init__(self,capacity,device="cpu"):
        """Capacity is the maximum number of experiences that can be stored in the replay buffer. 
            If the number of experiences exceeds the capacity, the oldest experiences will be removed."""
        self.device = device
        """Here, we use deque to implement the replay buffer. 
            Collections.deque is a double-ended queue, which supports adding and removing elements from either end."""
        self.buffer = collections.deque(maxlen=capacity)
        
    def append(self,exp_data):
        self.buffer.append(exp_data)
        
    def sample(self,batch_size):
        """Here, we use random.sample to randomly select a batch of experiences from the replay buffer.
            Note that the return type of random.sample is a list, so we need to convert it to a numpy array."""
        mini_batch = random.sample(self.buffer,batch_size)
        obs_batch, action_batch, reward_batch, next_obs_batch, terminated_batch = zip(*mini_batch)
        
        obs_batch = np.array(obs_batch)
        next_obs_batch = np.array(next_obs_batch)
        
        obs_batch = torch.tensor(obs_batch,dtype=torch.float32,device=self.device)
        
        action_batch = torch.tensor(action_batch,dtype=torch.int64,device=self.device) # for gather function, the index should be int type
        action_batch = action_batch.unsqueeze(1)
        
        reward_batch = torch.tensor(reward_batch,dtype=torch.float32,device=self.device)
        next_obs_batch = torch.tensor(next_obs_batch,dtype=torch.float32,device=self.device)
        terminated_batch = torch.tensor(terminated_batch,dtype=torch.int64,device=self.device)
          
        return obs_batch, action_batch, reward_batch, next_obs_batch, terminated_batch
    
    def __len__(self):
        return len(self.buffer)

In [11]:
class TrainManager():
    
    def __init__(self,
                 env,
                 episode_num = 1000,
                 lr = 0.001,
                 gamma = 0.9,
                 epsilon = 0.1,
                 buffer_capacity = 2000,
                 replay_start_size = 200,
                 replay_frequent = 4,
                 batch_size = 32):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = 'cpu'
        self.env = env
        self.episode_num = episode_num
        obs_size = gym.spaces.utils.flatdim(env.observation_space) 
        action_size = env.action_space.n
        self.buffer = ReplayBuffer(capacity=buffer_capacity,device=self.device)
        Q_func = Q_Network(obs_size,action_size).to(self.device)
        optimizer = torch.optim.Adam(Q_func.parameters(),lr=lr)
        self.agent = DQN_Agent(Q_func = Q_func,
                               action_size = action_size,
                               optimizer = optimizer,
                               replay_buffer = self.buffer,
                               replay_start_size = replay_start_size,
                               batch_size = batch_size,
                               replay_frequent=replay_frequent,
                               epsilon = epsilon,
                               gamma = gamma,
                               device = self.device)
        
    def train_episode(self,is_render=False):
        total_reward = 0 
        obs,_ = self.env.reset() 
        obs = np.array(obs)
        while True:
            action = self.agent.get_behavior_action(obs) 
            next_obs, reward, terminated, _, _ = self.env.step(action) 
            total_reward += reward 
            next_obs = np.array(next_obs)
            self.agent.Q_approximation(obs,action,reward,next_obs,terminated)
            obs = next_obs
            if is_render:
                self.env.render()
                                
            if terminated:
                break
            
        return total_reward       

    def test_episode(self):
        total_reward = 0 
        obs,_ = self.env.reset() 
        obs = np.array(obs)
        while True:
            action = self.agent.get_target_action(obs) 
            next_obs, reward, terminated, _, _= self.env.step(action) 
            obs = np.array(next_obs)
            total_reward += reward
            self.env.render()
            if terminated: break
            
        return total_reward
            
            
    def train(self):        
        for e in range(self.episode_num):
            episode_reward = self.train_episode()
            print('Episode %s: Total Reward = %.2f'%(e,episode_reward)) 
            
            if e % 100 == 0: 
                test_reward = self.test_episode()
                print('Test Total Reward = %.2f'%(test_reward))

In [12]:
if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    Manger = TrainManager(env = env,
                        episode_num = 1000,
                        lr = 0.001,
                        gamma = 0.9,
                        epsilon = 0.1
                        )
    Manger.train()

Episode 0: Total Reward = 12.00
Test Total Reward = 10.00
Episode 1: Total Reward = 9.00
Episode 2: Total Reward = 9.00
Episode 3: Total Reward = 11.00
Episode 4: Total Reward = 9.00
Episode 5: Total Reward = 11.00
Episode 6: Total Reward = 9.00
Episode 7: Total Reward = 10.00
Episode 8: Total Reward = 8.00
Episode 9: Total Reward = 8.00
Episode 10: Total Reward = 13.00
Episode 11: Total Reward = 9.00
Episode 12: Total Reward = 10.00
Episode 13: Total Reward = 10.00
Episode 14: Total Reward = 10.00
Episode 15: Total Reward = 9.00
Episode 16: Total Reward = 10.00
Episode 17: Total Reward = 9.00
Episode 18: Total Reward = 8.00
Episode 19: Total Reward = 11.00
Episode 20: Total Reward = 9.00
Episode 21: Total Reward = 9.00
Episode 22: Total Reward = 8.00
Episode 23: Total Reward = 9.00
Episode 24: Total Reward = 9.00
Episode 25: Total Reward = 10.00


  gym.logger.warn(


Episode 26: Total Reward = 10.00
Episode 27: Total Reward = 10.00
Episode 28: Total Reward = 10.00
Episode 29: Total Reward = 12.00
Episode 30: Total Reward = 10.00
Episode 31: Total Reward = 8.00
Episode 32: Total Reward = 9.00
Episode 33: Total Reward = 10.00
Episode 34: Total Reward = 10.00
Episode 35: Total Reward = 11.00
Episode 36: Total Reward = 8.00
Episode 37: Total Reward = 10.00
Episode 38: Total Reward = 10.00
Episode 39: Total Reward = 9.00
Episode 40: Total Reward = 12.00
Episode 41: Total Reward = 9.00
Episode 42: Total Reward = 9.00
Episode 43: Total Reward = 12.00
Episode 44: Total Reward = 10.00
Episode 45: Total Reward = 11.00
Episode 46: Total Reward = 10.00
Episode 47: Total Reward = 11.00
Episode 48: Total Reward = 12.00
Episode 49: Total Reward = 10.00
Episode 50: Total Reward = 8.00
Episode 51: Total Reward = 12.00
Episode 52: Total Reward = 9.00
Episode 53: Total Reward = 9.00
Episode 54: Total Reward = 9.00
Episode 55: Total Reward = 10.00
Episode 56: Total Re

KeyboardInterrupt: 