In [23]:
%matplotlib inline

import gym
import itertools
import matplotlib
import numpy as np
import pandas as pd
import sys

if "../" not in sys.path:
  sys.path.append("../") 

from collections import defaultdict
from lib.envs.cliff_walking import CliffWalkingEnv
from lib import plotting

matplotlib.style.use('ggplot')

In [24]:
env = CliffWalkingEnv()

In [None]:
def make_epsilon_greedy_policy(Q, epsilon, nA):
    """
    Creates an epsilon-greedy policy based on a given Q-function and epsilon.
    
    Args:
        Q: A dictionary that maps from state -> action-values.
            Each value is a numpy array of length nA (see below)
        epsilon: The probability to select a random action . float between 0 and 1.
        nA: Number of actions in the environment.
    
    Returns:
        A function that takes the observation as an argument and returns
        the probabilities for each action in the form of a numpy array of length nA.
    
    """
    def policy_fn(observation):
        A = np.ones(nA, dtype=float) * epsilon / nA        
        best_action = np.argmax(Q[observation])
        A[best_action] += (1.0 - epsilon)
        return A
    return policy_fn

In [120]:
def td_lambda(env , num_episodes, lamda = 0.8, discount_factor=1.0, alpha=0.5, epsilon=0.1):
    """
    Q-Learning algorithm: Off-policy TD control. Finds the optimal greedy policy
    while following an epsilon-greedy policy
    
    Args:
        env: OpenAI environment.
        num_episodes: Number of episodes to run for.
        discount_factor: Gamma discount factor.
        alpha: TD learning rate.
        epsilon: Chance the sample a random action. Float betwen 0 and 1.
    
    Returns:
        A tuple (Q, episode_lengths).
        Q is the optimal action-value function, a dictionary mapping state -> action values.
        stats is an EpisodeStats object with two numpy arrays for episode_lengths and episode_rewards.
    """
    
    # The final action-value function.
    # A nested dictionary that maps state -> (action -> action-value).
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    e = defaultdict(lambda: np.zeros(env.action_space.n))
    # Keeps track of useful statistics
    episode_lengths= []
    episode_rewards= [] 
    
    # The policy we're following
    policy = make_epsilon_greedy_policy(Q, epsilon, env.action_space.n)
    
    for i_episode in range(num_episodes):
        # Print out which episode we're on, useful for debugging.
        #if (i_episode + 1) % 100 == 0:
        print "\rEpisode {}/{}.".format(i_episode + 1, num_episodes)
        sys.stdout.flush()
        done = False
        s = env.reset()
        episode_l = 0
        episode_r = 0
        action = np.random.choice(np.arange(env.action_space.n), p=policy(s))
        e = defaultdict(lambda: np.zeros(env.action_space.n))
        visited_state_actions = []
        while not done:
            new_s, r, done, _ = env.step(action)
            a_max = np.argmax(Q[new_s])
            next_a = np.random.choice(np.arange(env.action_space.n), p=policy(new_s))
            td_error = r + discount_factor * Q[new_s][a_max] - Q[s][action]
            e[s][action] += 1
            visited_state_actions.append((s,action))
            for _state, _action in visited_state_actions:
                Q[_state][_action] += alpha * td_error * e[_state][_action]
                if a_max == next_a:
                    e[_state][_action] *= discount_factor * lamda
                else:
                    e[_state][_action] = 0
            action = next_a
            s = new_s
            episode_l += 1
            episode_r += r
            
    
        episode_r /= float(episode_l)
        print(episode_l)
        episode_lengths.append(episode_l)
        episode_rewards.append(episode_r)

    stats = plotting.EpisodeStats(episode_lengths=np.array(episode_lengths), episode_rewards=np.array(episode_rewards))

    return Q, stats

In [None]:
Q, stats = td_lambda(env, 800)

Episode 1/800.
104
Episode 2/800.
1
Episode 3/800.
12
Episode 4/800.
11
Episode 5/800.
86
Episode 6/800.
30
Episode 7/800.
64
Episode 8/800.
40
Episode 9/800.
24
Episode 10/800.
53
Episode 11/800.
130
Episode 12/800.
24
Episode 13/800.
28
Episode 14/800.
54
Episode 15/800.
31
Episode 16/800.
134
Episode 17/800.
90
Episode 18/800.
31
Episode 19/800.
25
Episode 20/800.
23
Episode 21/800.
19
Episode 22/800.
21
Episode 23/800.
17
Episode 24/800.
25
Episode 25/800.
23
Episode 26/800.
18
Episode 27/800.
19
Episode 28/800.
21
Episode 29/800.
17
Episode 30/800.
17
Episode 31/800.
19
Episode 32/800.
19
Episode 33/800.
22
Episode 34/800.
20
Episode 35/800.
18
Episode 36/800.
17
Episode 37/800.
18
Episode 38/800.
19
Episode 39/800.
34
Episode 40/800.
21
Episode 41/800.
22
Episode 42/800.
17
Episode 43/800.
17
Episode 44/800.
17
Episode 45/800.
21
Episode 46/800.
17
Episode 47/800.
18
Episode 48/800.
17
Episode 49/800.
18
Episode 50/800.
24
Episode 51/800.
23
Episode 52/800.
17
Episode 53/800.
18


In [None]:
plotting.plot_episode_stats(stats)

In [29]:
print np.random.random()

0.460058498157
