In [2]:
import gymnasium as gym
from gymnasium import spaces
from collections import defaultdict
import numpy as np

class NegativeHoleRewardWrapper(gym.RewardWrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        if reward == 0 and terminated:
            reward = -1
        
        return observation, reward, terminated, truncated, info

env = gym.make('FrozenLake-v1', desc=["SFFF", "FHFH", "FFFH", "HFFG"], is_slippery=False)
wrapped_env = NegativeHoleRewardWrapper(env)
observation, info = wrapped_env.reset(seed=42)

In [65]:
def epsilon_greedy(Q_func, S, epsilon):
    r = np.random.random()
    if r > epsilon:
        return np.argmax(Q_func[S])
    else:
        return np.random.choice([0, 1, 2, 3])

def Sarsa(policy, _env, n_ep):
    Q_func = defaultdict(lambda: [0] * 4)
    N = defaultdict(lambda: [0] * 4)
    N_s = defaultdict(int)
    N_0 = 100
    successes = 0
    _gamma = 0.5
    _alpha = 0.5

    for i in range(n_ep):
        observation, info = _env.reset()

        N_s[observation] += 1
        epsilon = N_0/(N_0 + N_s[observation])
        action = policy(Q_func, observation, epsilon)
    
        while True:
            S_t = observation
            A_t = action
            
            observation, reward, terminated, truncated, info = _env.step(action)
            R_t_plus_1 = reward
            S_t_plus_1 = observation

            if not (terminated or truncated):
                N_s[observation] += 1
                epsilon = N_0/(N_0 + N_s[observation])
                A_t_plus_1 = policy(Q_func, S_t_plus_1, epsilon)      
            else:
                A_t_plus_1 = None

            if not (terminated or truncated):
                Q_func[S_t][A_t] += _alpha * (R_t_plus_1 + _gamma * Q_func[S_t_plus_1][A_t_plus_1] - Q_func[S_t][A_t])
            else:
                Q_func[S_t][A_t] += _alpha * (R_t_plus_1 - Q_func[S_t][A_t])
                
            observation = S_t_plus_1
            action = A_t_plus_1
            
            if terminated or truncated:
                if reward == 1:
                    successes += 1
                break
                
    print("Number of successes:", successes)
    return Q_func

Q_func = Sarsa(epsilon_greedy, env, 1000) 
# still very inconsistent lol
env.close()

Number of successes: 587


In [None]:
def Sarsa_lambda(policy, _env, n_ep):
    Q_func = defaultdict(lambda: [0] * 4)
    N_s = defaultdict(int)
    N_0 = 100
    successes = 0

    E_trace = defaultdict(lambda: [0] * 4)
    _gamma = 0.5
    _lambda = 0.9
    _alpha = 0.5

    for i in range(n_ep):
        observation, info = _env.reset()

        N_s[observation] += 1
        epsilon = N_0/(N_0 + N_s[observation])
        action = policy(Q_func, observation, epsilon)
    
        while True:
            S_t = observation
            A_t = action
            
            observation, reward, terminated, truncated, info = _env.step(action)
            R_t_plus_1 = reward
            S_t_plus_1 = observation

            if not (terminated or truncated):
                N_s[observation] += 1
                epsilon = N_0/(N_0 + N_s[observation])
                A_t_plus_1 = policy(Q_func, S_t_plus_1, epsilon)
                TD_error = R_t_plus_1 + (_gamma * Q_func[S_t_plus_1][A_t_plus_1]) - Q_func[S_t][A_t]
            else:
                A_t_plus_1 = None

            E_trace[S_t][A_t] += 1
            
            if not (terminated or truncated):
                for S in Q_func:
                    for A in range(_env.action_space.n):
                        Q_func[S][A] += _alpha * TD_error * E_trace[S][A]
                        E_trace[S][A] = _gamma * _lambda * E_trace[S][A]    
            else:
                Q_func[S_t][A_t] += _alpha * (R_t_plus_1 - Q_func[S_t][A_t])

            observation = S_t_plus_1
            action = A_t_plus_1
            
            if terminated or truncated:
                if reward == 1:
                    successes += 1
                break
            
    print("Number of successes:", successes)
    return Q_func

In [67]:
def greedy(Q_func, S):
    return np.argmax(Q_func[S])

def generate_episode(policy, _env, Q_func):
    observation, info = _env.reset()

    while True:
        action = policy(Q_func, observation)
        observation, reward, terminated, truncated, info = _env.step(action)
        
        if terminated or truncated:
            break

env_1 = gym.make('FrozenLake-v1', desc=["SFFF", "FHFH", "FFFH", "HFFG"], is_slippery=False, render_mode='human')
observation, info = wrapped_env.reset(seed=42)

for _ in range(10):
    generate_episode(greedy, env_1, Q_func)
    
env_1.close()