We find the best policy for the random walk problem using Q table n-step Sarse.

The random walk problem is defined at the book (Chapter 6.2, page 133, example 6.2). We make the chain a bit bigger with a path with 9 states with a terminating state on each side. The rewards will be 1 at the right most terminating state, -1 at the left most and 0 elsewhere.

For the solution, we follow the pseudo code from chapter 7.2, page 157.

In [1]:
from __future__ import division
from collections import defaultdict
import random
import sys

Defining helper functions.

In [2]:
def arg_max_with_tol(arr, tol=0.0001):
    """
    Finds the index of the max value in an array. Allows precision error to be ignored. Breaks even randomely.
    """
    m = max(arr)
    l = []
    
    for i, val in enumerate(arr):
        if abs(val - m) <= tol:
            l.append(i)
    tmp = random.randint(0, len(l) - 1)
    return l[tmp]

def choose_action_idx(actions, state_q_table, epsilon=0.1):
    """Chosses the action idx according to greedy epsilon policy. """
    tmp = random.random()
    if tmp < epsilon:
        a = random.randint(0, len(actions) - 1)
    else:
        a = arg_max_with_tol(state_q_table)
    return a

def get_greedy_policy(s, q_table, actions):
    """Returns the sequence of states visited starting from s. Uses greedy policy breaking even randomely. """
    states = [s]
    while s not in Ts:
        a = choose_action_idx(actions, q_table[s], epsilon=0)
        s = s + actions[a]
        states.append(s)
    
    return states

Initializing the world with $\alpha = 0.3, n = 3, \gamma=1$

In [3]:
Ts = {0, 10}
n = 3
alpha = 0.3
gamma = 1
actions = [-1, 1]
Rs = ([0] * 11)
Rs[10] = 1  # The last state is the only rewarded state
Rs[0] = -1
q_table = defaultdict(lambda: ([0] * len(actions)), {})

First let's run a single episode.

In [4]:
S, A = [], []
S0 = 5
S.append(S0)
A0 = choose_action_idx(actions, q_table[S0])
A.append(A0)
T = sys.maxint
for t in xrange(sys.maxint):    # What will happen if you'll use range instead of xrange
    if t < T:
        # Taking action At
        next_s = S[t] + actions[A[t]]
        S.append(next_s)
        next_r = Rs[next_s]
    
        # Checking if it's in terminate states
        if next_s in Ts:
            T = t + 1
        else:
            next_a = choose_action_idx(actions, q_table[next_s])
            A.append(next_a)
    
    tau = t - n + 1
    if T < sys.maxint:  # This is not in the psuedo code - but it should be
        tau = max(tau, 0)
        
    if tau >= 0:
        G = sum([gamma ** (i - tau - 1) * Rs[S[i]] for i in range(tau + 1, min(tau + n, T) + 1)])
        if tau + n < T:
            G += gamma ** n * q_table[S[tau + n]][A[tau + n]] 

        old_q = q_table[S[tau]][A[tau]]
        q_table[S[tau]][A[tau]] = old_q + alpha * (G - old_q)        
    
    if tau == T - 1:
        break

Printing the greedy policy for different states. 

### Question:
Which states can we be sure that they are correct at this point? Why is that?

In [5]:
print get_greedy_policy(9, q_table, actions)
print get_greedy_policy(8, q_table, actions)
print get_greedy_policy(7, q_table, actions)
print get_greedy_policy(5, q_table, actions)

[9, 8, 9, 8, 9, 10]
[8, 9, 10]
[7, 8, 9, 10]
[5, 6, 7, 6, 5, 6, 7, 8, 9, 8, 9, 10]


### Question:
How can we compute convergance?

In [7]:
def expected_num_steps(s, q_table, actions, size=10000):
    """Computes the expected number of steps until termination starting from state s. """
    num_steps = 0
    for i in range(size):
        num_steps += len(get_greedy_policy(s, q_table, actions)) - 1
    
    return num_steps / size

In [8]:
expected_num_steps(5, q_table, actions)

17.2716

Running until convergance

