In [1]:
import numpy as np
import gymnasium as gym
import Toy_Envs.gridworld as gw
import time

The core idea of Sarsa (i.e., Step-1 TD Algorithm) is to solve the **Bellman equation** (action value ver.) by using stochastic approximation, so that policy evaluation can be done.

\begin{equation}
\begin{aligned}
q_{t+1}\left(s_t, a_t\right) & =q_t\left(s_t, a_t\right)-\alpha_t\left(s_t, a_t\right)\left[q_t\left(s_t, a_t\right)-\left[r_{t+1}+\gamma q_t\left(s_{t+1}, a_{t+1}\right)\right]\right] \\
q_{t+1}(s, a) & =q_t(s, a), \quad \text { for all }(s, a) \neq\left(s_t, a_t\right)
\end{aligned}
\end{equation}

As q is updated, the policy is also updated by using the $\epsilon$-greedy policy.

\begin{equation}
\begin{aligned}
& \pi_{t+1}\left(a \mid s_t\right)=1-\frac{\epsilon}{|\mathcal{A}(s)|}(|\mathcal{A}(s)|-1) \text { if } a=\arg \max _a q_{t+1}\left(s_t, a\right) \\
& \pi_{t+1}\left(a \mid s_t\right)=\frac{\epsilon}{|\mathcal{A}(s)|} \text { otherwise }
\end{aligned}
\end{equation}

In [2]:
class Sarsa_Agent():
    """ Since the discrete actions have been redefined as {0,1,2,3} by using the wapper file, we can simply represent the action by a number. """
    
    def __init__(self,
                 obs_dim:int,
                 action_dim:int,
                 epsilon:float = 0.1,
                 lr:float = 0.1,
                 gamma:float = 0.9) -> None:
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.Q = np.zeros((self.obs_dim,self.action_dim))
        
        self.epsilon = epsilon
        self.lr = lr
        self.gamma = gamma

    def get_greedy_action(self,obs:int) -> int:
        # action of determine policy by policy improvement
        Q_list = self.Q[obs,:]
        # action = np.argmax(Q_list) # for this method, [0,0,0,0] will always choose action[0]
        action = np.random.choice(np.flatnonzero(Q_list==Q_list.max()))  # for this method, [0,0,0,0] will choose action[0,1,2,3] randomly
        return action

    def get_action(self,obs:int) -> int:
        # epsilon-greedy policy
        if np.random.uniform(0,1) < self.epsilon:
            action = np.random.choice(self.action_dim)
        else:
            # use improved policy
            action = self.get_greedy_action(obs)
            
        return action
    
    def policy_evaluation(self,
                          obs:int,
                          action:int,
                          reward:float,
                          next_obs:int,
                          next_action:int,
                          done:bool) -> None:
        current_Q = self.Q[obs,action]
        # Note that if terminated is True, there will be no next_state and next_action. In this case, the target_Q is just reward
        TD_target = reward + (1-float(done)) * self.gamma * self.Q[next_obs,next_action]
        self.Q[obs,action] -= self.lr * (current_Q - TD_target)
        

In [3]:
class TrainManager():
    
    def __init__(self,
                 env:gym.Env,
                 episode_num:int = 1000,
                 lr:float = 0.1,
                 gamma:float = 0.9,
                 epsilon:float = 0.1) -> None:  
        
        self.env = env
        self.episode_num = episode_num
        obs_dim = env.observation_space.n
        action_dim = env.action_space.n
        self.agent = Sarsa_Agent(
                    obs_dim = obs_dim, 
                    action_dim = action_dim,
                    epsilon = epsilon,
                    lr = lr, 
                    gamma = gamma 
                )
    
    def train_episode(self,is_render:bool=False) -> float:
        total_reward = 0 # record total reward in one episode
        obs,_ = self.env.reset() # reset env and get initial state
        while True:
            action = self.agent.get_action(obs) # get action using learned epsilon-greedy policy
            next_obs, reward, terminated, truncated, _ = self.env.step(action) # take action and get next_state, reward, done, info
            """In Gymnasium or Gym v0.26, done is True when terminated or truncated. 
                In Gym early version, there is done but not terminated, truncated here.
                You should modify the code according to the version of Gym you use."""
            done = terminated or truncated
            total_reward += reward
            # For Sarsa, we NEED obtain a' using the current policy        
            next_action = self.agent.get_action(next_obs)
            # using data to do policy evaluation
            self.agent.policy_evaluation(obs,action,reward,next_obs,next_action,done)
            # update state and action
            obs = next_obs     
            if is_render:
                self.env.render() # !! You can find the game window in the taskbar !!
                time.sleep(0.1) # set a speed for visualization
                
            if done:
                break
            
        return total_reward       

    def test_episode(self) -> float:
        total_reward = 0 # record total reward in one episode
        obs,_ = self.env.reset() # reset env and get initial state
        while True:
            action = self.agent.get_greedy_action(obs) # get action using learned greedy policy
            next_obs, reward, terminated, truncated, _= self.env.step(action) # take action and get next_state, reward, done, info
            done = terminated or truncated
            obs = next_obs
            total_reward += reward
            self.env.render()
            time.sleep(0.1)
            if done:
                break
            
        return total_reward
            
    def train(self) -> None:       
        is_render = False
        for e in range(self.episode_num):
            episode_reward = self.train_episode(is_render)
            print('Episode %s: Total Reward = %.2f'%(e,episode_reward)) 
            
            if e%50==0:
                is_render = True
            else:
                is_render = False
                
        test_reward = self.test_episode()
        print('Test Total Reward = %.2f'%(test_reward))

In [4]:
if __name__ == "__main__":
    env = gym.make('CliffWalking-v0')
    env = gw.CliffWalkingWapper(env)
    
    """Here is another game you can try."""
    
    # env = gym.make("FrozenLake-v1", is_slippery=False)
    # env = gw.FrozenLakeWapper(env)
    
    Manger = TrainManager(env=env,
                        episode_num=500,
                        lr=0.1,
                        gamma=0.9,
                        epsilon=0.1
                        )
    Manger.train()

Episode 0: Total Reward = -2313.00
Episode 1: Total Reward = -179.00
Episode 2: Total Reward = -329.00
Episode 3: Total Reward = -291.00
Episode 4: Total Reward = -193.00
Episode 5: Total Reward = -293.00
Episode 6: Total Reward = -483.00
Episode 7: Total Reward = -380.00
Episode 8: Total Reward = -69.00
Episode 9: Total Reward = -350.00
Episode 10: Total Reward = -132.00
Episode 11: Total Reward = -253.00
Episode 12: Total Reward = -95.00
Episode 13: Total Reward = -89.00
Episode 14: Total Reward = -271.00
Episode 15: Total Reward = -60.00
Episode 16: Total Reward = -57.00
Episode 17: Total Reward = -294.00
Episode 18: Total Reward = -183.00
Episode 19: Total Reward = -101.00
Episode 20: Total Reward = -50.00
Episode 21: Total Reward = -176.00
Episode 22: Total Reward = -69.00
Episode 23: Total Reward = -99.00
Episode 24: Total Reward = -222.00
Episode 25: Total Reward = -86.00
Episode 26: Total Reward = -65.00
Episode 27: Total Reward = -75.00
Episode 28: Total Reward = -128.00
Episo

: 