In [54]:
import numpy as np
import pandas as pd

import gym
import timeit
from gym import wrappers

In [55]:
max_iterations = 10000000
t_max = 10000
epsilon = 1e-10

env_name = 'Taxi-v2'
env = gym.make(env_name)
env.seed(0)
np.random.seed(0)

In [56]:
# https://medium.com/@m.alzantot/deep-reinforceme
# nt-learning-demysitifed-episode-2-policy-iteration
# -value-iteration-and-q-978f9e89ddaa

def run_episode(env, policy, gamma = 1.0, render = False):
    """ Evaluates policy by using it to run an episode and finding its
    total reward.
    args:
    env: gym environment.
    policy: the policy to be used.
    gamma: discount factor.
    render: boolean to turn rendering on/off.
    returns:
    total reward: real value of the total reward recieved by agent under policy.
    """
    obs = env.reset()
    total_reward = 0
    step_idx = 0
    while True:
        if render:
            env.render()
        obs, reward, done , _ = env.step(int(policy[obs]))
        taxirow, taxicol, passidx, destidx = env.env.decode(env.env.s)
        total_reward += (gamma ** step_idx * reward)
        step_idx += 1
        if done:
            break
    return total_reward


def evaluate_policy(env, policy, gamma = 1.0,  n = 1000):
    """ Evaluates a policy by running it n times.
    returns:
    average total reward
    """
    scores = [
            run_episode(env, policy, gamma = gamma, render = False)
            for _ in range(n)]
    return np.mean(scores)

In [57]:
def value_iteration(env, gamma = 1, v = None):
    v = np.zeros(env.env.nS)
    
    for i in range(max_iterations):
        old_v = np.copy(v)
        for s in range(env.env.nS):
            q_sa = [sum([gamma * p * (r + old_v[s_]) for p, s_, r, _ in env.env.P[s][a]]) for a in range(env.env.nA)]
            v[s] = max(q_sa)
        if np.sum(np.fabs(v - old_v)) <= epsilon:
            print ('Value-iteration converged at iteration #%d.' %(i+1))
            break
    
    return v

start = timeit.default_timer()
values = value_iteration(env, gamma = 0.9)
froz_lake_vi = timeit.default_timer() - start

pd.Series(values, name="value").to_csv("value_iteration_taxi.csv", header=True)
froz_lake_vi

Value-iteration converged at iteration #306.


1.6660783950064797

In [58]:
def extract_policy(v, gamma = 1.0):
    """ Extract the policy given a value-function """
    policy = np.zeros(env.env.nS)
    for s in range(env.env.nS):
        q_sa = np.zeros(env.env.action_space.n)
        for a in range(env.env.action_space.n):
            for next_sr in env.env.P[s][a]:
                # next_sr is a tuple of (probability, next state, reward, done)
                p, s_, r, _ = next_sr
                q_sa[a] += (p * (r + gamma * v[s_]))
        policy[s] = np.argmax(q_sa)
    return policy

policy = extract_policy(values)
pd.Series(policy, name="policy").to_csv("vi_policy_taxi.csv", header=True)

In [59]:
evaluate_policy(env, policy)

8.445

# Policy Iteration

In [16]:
def compute_policy_v(env, policy, gamma=1):
    """ Iteratively evaluate the value-function under policy.
    Alternatively, we could formulate a set of linear equations in iterms of v[s] 
    and solve them to find the value function.
    """
    v = np.zeros(env.env.nS)
    eps = 1e-10
    while True:
        prev_v = np.copy(v)
        for s in range(env.env.nS):
            policy_a = policy[s]
            v[s] = sum([p * (r + gamma * prev_v[s_]) for p, s_, r, _ in env.env.P[s][policy_a]])
        if (np.sum((np.fabs(prev_v - v))) <= eps):
            # value converged
            break
    return v

def policy_iteration(env, gamma = 1):
    """ Policy-Iteration algorithm """
    policy = np.random.choice(env.env.nA, size=(env.env.nS))  # initialize a random policy
    max_iterations = 200000
    for i in range(max_iterations):
        old_policy_v = compute_policy_v(env, policy, gamma)
        new_policy = extract_policy(old_policy_v, gamma)
        if (np.all(policy == new_policy)):
            print ('Policy-Iteration converged at step %d.' %(i+1))
            break
        policy = new_policy
    return policy

In [17]:
start = timeit.default_timer()
pi_policy = policy_iteration(env, gamma = 0.9)
froz_lake_pi = timeit.default_timer() - start

pd.Series(pi_policy, name="policy").to_csv("pi_policy_taxi.csv", header=True)

Policy-Iteration converged at step 16.


In [18]:
froz_lake_pi

5.893024926997896