In [9]:
exp_num_steps = float('inf')
converged = False
while not converged:
    S, A = [], []
    S0 = 5
    S.append(S0)
    A0 = choose_action_idx(actions, q_table[S0])
    A.append(A0)
    T = sys.maxint
    for t in xrange(sys.maxint):    # What will happen if you'll use range instead of xrange
        if t < T:
            # Taking action At
            next_s = S[t] + actions[A[t]]
            S.append(next_s)
            next_r = Rs[next_s]

            # Checking if it's in terminate states
            if next_s in Ts:
                T = t + 1
            else:
                next_a = choose_action_idx(actions, q_table[next_s])
                A.append(next_a)

        tau = t - n + 1
        if T < sys.maxint:  # This is not in the psuedo code - but it should be
            tau = max(tau, 0)
            
        if tau >= 0:
            G = sum([gamma ** (i - tau - 1) * Rs[S[i]] for i in range(tau + 1, min(tau + n, T) + 1)])
            if tau + n < T:
                G += gamma ** n * q_table[S[tau + n]][A[tau + n]] 

            old_q = q_table[S[tau]][A[tau]]
            q_table[S[tau]][A[tau]] = old_q + alpha * (G - old_q)        

        if tau == T - 1:
            break
            
    # Checking for convergance.            
    tmp = expected_num_steps(5, q_table, actions)
    if np.abs(exp_num_steps - tmp) < 0.0001:
        converged = True
    
    exp_num_steps = tmp

In [10]:
print get_greedy_policy(5, q_table, actions)
print expected_num_steps(5, q_table, actions)

[5, 6, 7, 8, 9, 10]
5.0


### Question:
Are we guarenteed to learn anything for state 2?

In [11]:
print get_greedy_policy(2, q_table, actions)
print get_greedy_policy(2, q_table, actions)
print get_greedy_policy(2, q_table, actions)
print get_greedy_policy(2, q_table, actions)

[2, 1, 2, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[2, 3, 4, 3, 2, 3, 4, 3, 2, 1, 0]
[2, 3, 2, 3, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1, 0]


Unless we increase the epsilon (which we don't really want to do -- it's too high already) we are not guarenteed to visit node 2 (or 3 or 4). That means that the q_table value is [0, 0].

In [12]:
q_table[2]

[0, 0]

We can fix that by sampling an initial state and checking for convergance starting at state 2.

In [14]:
exp_num_steps = float('inf')
converged = False
while not converged:
    S, A = [], []
    S0 = random.randint(2, 9)   # Sampling the starting point
    S.append(S0)
    A0 = choose_action_idx(actions, q_table[S0])
    A.append(A0)
    T = sys.maxint
    for t in xrange(sys.maxint):    # What will happen if you'll use range instead of xrange
        if t < T:
            # Taking action At
            next_s = S[t] + actions[A[t]]
            S.append(next_s)
            next_r = Rs[next_s]

            # Checking if it's in terminate states
            if next_s in Ts:
                T = t + 1
            else:
                next_a = choose_action_idx(actions, q_table[next_s])
                A.append(next_a)

        tau = t - n + 1
        if T < sys.maxint:  # This is not in the psuedo code - but it should be
            tau = max(tau, 0)
            
        if tau >= 0:
            G = sum([gamma ** (i - tau - 1) * Rs[S[i]] for i in range(tau + 1, min(tau + n, T) + 1)])
            if tau + n < T:
                G += gamma ** n * q_table[S[tau + n]][A[tau + n]] 

            old_q = q_table[S[tau]][A[tau]]
            q_table[S[tau]][A[tau]] = old_q + alpha * (G - old_q)        

        if tau == T - 1:
            break
            
    # Checking for convergance.
    tmp = expected_num_steps(2, q_table, actions)
    if np.abs(exp_num_steps - tmp) < 0.0001:
        converged = True
    
    exp_num_steps = tmp

In [15]:
print get_greedy_policy(2, q_table, actions)
print expected_num_steps(2, q_table, actions)

[2, 3, 4, 5, 6, 7, 8, 9, 10]
8.0


### Question:
Would this solution work for state 1 as well?