Consider the following one-dimensional "gridworld":
You are on a route consisting of $10$ states.
The state on the left side is a terminal state with a reward of $+10$ and the state on the right is also a terminal state width a reward of $-5$.
You can move left and right.

In [1]:
import numpy as np
from fractions import Fraction
import itertools

In [2]:
states = set(np.arange(10))
states_terminal = {0, 9}

actions = {s: {'left', 'right'} for s in states}
for s in states_terminal:
    actions[s] = {}

rewards = {0, 10, -5}

def p(s_prime, r, s, a):

    if (s_prime, r, s, a) in {
        (2,  0, 1, 'right'),
        (3,  0, 2, 'right'),
        (4,  0, 3, 'right'),
        (5,  0, 4, 'right'),
        (6,  0, 5, 'right'),
        (7,  0, 6, 'right'),
        (8,  0, 7, 'right'),
        (9, -5, 8, 'right'),
        (7,  0, 8, 'left'),
        (6,  0, 7, 'left'),
        (5,  0, 6, 'left'),
        (4,  0, 5, 'left'),
        (3,  0, 4, 'left'),
        (2,  0, 3, 'left'),
        (1,  0, 2, 'left'),
        (0, 10, 1, 'left')
    }:
        return 1
    else:
        return 0

Write an implementation of Dynamic Programming for estimating the values of the above states under the equiprobable random policy.

In [3]:
def initialization():

    V = {s: np.random.random() for s in states}

    for s in states_terminal:
        V[s] = 0

    pi = {s: {a: np.random.random() for a in actions[s]} for s in states}

    for s, pi_s in pi.items():

        sum_evaluated = sum(pi_s.values())

        for a in pi_s.keys():
            pi[s][a] /= sum_evaluated

    return V, pi

In [4]:
def iterative_policy_evaluation(V = None, pi = None, gamma = 1, theta = 1e-6, debug = False):

    if V == None:
        V, _ = initialization()

    assert pi != None

    Delta = np.infty

    if debug:
        run = 0

    while not Delta < theta:

        Delta = 0

        for s in states:

            v = V[s]

            V[s] = sum([
                pi[s][a] * sum([
                    p(s_prime, r, s, a) * (r + gamma * V[s_prime])
                    for s_prime, r in itertools.product(states, rewards)
                ])
                for a in actions[s]
            ])

            Delta = max(Delta, abs(v - V[s]))

    return V

In [5]:
pi = {s: {a: 0.5 for a in actions[s]} for s in states}

V = iterative_policy_evaluation(pi = pi)

In [34]:
for s, V_s in V.items():

    print(
        f'V({s}) =',
        str(
            Fraction(V_s).limit_denominator(max_denominator = 10_000)
        )
    )

V(0) = 0
V(1) = 10
V(2) = 5
V(3) = 5/2
V(4) = 5/4
V(5) = 5/8
V(6) = 5/16
V(7) = 5/32
V(8) = 5/64
V(9) = 0


Then use policy improvement to find the optimal policy.

In [7]:
def argsmax(A):

    if type(A) == dict:
        return {i for i, a in A.items() if a == max(A.values())}

    if type(A) == list:
        return argsmax(dict(enumerate(A)))

    return None

In [8]:
def policy_improvement(V, pi, gamma = 1, debug = False):

    policy_stable = True

    for s in states:
        if s not in states_terminal:

            old_actions = pi[s]

            q_s = {
                a: sum([
                    p(s_prime, r, s, a) * (r + gamma * V[s_prime])
                    for s_prime, r in itertools.product(states, rewards)
                ])
                for a in actions[s]
            }
            actions_max = argsmax(q_s)

            pi[s] = {a: 0 for a in actions[s]}

            for a in actions_max:
                pi[s][a] = 1 / len(actions_max)

            if old_actions != pi[s]:
                policy_stable = False

    return pi, policy_stable

In [18]:
def policy_iteration(gamma = 1, theta = 1e-6, debug = False):

    V, pi = initialization()

    policy_stable = False

    if debug:
        run = 0

    while not policy_stable:

        if debug:
            print(f'run = {run} ...')
            run += 1

        V = iterative_policy_evaluation(V = V, pi = pi, gamma = gamma, theta = 1e-6)
        if debug:
            print(f'V = {V}')

        pi, policy_stable = policy_improvement(V, pi, gamma = gamma)
        if debug:
            print(f'pi = {pi}')
            print(f'policy_stable = {policy_stable}')
            print()

    return V, pi