In [48]:
def q_learning(env, gamma = 1.0, alpha = 0.9):
    # np.abs(np.random.randn())
    first_reward = True
    Q = np.zeros((env.observation_space.n, env.action_space.n))
    eps = 0.99999999
    total_reward = 0
    for i in range(max_iterations):
        
        if (i + 1) % 10000 == 0:
            print((i + 1) / 10000, "% complete -- alpha: ", round(alpha, 2),
                  "-- epsilon: ", round(eps, 2), "-- reward:", round(total_reward, 2))
            total_reward = 0
            
        obs = env.reset()
        for t in range(t_max):
            if np.random.uniform(0, 1) < eps:
                action = np.random.choice(env.action_space.n)
            else:
                action = np.argmax(Q[obs])
            
            old_obs = obs
            obs, reward, done, _ = env.step(action)
                                    
            predict = Q[old_obs, action]
            target = reward + gamma * np.max(Q[obs])
            Q[old_obs, action] = predict + alpha * (target - predict)
            
            total_reward += reward
            
            if reward != 0 and first_reward:
                first_reward = False
                print("first reward at iteration ", i, ". reward: ", reward)
            
            if done:
                break
        
        alpha = alpha * .999999
        eps = eps * 0.999999
    
    return Q

def extract_q_policy(env, Q):
    policy = np.zeros(env.observation_space.n)
    for state in range(env.observation_space.n):
        policy[state] = np.argmax(Q[state])
        
    return policy

In [49]:
Q = q_learning(env)

first reward at iteration  0 . reward:  -10
1.0 % complete -- alpha:  0.89 -- epsilon:  0.99 -- reward: -7619145
2.0 % complete -- alpha:  0.88 -- epsilon:  0.98 -- reward: -7449195
3.0 % complete -- alpha:  0.87 -- epsilon:  0.97 -- reward: -7289447
4.0 % complete -- alpha:  0.86 -- epsilon:  0.96 -- reward: -7139762
5.0 % complete -- alpha:  0.86 -- epsilon:  0.95 -- reward: -6953705
6.0 % complete -- alpha:  0.85 -- epsilon:  0.94 -- reward: -6709949
7.0 % complete -- alpha:  0.84 -- epsilon:  0.93 -- reward: -6487196
8.0 % complete -- alpha:  0.83 -- epsilon:  0.92 -- reward: -6226355
9.0 % complete -- alpha:  0.82 -- epsilon:  0.91 -- reward: -6037428
10.0 % complete -- alpha:  0.81 -- epsilon:  0.9 -- reward: -5782274
11.0 % complete -- alpha:  0.81 -- epsilon:  0.9 -- reward: -5516212
12.0 % complete -- alpha:  0.8 -- epsilon:  0.89 -- reward: -5245071
13.0 % complete -- alpha:  0.79 -- epsilon:  0.88 -- reward: -4955737
14.0 % complete -- alpha:  0.78 -- epsilon:  0.87 -- rewar

119.0 % complete -- alpha:  0.27 -- epsilon:  0.3 -- reward: -143553
120.0 % complete -- alpha:  0.27 -- epsilon:  0.3 -- reward: -137786
121.0 % complete -- alpha:  0.27 -- epsilon:  0.3 -- reward: -135973
122.0 % complete -- alpha:  0.27 -- epsilon:  0.3 -- reward: -130053
123.0 % complete -- alpha:  0.26 -- epsilon:  0.29 -- reward: -126346
124.0 % complete -- alpha:  0.26 -- epsilon:  0.29 -- reward: -124835
125.0 % complete -- alpha:  0.26 -- epsilon:  0.29 -- reward: -120831
126.0 % complete -- alpha:  0.26 -- epsilon:  0.28 -- reward: -119051
127.0 % complete -- alpha:  0.25 -- epsilon:  0.28 -- reward: -114391
128.0 % complete -- alpha:  0.25 -- epsilon:  0.28 -- reward: -111368
129.0 % complete -- alpha:  0.25 -- epsilon:  0.28 -- reward: -109062
130.0 % complete -- alpha:  0.25 -- epsilon:  0.27 -- reward: -108729
131.0 % complete -- alpha:  0.24 -- epsilon:  0.27 -- reward: -100765
132.0 % complete -- alpha:  0.24 -- epsilon:  0.27 -- reward: -101489
133.0 % complete -- alph

239.0 % complete -- alpha:  0.08 -- epsilon:  0.09 -- reward: 37307
240.0 % complete -- alpha:  0.08 -- epsilon:  0.09 -- reward: 36187
241.0 % complete -- alpha:  0.08 -- epsilon:  0.09 -- reward: 35785
242.0 % complete -- alpha:  0.08 -- epsilon:  0.09 -- reward: 36845
243.0 % complete -- alpha:  0.08 -- epsilon:  0.09 -- reward: 37073
244.0 % complete -- alpha:  0.08 -- epsilon:  0.09 -- reward: 38126
245.0 % complete -- alpha:  0.08 -- epsilon:  0.09 -- reward: 38565
246.0 % complete -- alpha:  0.08 -- epsilon:  0.09 -- reward: 37738
247.0 % complete -- alpha:  0.08 -- epsilon:  0.08 -- reward: 39462
248.0 % complete -- alpha:  0.08 -- epsilon:  0.08 -- reward: 37521
249.0 % complete -- alpha:  0.07 -- epsilon:  0.08 -- reward: 40445
250.0 % complete -- alpha:  0.07 -- epsilon:  0.08 -- reward: 40692
251.0 % complete -- alpha:  0.07 -- epsilon:  0.08 -- reward: 40805
252.0 % complete -- alpha:  0.07 -- epsilon:  0.08 -- reward: 40856
253.0 % complete -- alpha:  0.07 -- epsilon:  0.

