In [None]:
import gymnasium as gym
import logging
import numpy as np
import sys
from logger import get_logger

In [None]:
class SARSA:
    def __init__(self,env,logger,epsilon=0.1,gamma=1.0,alpha_start=1.0):
        self.env = env
        self.logger = logger
        self.action_n = env.action_space.n
        self.obs_n = env.observation_space.n
        self.epsilon = epsilon
        self.gamma = gamma
        self.alpha_start = alpha_start
        self.q_value = np.zeros((self.obs_n,self.action_n))
        self.pi = np.random.randint(self.action_n,size=self.obs_n)
    def evaluate(self,n_iter,eval_num=1):
        expected_return = 0
        for _ in range(eval_num):
            obs, _ = self.env.reset()
            returns = 0
            done = False
            while not done:
                obs, reward, terminated, truncated, _ = self.env.step(self.pi[obs])
                returns += reward
                done = terminated or truncated
            expected_return += returns
        self.logger.info("[ITERATION {}]: The expected return is {}".format(n_iter,expected_return/eval_num))
    def run(self,num_episode=1000):
        for eps in range(num_episode):
            done = False
            obs, _ = self.env.reset()
            self.starting = obs
            action = self.env.action_space.sample() if np.random.uniform(0,1) < self.epsilon else self.pi[obs]
            t = 1
            while not done:
                obs_next, reward, terminated, truncated, _ = self.env.step(action)
                action_next = self.env.action_space.sample() if np.random.uniform(0,1) < self.epsilon else self.pi[obs_next]
                self.q_value[obs,action] += self.alpha_start/t * (reward + 
                                                                  self.gamma * self.q_value[obs_next,action_next] 
                                                                  - self.q_value[obs,action])
                self.pi[obs] = np.argmax(self.q_value[obs,:])
                obs = obs_next
                action = action_next
                done = terminated or truncated
                t += 1
            
            self.evaluate(eps)

In [None]:
cliff_env = gym.make('CliffWalking-v0')
cliff_env = gym.wrappers.TimeLimit(cliff_env,max_episode_steps=2000)
sarsa = SARSA(env=cliff_env,logger=get_logger(name='sarsaLogger',fname='sarsa.log'))

In [None]:
sarsa.run(num_episode=3000)

In [None]:
sarsa.pi

In [None]:
obs, _ = cliff_env.reset()
returns = 0
done = False
while not done:
    obs, reward, terminated, truncated, _ = cliff_env.step(sarsa.pi[obs])
    returns += reward
    done = terminated or truncated

In [None]:
returns

In [None]:
taxi_env = gym.make("Taxi-v3")
taxi_env = gym.wrappers.TimeLimit(taxi_env,max_episode_steps=5000)
sarsa_taxi = SARSA(env=taxi_env,logger=get_logger(name='sarsaTaxiLogger',fname='sarsaTaxi.log'))

In [None]:
sarsa_taxi.run(num_episode=5000)