In [1]:
from env.windy_gridworld import WindyGridWorld
import numpy as np

In [2]:
ROWS = 7
COLUMNS = 10
NUM_ACTIONS = 4
WIND_LOC = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0]
# Tuples are row column
TARGET_LOC = [3, 7]
INIT_LOCATION = [3, 0]

# Set up the environment
env = WindyGridWorld(rows=ROWS, columns=COLUMNS, init_location=np.array(INIT_LOCATION))
env.target_location = np.array(TARGET_LOC)
env.wind_location = np.array(WIND_LOC)

## n-Step Tree Backup for estimating $Q \approx q_*$ or $q_\pi$

In [3]:
def n_step_tree_backup(
        Q_init:np.ndarray,
        policy:np.ndarray,
        env:WindyGridWorld,
        n:int=1,
        episodes:int=1,
        alpha=0.1,
        gamma:float=1.0,
):
    Q = Q_init.copy()
    policy = policy.copy()
    memory = [{} for _ in range(n + 1)]
    num_actions = Q.shape[-1]

    for episode in range(episodes):
        state = env.reset()['agent']
        action = np.random.choice(num_actions)
        memory[0]['state'] = state.tolist()
        memory[0]['action'] = int(action)

        T = np.inf
        tau = 0
        t = 0
        while tau < T - 1:
            if t < T:
                state, action = memory[t % (n + 1)]['state'], memory[t % (n + 1)]['action']
                feedback = env.step(action)
                new_state = feedback[0]['agent'].tolist()
                reward = feedback[1]
                terminated = feedback[2]

                memory[(t + 1) % (n + 1)]['state'] = new_state
                memory[(t + 1) % (n + 1)]['reward'] = reward

                if terminated:
                    T = t + 1
                else:
                    new_action = np.random.choice(num_actions)
                    memory[(t + 1) % (n + 1)]['action'] = int(new_action)
            
            tau = t + 1 - n
            if tau >= 0:
                G = 0
                if t + 1 >= T:
                    G = memory[(t + 1) % (n + 1)]['reward']
                else:
                    G = memory[(t + 1) % (n + 1)]['reward']
                    state = memory[(t + 1) % (n + 1)]['state']

                    for a in range(num_actions):
                        state_action_pair = state + [a]
                        G += gamma * policy[*state_action_pair] * Q[*state_action_pair]
                
                for k in reversed(range(tau + 1, min(t + 1, T))):
                    reward_k = memory[k % (n + 1)]['reward']
                    state_k = memory[k % (n + 1)]['state']
                    action_taken = memory[k % (n + 1)]['action']

                    expectation = 0
                    for p in range(num_actions):
                        state_action_pair_k = state_k + [p]
                        if p == action_taken:
                            expectation += policy[*state_action_pair_k] * G
                        else:
                            expectation += policy[*state_action_pair_k] * Q[*state_action_pair_k]

                    G = reward_k + gamma * expectation
                
                state = memory[tau % (n + 1)]['state']
                action = memory[tau % (n + 1)]['action']
                state_action_pair = state + [action]
                Q[*state_action_pair] += alpha * (G - Q[*state_action_pair])
                # Make policy greedy with respect to Q
                q_max = Q.max(axis=-1, keepdims=True).repeat(num_actions, axis=-1)
                policy = np.isclose(Q, q_max).astype(np.float32)
                policy = policy / policy.sum(axis=-1, keepdims=True)
                # q_max = Q.argmax(axis=-1)
                # policy = np.eye(num_actions)[q_max]
            
            t += 1
        
        print(f'\rEpisode {episode:<6}', end='')
        
    return Q, policy

In [4]:
Q_init = np.zeros(shape=(ROWS, COLUMNS, NUM_ACTIONS), dtype=np.float32)
n = 5
episodes = 100
policy = np.zeros((ROWS, COLUMNS, NUM_ACTIONS), dtype=np.float32)
policy[:, :, :] = 1.0 / 4
Q, policy = n_step_tree_backup(
    Q_init=Q_init, 
    n=n, 
    episodes=episodes, 
    env=env, 
    policy=policy,
    gamma=1.0,
    alpha=0.4
)

Episode 99    

In [5]:
state = env.reset()['agent'].tolist()
step = 0
while not env.is_terminated():
    action = policy[*state].argmax()
    feedback = env.step(action)
    state = feedback[0]['agent']
    step += 1
    print(step)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
