In [None]:
import numpy as np
import matplotlib.pyplot as plt
from utils import GridWorld

%matplotlib

In [82]:
env = GridWorld()

In [83]:
# Setup
UP = 0
DOWN = 1
RIGHT = 2
LEFT = 3
actions = ['UP', 'DOWN', 'RIGHT', 'LEFT']

In [84]:
def select_greedy_policy(q_values, state, epsilon=0.1):
    ''' 
    Choose an action based on a epsilon greedy policy.    
    A random action is selected with epsilon probability, else select the best action.    
    '''
    if np.random.random() < epsilon:
        return np.random.choice(4)
    else:
        return np.argmax(q_values[state])

In [85]:
def q_learning(env, num_episodes=1000, render=True, exploration_rate=0.1,
               learning_rate=0.5, gamma=0.9):    
    q_values = np.zeros((4*12, 4))
    ep_rewards = []
    
    for _ in range(num_episodes):
        state = env.reset()    
        done = False
        reward_sum = 0

        while not done:            
            # Choose action        
            action = select_greedy_policy(q_values, state, exploration_rate)
            # Do the action
            next_state, reward, done = env.step(action)
            reward_sum += reward
            # Update q_values       
            td_target = reward + 0.9 * np.max(q_values[next_state])
            td_error = td_target - q_values[state][action]
            q_values[state][action] += learning_rate * td_error
            # Update state
            state = next_state

            if render:
                env.render(q_values, action=actions[action], colorize_q=True)
            
        ep_rewards.append(reward_sum)
    
    return ep_rewards, q_values

In [86]:
q_learning_rewards, q_values = q_learning(env, gamma=0.9, learning_rate=1, render=True)

In [89]:
def sarsa(env, num_episodes=1000, render=True, exploration_rate=0.1,
          learning_rate=0.5, gamma=0.9):
    q_values_sarsa = np.zeros((4*12, 4))
    ep_rewards = []
    
    for _ in range(num_episodes):
        state = env.reset()    
        done = False
        reward_sum = 0
        # Choose action        
        action = select_greedy_policy(q_values_sarsa, state, exploration_rate)

        while not done:        
            # Do the action
            next_state, reward, done = env.step(action)
            reward_sum += reward
            
            # Choose next action
            next_action = select_greedy_policy(q_values_sarsa, next_state, exploration_rate)
            # Next q value is the value of the next action
            td_target = reward + gamma * q_values_sarsa[next_state][next_action]
            td_error = td_target - q_values_sarsa[state][action]
            # Update q value
            q_values_sarsa[state][action] += learning_rate * td_error

            # Update state and action        
            state = next_state
            action = next_action
            
            if render:
                env.render(q_values, action=actions[action], colorize_q=True)
                
        ep_rewards.append(reward_sum)
    return ep_rewards, q_values_sarsa

# Sarsa

In [90]:
sarsa_rewards, q_values_sarsa = sarsa(env, render=True, learning_rate=0.5, gamma=0.99)

## Visualization

In [87]:
def plot_reward(values, is_q_learning = True):
    env.render(values, colorize_q=True)
    
    if is_q_learning:
        rewards, _ = zip(*[q_learning(env, render=False, exploration_rate=0.1, learning_rate=1) for _ in range(10)])
    else:
        rewards, _ = zip(*[sarsa(env, render=False, exploration_rate=0.2) for _ in range(100)])
    
    avg_rewards = np.mean(rewards, axis=0)
    mean_reward = [np.mean(avg_rewards)] * len(avg_rewards)

    fig, ax = plt.subplots()
    ax.set_xlabel('Episodes')
    ax.set_ylabel('Rewards')
    ax.plot(avg_rewards)
    ax.plot(mean_reward, 'g--')

    print('Mean Reward: {}'.format(mean_reward[0]))

def play(q_values):
    env = GridWorld()
    state = env.reset()
    done = False

    while not done:    
        # Select action
        action = select_greedy_policy(q_values, state, 0.0)
        # Do the action
        next_state, reward, done = env.step(action)  

        # Update state and action        
        state = next_state  
        
        env.render(q_values=q_values, action=actions[action], colorize_q=True)

In [88]:
plot_reward(q_values)
play(q_values)

Mean Reward: -38.8523


In [92]:
plot_reward(values = q_values_sarsa, is_q_learning = False)
play(q_values_sarsa)

Mean Reward: -84.68666