360.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 70466
361.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 70632
362.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 71599
363.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 71612
364.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 72611
365.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 71247
366.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 71472
367.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 71956
368.0 % complete -- alpha:  0.02 -- epsilon:  0.03 -- reward: 71755
369.0 % complete -- alpha:  0.02 -- epsilon:  0.02 -- reward: 72419
370.0 % complete -- alpha:  0.02 -- epsilon:  0.02 -- reward: 71708
371.0 % complete -- alpha:  0.02 -- epsilon:  0.02 -- reward: 72172
372.0 % complete -- alpha:  0.02 -- epsilon:  0.02 -- reward: 71948
373.0 % complete -- alpha:  0.02 -- epsilon:  0.02 -- reward: 72814
374.0 % complete -- alpha:  0.02 -- epsilon:  0.

481.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80538
482.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80512
483.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80049
484.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80504
485.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80232
486.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80646
487.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80910
488.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80881
489.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80414
490.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 81007
491.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80956
492.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80744
493.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80727
494.0 % complete -- alpha:  0.01 -- epsilon:  0.01 -- reward: 80668
495.0 % complete -- alpha:  0.01 -- epsilon:  0.

604.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83789
605.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 82995
606.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 82948
607.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83217
608.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83470
609.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83178
610.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83125
611.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 82888
612.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83017
613.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83551
614.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83954
615.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 82857
616.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83352
617.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83229
618.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83152
619.0 % co

729.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83887
730.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83900
731.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84333
732.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83955
733.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84149
734.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83869
735.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84052
736.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84214
737.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84095
738.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83913
739.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84044
740.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83830
741.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83693
742.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83722
743.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83824
744.0 % co

854.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83883
855.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84470
856.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84311
857.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84365
858.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84326
859.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84058
860.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84018
861.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84049
862.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84256
863.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84098
864.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84237
865.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83834
866.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84182
867.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83642
868.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84748
869.0 % co

979.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84080
980.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84325
981.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84255
982.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84149
983.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84325
984.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84091
985.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84446
986.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84276
987.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84573
988.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84745
989.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84075
990.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84097
991.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84130
992.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 84198
993.0 % complete -- alpha:  0.0 -- epsilon:  0.0 -- reward: 83779
994.0 % co

In [50]:
new_policy = extract_q_policy(env, Q)
pd.Series(new_policy, name="policy").to_csv("q_policy_taxi.csv", header=True)

In [51]:
evaluate_policy(env, new_policy)

8.367

In [52]:
Q

array([[      0.        ,       0.        ,       0.        ,
              0.        ,       0.        ,       0.        ],
       [4023836.22611431, 4020949.25634375, 4024207.77503775,
        4023757.80595487, 4072720.55457661, 4020833.03067751],
       [3987731.45568901, 3988463.5030436 , 3984651.34915117,
        3985780.67823743, 4037366.12047424, 3987763.83862317],
       ...,
       [3617751.09177682, 3608865.07293747, 3625779.33155269,
        4020915.02065667, 3589140.58446721, 3614230.21265619],
       [3649555.19731745, 3563909.0666153 , 3595261.24981797,
        3984747.00994965, 3580921.41137945, 3632006.00499502],
       [3769975.31674422, 3735349.52805652, 3675817.93745849,
        4019575.75572213, 3753098.13475314, 3752639.94063378]])

In [53]:
new_policy

array([0., 4., 4., 4., 2., 0., 2., 2., 0., 0., 0., 0., 0., 0., 2., 0., 5.,
       0., 0., 0., 0., 3., 3., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 3., 0., 0., 0., 0., 0., 0., 0., 2., 0., 2., 2., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 3., 0., 3., 2., 0., 2., 2.,
       0., 0., 0., 3., 0., 0., 0., 0., 2., 2., 0., 2., 0., 3., 3., 3., 4.,
       0., 4., 4., 0., 0., 0., 0., 0., 3., 0., 0., 0., 5., 0., 0., 0., 1.,
       1., 1., 2., 0., 2., 2., 0., 0., 0., 0., 0., 0., 2., 0., 1., 2., 0.,
       0., 0., 3., 3., 3., 2., 0., 2., 2., 0., 0., 0., 0., 2., 2., 0., 0.,
       3., 2., 0., 0., 0., 3., 3., 3., 2., 0., 2., 2., 0., 0., 0., 0., 0.,
       0., 0., 0., 3., 2., 0., 2., 0., 3., 3., 0., 1., 0., 1., 1., 0., 3.,
       0., 0., 0., 0., 0., 0., 3., 1., 0., 0., 0., 3., 3., 3., 1., 0., 1.,
       1., 0., 3., 0., 0., 3., 3., 0., 0., 3., 3., 0., 3., 0., 1., 1., 1.,
       1., 0., 2., 2., 0., 0., 0., 0., 2., 2., 2., 0., 1., 1., 0., 2., 0.,
       1., 1., 1., 2., 0.