In [1]:
from env.windy_gridworld import WindyGridWorld
import numpy as np

In [2]:
ROWS = 7
COLUMNS = 10
NUM_ACTIONS = 4
WIND_LOC = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]
# Tuples are row column
TARGET_LOC = [3, 7]
INIT_LOCATION = [3, 0]

# Set up the environment
env = WindyGridWorld(rows=ROWS, columns=COLUMNS, init_location=np.array(INIT_LOCATION))
env.target_location = np.array(TARGET_LOC)
env.wind_location = np.array(WIND_LOC)

## n-step SARSA

In [3]:
def n_step_sarsa(
    Q_init:np.ndarray,
    n:int,
    episodes:int,
    env:WindyGridWorld,
    gamma:float=1.0,
    alpha=0.1,
    epsilon=0.1,
):
    Q = Q_init.copy()
    memory = [{} for _ in range(n+1)]

    for episode in range(episodes):
        state = env.reset()['agent']
        T = np.inf
        memory[0]['state'] = state.tolist()

        # Take an action for the first time
        if np.random.rand() < epsilon:
            action = int(env.action_space.sample())
        else:
            action = int(np.argmax(Q[*state]))
        memory[0]['action'] = action
        tau = 0
        t = 0
        while tau < T-1:
            if t < T:
                # Load the current state-action pair
                action = memory[t % (n + 1)]['action']
                state = memory[t % (n + 1)]['state']

                feedback = env.step(action)
                new_state = feedback[0]['agent']
                reward = feedback[1]
                terminated = feedback[2]

                memory[(t + 1) % (n + 1)]['state'] = new_state.tolist()
                memory[(t + 1) % (n + 1)]['reward'] = reward

                if terminated:
                    T = t + 1
                else:
                    # Take an action for the time step t+1
                    if np.random.rand() < epsilon:
                        new_action = int(env.action_space.sample())
                    else:
                        new_action = int(np.argmax(Q[*state]))
                    
                    memory[(t + 1) % (n + 1)]['action'] = new_action
            
            tau = t - n + 1
            if tau >= 0:
                G = 0
                for i in range(tau + 1, min(tau + n + 1, T + 1)):
                    G += (gamma ** i) * memory[i % (n + 1)]['reward']
                
                if tau + n < T:
                    state = memory[(tau + n) % (n + 1)]['state']
                    action = memory[(tau + n) % (n + 1)]['action']
                    state_action_pair = state + [action]
                    G += (gamma ** n) * Q[*state_action_pair]
                
                state = memory[tau % (n + 1)]['state']
                action = memory[tau % (n + 1)]['action']
                state_action_pair = state + [action]
                Q[*state_action_pair] += alpha * (G - Q[*state_action_pair])
            
            t += 1

    return Q

In [4]:
Q_init = np.zeros(shape=(ROWS, COLUMNS, NUM_ACTIONS), dtype=np.float32)
n = 10
episodes = 10000
Q = n_step_sarsa(Q_init=Q_init, n=n, episodes=episodes, env=env, epsilon=0.1, alpha=0.1)