# Iterative Policy Evaluation

### Example 4.1

In [1]:
def p(s_n, r, s, a):
    if r != -1:
        return 0
    if a == 'right' and s % 4 < 3 and s_n == s + 1:
        return 1
    if a == 'left' and s % 4 > 0 and s_n == s - 1:
        return 1
    if a == 'up' and s // 4 > 0 and s_n == s - 4:
        return 1
    if a == 'down' and s // 4 < 3 and s_n == s + 4:
        return 1
    if a == 'right' and s % 4 == 3 and s_n == s:
        return 1
    if a == 'left' and s % 4 == 0 and s_n == s:
        return 1
    if a == 'up' and s // 4 == 0 and s_n == s:
        return 1
    if a == 'down' and s // 4 == 3 and s_n == s:
        return 1
    return 0


def π(a, s):
    return 0.25


V = [0] + 14 * [0] + [0]
θ = 0.001

nonterminal = list(range(1, 15))
terminal = [0, 15]
actions = ['left', 'right', 'up', 'down']
rewards = [-1]
γ = 1

Δ = float('inf')
while Δ >= θ:
    Δ = 0
    for s in nonterminal:
        v = V[s]
        V[s] = sum([
            π(a, s) * p(s_n, r, s, a) * (r + γ * V[s_n])
            for a in actions for s_n in nonterminal + terminal for r in rewards
        ])
        Δ = max(Δ, abs(v - V[s]))
print('\n------·------·------·------\n'.join(['|'.join([f'{V[4 * r + c]:6.2f}' for c in range(4)]) for r in range(4)]))

  0.00|-13.99|-19.99|-21.99
------·------·------·------
-13.99|-17.99|-19.99|-19.99
------·------·------·------
-19.99|-19.99|-17.99|-13.99
------·------·------·------
-21.99|-19.99|-13.99|  0.00


# Policy Iteration

In [77]:
map_str = r"""
e..x...
..xxx.x
x..x...
.....x.
x......
"""
terminal_states = []
nonterminal_states = []

for y, line in enumerate(map_str.split()):
    for x, value in enumerate(line):
        if value == 'e':
            terminal_states.append((x, y))
        elif value == '.':
            nonterminal_states.append((x, y))

states = terminal_states + nonterminal_states
actions = {
    'left': (-1, 0),
    'right': (1, 0),
    'up': (0, 1),
    'down': (0, -1)
}
rewards = [-1]


def p(new_state, reward, state, action):
    if action not in actions:
        return 0
    if reward not in rewards:
        return 0

    delta = actions[action]
    moved_state = (state[0] + delta[0], state[1] + delta[1])

    if moved_state not in states and new_state == state:
        return 1
    elif moved_state in states and new_state == moved_state:
        return 1
    else:
        return 0

In [150]:
# begin with random policy
policy = {state: {action: 0.25 for action in actions} for state in states}
initial_value_function = {state: 0 for state in states}


def evaluate_policy(policy, initial_value_function, error_tolerance=0.5):
    value_function = initial_value_function.copy()
    error = float('inf')

    while error >= error_tolerance:
        error = 0
        for state in nonterminal_states:
            new_value = 0
            for new_state in states:
                for action in actions:
                    for reward in rewards:
                        new_value += policy[state][action] * p(new_state, reward, state, action) * (reward + value_function[new_state])

            error = max(error, abs(new_value - value_function[state]))
            value_function[state] = new_value

    return value_function


def improve_policy(policy, initial_value_function):
    stable = False
    value_function = initial_value_function
    policy = {state: policy[state].copy() for state in states}
    while not stable:
        stable = True
        value_function = evaluate_policy(policy, initial_value_function=value_function)
        for state in nonterminal_states:
            best_actions = []
            best_value = float('-inf')
            for action in actions:
                value = 0
                for new_state in states:
                    for reward in rewards:
                        value += p(new_state, reward, state, action) * (reward + value_function[new_state])
                if value > best_value:
                    best_actions = [action]
                    best_value = value
                elif value == best_value:
                    best_actions.append(action)
            new_policy_value = {action: 1 / len(best_actions) if action in best_actions else 0 for action in actions}
            if new_policy_value != policy[state]:
                stable = False
                policy[state] = new_policy_value
    return policy, value_function


def print_2d(cells):
    max_len = max([len(str(value)) for value in cells.values()])
    line_strs = []
    for y, line in enumerate(map_str.split()):
        line_cells = []
        for x in range(len(line)):
            cell_value = None
            if line[x] == 'x':
                cell_value = max_len * '█'
            else:
                cell_value = str(cells[x, y])
            line_cells.append(f'{cell_value:^{max_len}}')
        width = max([len(other_line) for other_line in map_str.split()[r:r + 1]])
        vertical_padding = max_len // 2
        line_strs.append('|' + '|'.join(line_cells) + '|')
        line_strs.append('·' + '·'.join(len(line) * [max_len * '—']) + '·')
    line_strs.insert(0, line_strs[1][:])
    print('\n'.join(line_strs))

In [151]:
print_2d({state: round(value) for state, value in evaluate_policy(policy, initial_value_function=initial_value_function).items()})

·————·————·————·————·————·————·————·
| 0  |-40 |-43 |████|-269|-267|-270|
·————·————·————·————·————·————·————·
|-38 |-72 |████|████|████|-260|████|
·————·————·————·————·————·————·————·
|████|-134|-159|████|-238|-250|-250|
·————·————·————·————·————·————·————·
|-170|-167|-181|-204|-222|████|-248|
·————·————·————·————·————·————·————·
|████|-180|-191|-207|-222|-233|-242|
·————·————·————·————·————·————·————·


In [152]:
value_function = evaluate_policy(policy, initial_value_function=initial_value_function)
new_policy, new_value_function = improve_policy(policy, initial_value_function=initial_value_function)

In [164]:
sym = {
    'left': ' ← ',
    'right': ' → ',
    'up': ' ↓ ',
    'down': ' ↑ '
}
print_2d({k: [sym[action] for action in actions if v[action] != 0][0] if k != (0, 0) else ' ' for k, v in new_policy.items()})
print_2d({state: round(value) for state, value in new_value_function.items()})

·———·———·———·———·———·———·———·
|   | ← | ← |███| → | ↓ | ← |
·———·———·———·———·———·———·———·
| ↑ | ← |███|███|███| ↓ |███|
·———·———·———·———·———·———·———·
|███| ↑ | ← |███| ↓ | ← | ← |
·———·———·———·———·———·———·———·
| → | ↑ | ← | ← | ← |███| ↓ |
·———·———·———·———·———·———·———·
|███| ↑ | ← | ← | ← | ← | ← |
·———·———·———·———·———·———·———·
·———·———·———·———·———·———·———·
| 0 |-1 |-2 |███|-12|-11|-12|
·———·———·———·———·———·———·———·
|-1 |-2 |███|███|███|-10|███|
·———·———·———·———·———·———·———·
|███|-3 |-4 |███|-8 |-9 |-10|
·———·———·———·———·———·———·———·
|-5 |-4 |-5 |-6 |-7 |███|-11|
·———·———·———·———·———·———·———·
|███|-5 |-6 |-7 |-8 |-9 |-10|
·———·———·———·———·———·———·———·