In [43]:
for gamma in [0, 0.1, 0.5, 0.9, 1]:

    print('#', '-'*64, '#', '\n')
    print(f'gamma = {gamma} ...', '\n')

    V_ast, pi_ast = policy_iteration(gamma = gamma)

    print('V_ast =')
    display(V_ast)

    print('pi_ast =')
    display(pi_ast)

print('#', '-'*64, '#')

# ---------------------------------------------------------------- # 

gamma = 0 ... 

V_ast =


{0: 0, 1: 10.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0, 9: 0}

pi_ast =


{0: {},
 1: {'right': 0, 'left': 1.0},
 2: {'right': 0.5, 'left': 0.5},
 3: {'right': 0.5, 'left': 0.5},
 4: {'right': 0.5, 'left': 0.5},
 5: {'right': 0.5, 'left': 0.5},
 6: {'right': 0.5, 'left': 0.5},
 7: {'right': 0.5, 'left': 0.5},
 8: {'right': 0, 'left': 1.0},
 9: {}}

# ---------------------------------------------------------------- # 

gamma = 0.1 ... 

V_ast =


{0: 0,
 1: 10.0,
 2: 1.0,
 3: 0.1,
 4: 0.010000000000000002,
 5: 0.0010000000000000002,
 6: 0.00010000000000000003,
 7: 1.0000000000000004e-05,
 8: 1.0000000000000004e-06,
 9: 0}

pi_ast =


{0: {},
 1: {'right': 0, 'left': 1.0},
 2: {'right': 0, 'left': 1.0},
 3: {'right': 0, 'left': 1.0},
 4: {'right': 0, 'left': 1.0},
 5: {'right': 0, 'left': 1.0},
 6: {'right': 0, 'left': 1.0},
 7: {'right': 0, 'left': 1.0},
 8: {'right': 0, 'left': 1.0},
 9: {}}

# ---------------------------------------------------------------- # 

gamma = 0.5 ... 

V_ast =


{0: 0,
 1: 10.0,
 2: 5.0,
 3: 2.5,
 4: 1.25,
 5: 0.625,
 6: 0.3125,
 7: 0.15625,
 8: 0.078125,
 9: 0}

pi_ast =


{0: {},
 1: {'right': 0, 'left': 1.0},
 2: {'right': 0, 'left': 1.0},
 3: {'right': 0, 'left': 1.0},
 4: {'right': 0, 'left': 1.0},
 5: {'right': 0, 'left': 1.0},
 6: {'right': 0, 'left': 1.0},
 7: {'right': 0, 'left': 1.0},
 8: {'right': 0, 'left': 1.0},
 9: {}}

# ---------------------------------------------------------------- # 

gamma = 0.9 ... 

V_ast =


{0: 0,
 1: 10.0,
 2: 9.0,
 3: 8.1,
 4: 7.29,
 5: 6.561,
 6: 5.9049000000000005,
 7: 5.3144100000000005,
 8: 4.7829690000000005,
 9: 0}

pi_ast =


{0: {},
 1: {'right': 0, 'left': 1.0},
 2: {'right': 0, 'left': 1.0},
 3: {'right': 0, 'left': 1.0},
 4: {'right': 0, 'left': 1.0},
 5: {'right': 0, 'left': 1.0},
 6: {'right': 0, 'left': 1.0},
 7: {'right': 0, 'left': 1.0},
 8: {'right': 0, 'left': 1.0},
 9: {}}

# ---------------------------------------------------------------- # 

gamma = 1 ... 

V_ast =


{0: 0,
 1: 10.0,
 2: 10.0,
 3: 10.0,
 4: 10.0,
 5: 10.0,
 6: 10.0,
 7: 10.0,
 8: 10.0,
 9: 0}

pi_ast =


{0: {},
 1: {'right': 0.5, 'left': 0.5},
 2: {'right': 0.5, 'left': 0.5},
 3: {'right': 0.5, 'left': 0.5},
 4: {'right': 0.5, 'left': 0.5},
 5: {'right': 0.5, 'left': 0.5},
 6: {'right': 0.5, 'left': 0.5},
 7: {'right': 0.5, 'left': 0.5},
 8: {'right': 0, 'left': 1.0},
 9: {}}

# ---------------------------------------------------------------- #
