In [None]:
import numpy as np
import gymnasium as gym
import torch 
import collections
import random
import copy

In [None]:
class DQN_Agent():
    
    def __init__(self,
                 Q_func,
                 action_size,
                 optimizer,
                 replay_buffer,
                 replay_start_size,
                 batch_size,
                 replay_frequent,
                 target_sync_frequent,
                 epsilon=0.1,
                 gamma=0.9,
                 device='cpu'):
        self.device = device
        self.action_size = action_size
        
        self.exp_counter = 0
        
        self.replay_buffer = replay_buffer
        self.replay_start_size = replay_start_size
        self.batch_size = batch_size
        self.replay_frequent = replay_frequent
        
        self.target_update_frequent = target_sync_frequent
        
        self.main_Q_func = Q_func
        self.target_Q_func = copy.deepcopy(Q_func)
        
        self.criteria = torch.nn.MSELoss()
        self.optimizer = optimizer
        
        self.epsilon = epsilon
        self.gamma = gamma
        
    pass

    def get_target_action(self,obs):
        # action of determine target policy by choosing the action with the highest Q value. This method is used for testing.
        obs = torch.tensor(obs,dtype=torch.float32,device=self.device)
        # obs = torch.FloatTensor(obs)
        Q_list = self.target_Q_func(obs)
        action = int(torch.argmax(Q_list).clone().detach().cpu().numpy())
        return action

    def get_behavior_action(self,obs):
        # For such an off-policy algorithm, we just modified an epsilon-greedy policy from the target one for exploration.
        if np.random.uniform(0,1) < self.epsilon:
            action = np.random.choice(self.action_size)
        else:
            action = self.get_target_action(obs)
            
        return action
    
    def sync_target_Q_func(self):
        for target_params, main_params in zip(self.target_Q_func.parameters(), self.main_Q_func.parameters()):
            target_params.data.copy_(main_params.data)
            
    
    def batch_Q_approximation(self,batch_obs,batch_action,batch_reward,batch_next_obs,batch_terminated):
        # Here we use a batch of data to calculate current Q value. 
        # Different from the single (S,A,R,S') tuple that we use Q_func(obs)[action], each of the batch data has a Q value for different actions.
        # Therefore, we use torch.gather to get the Q value of the action that we actually take.
    
        batch_current_Q = torch.gather(self.main_Q_func(batch_obs),1,batch_action).squeeze(1)
        
        # Note that if terminated is True, there will be no next_state and next_action. In this case, the target_Q is just reward
        batch_TD_target = batch_reward + (1-batch_terminated) * self.gamma * self.target_Q_func(batch_next_obs).max(1)[0]
        loss = self.criteria(batch_current_Q,batch_TD_target)
    
        # Here, we directly use gradient descent to optimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
            
    def Q_approximation(self,obs,action,reward,next_obs,terminated):
        self.exp_counter += 1
        self.replay_buffer.append((obs,action,reward,next_obs,terminated))
        if self.exp_counter%self.target_update_frequent == 0:
            self.sync_target_Q_func()
        
        if len(self.replay_buffer) > self.replay_start_size and self.exp_counter%self.replay_frequent == 0:
            self.batch_Q_approximation(*self.replay_buffer.sample(self.batch_size))

In [None]:
class Q_Network(torch.nn.Module):
    def __init__(self,obs_size,action_size):
        super(Q_Network,self).__init__()
        self.network = self.mlp_network(obs_size,action_size)
        
    def mlp_network(self,obs_size,action_size):
        mlp = torch.nn.Sequential(
            torch.nn.Linear(obs_size,64),
            torch.nn.ReLU(),
            torch.nn.Linear(64,64),
            torch.nn.ReLU(),
            torch.nn.Linear(64,action_size)
        )
        return mlp
    
    def forward(self,x):
        return self.network(x)

