# =========================Libraries ============================

In [1]:
import gym 
import itertools 
import matplotlib 
import matplotlib.style 
import numpy as np 
import pandas as pd 
import sys

from collections import defaultdict
from lib.envs.windy_gridworld import WindyGridworldEnv
from lib import plotting 
  
matplotlib.style.use('ggplot')

# ======================== Parameters ==========================

In [19]:
num_episodes    = 200
max_steps       = 10000

gamma        = 0.9
alpha        = 0.9
epsilon      = 0.1

# ======================== Functions ==========================

In [20]:
def make_epsilon_greedy_policy(Q, epsilon, nA):
    """
    Creates an epsilon-greedy policy based on a given Q-function and epsilon.
    
    Args:
        Q: A dictionary that maps from state -> action-values.
            Each value is a numpy array of length nA (see below)
        epsilon: The probability to select a random action . float between 0 and 1.
        nA: Number of actions in the environment.
    
    Returns:
        A function that takes the observation as an argument and returns
        the probabilities for each action in the form of a numpy array of length nA.
    
    """
    def policy_fn(observation):
        A = np.ones(nA, dtype=float) * epsilon / nA
        best_action = np.argmax(Q[observation])
        A[best_action] += (1.0 - epsilon)
        return A
    return policy_fn

In [25]:
def sarsa(env, num_episodes, max_steps,gamma, alpha, epsilon):
    
    """
    SARSA algorithm: On-policy TD control. Finds the optimal epsilon-greedy policy.
    
    Args:
        env: OpenAI environment.
        num_episodes: Number of episodes to run for.
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        A tuple (Q, stats).
        Q is the optimal action-value function, a dictionary mapping state -> action values.
        stats is an EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    """
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    
    # Keeps track of useful statistics
    stats = plotting.EpisodeStats(episode_lengths=np.zeros(num_episodes),episode_rewards=np.zeros(num_episodes))
    
    # The policy we're following
    policy = make_epsilon_greedy_policy(Q, epsilon, env.action_space.n)
    
    for ep_idx in range(num_episodes):
        
        state       = env.reset()
        action_prob = policy(state)
        action      = np.random.choice(np.arange(len(action_prob)), p=action_prob)
        
        for t in itertools.count():
            
            next_state, reward, done, _ = env.step(action)
            next_action_prob = policy(state)
            next_action      = np.random.choice(np.arange(len(next_action_prob)), p=next_action_prob)
            
            error = reward + gamma*Q[next_state][next_action] - Q[state][action]
            
            Q[state][action] += alpha*error
            
            stats.episode_rewards[ep_idx] += reward
            stats.episode_lengths[ep_idx] = t
            
            if done:
                break
            
            next_state  = state
            next_action = action
            
    return Q, stats

# =========================== Main =============================

In [26]:
env = WindyGridworldEnv()

In [27]:
Q, stats = sarsa(env, num_episodes, max_steps,gamma, alpha, epsilon)

KeyboardInterrupt: 

In [None]:
plotting.plot_episode_stats(stats)