In [None]:
class ReplayBuffer():
    def __init__(self,capacity,device="cpu"):
        self.device = device
        # Here, we use deque to implement the replay buffer. 
        # Collections.deque is a double-ended queue, which supports adding and removing elements from either end.
        self.buffer = collections.deque(maxlen=capacity)
        
    def append(self,exp_data):
        self.buffer.append(exp_data)
        
    def sample(self,batch_size):
        # Here, we use random.sample to randomly select a batch of experiences from the replay buffer.
        # Note that the return type of random.sample is a list, so we need to convert it to a numpy array.
        mini_batch = random.sample(self.buffer,batch_size)
        obs_batch, action_batch, reward_batch, next_obs_batch, terminated_batch = zip(*mini_batch)
        
        obs_batch = np.array(obs_batch)
        next_obs_batch = np.array(next_obs_batch)
        
        obs_batch = torch.tensor(obs_batch,dtype=torch.float32,device=self.device)
        
        action_batch = torch.tensor(action_batch,dtype=torch.int64,device=self.device) # for gather function, the index should be int type
        action_batch = action_batch.unsqueeze(1)
        
        reward_batch = torch.tensor(reward_batch,dtype=torch.float32,device=self.device)
        next_obs_batch = torch.tensor(next_obs_batch,dtype=torch.float32,device=self.device)
        terminated_batch = torch.tensor(terminated_batch,dtype=torch.int64,device=self.device)
          
        return obs_batch, action_batch, reward_batch, next_obs_batch, terminated_batch
    
    def __len__(self):
        return len(self.buffer)

In [None]:
class TrainManager():
    
    def __init__(self,
                 env,
                 episode_num = 1000,
                 lr = 0.001,
                 gamma = 0.9,
                 epsilon = 0.1,
                 buffer_capacity = 2000,
                 replay_start_size = 200,
                 replay_frequent = 4,
                 target_sync_frequent = 200,
                 batch_size = 32):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = 'cpu'
        self.env = env
        self.episode_num = episode_num
        obs_size = gym.spaces.utils.flatdim(env.observation_space)
        action_size = env.action_space.n
        self.buffer = ReplayBuffer(capacity=buffer_capacity,device=self.device)
        Q_func = Q_Network(obs_size,action_size)
        Q_func.to(self.device)
        optimizer = torch.optim.Adam(Q_func.parameters(),lr=lr)
        self.agent = DQN_Agent(Q_func = Q_func,
                               action_size = action_size,
                               optimizer = optimizer,
                               replay_buffer = self.buffer,
                               replay_start_size = replay_start_size,
                               batch_size = batch_size,
                               replay_frequent = replay_frequent,
                               target_sync_frequent = target_sync_frequent,
                               epsilon = epsilon,
                               gamma = gamma,
                               device = self.device)
        
        
    def train_episode(self,is_render=False):
        total_reward = 0 # record total reward in one episode
        obs,_ = self.env.reset() # reset env and get initial state
        obs = np.array(obs)
        while True:
            action = self.agent.get_behavior_action(obs) # get action using learned epsilon-greedy policy
            next_obs, reward, terminated, _, _ = self.env.step(action) # take action and get next_state, reward, done, info
            total_reward += reward 
            next_obs = np.array(next_obs)
            self.agent.Q_approximation(obs,action,reward,next_obs,terminated)
            obs = next_obs
            if is_render:
                self.env.render()
                                
            if terminated:
                break
            
        return total_reward       

    def test_episode(self):
        total_reward = 0 # record total reward in one episode
        obs,_ = self.env.reset() # reset env and get initial state
        obs = np.array(obs)
        while True:
            action = self.agent.get_target_action(obs) # get action using target policy
            next_obs, reward, terminated, _, _= self.env.step(action) # take action and get next_state, reward, done, info
            obs = np.array(next_obs)
            total_reward += reward
            self.env.render()
            if terminated: break
            
        return total_reward
            
            
    def train(self):        
        for e in range(self.episode_num):
            episode_reward = self.train_episode()
            print('Episode %s: Total Reward = %.2f'%(e,episode_reward)) 
            
            if e%100 == 0: 
                test_reward = self.test_episode()
                print('Test Total Reward = %.2f'%(test_reward))

In [None]:
if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    Manger = TrainManager(env = env,
                        episode_num = 1000,
                        lr = 0.001,
                        gamma = 0.9,
                        epsilon = 0.1,
                        target_sync_frequent = 200,
                        )
    Manger.train()