In [1]:
%load_ext autoreload
%autoreload 2

# Planning

## Small Grid World
Below is the definition of a small grid world as described in Silver's lecture.
  * Undiscounted episodic MDP ($\gamma = 1$)
  * 4x4 grid
  * 2 terminal states: (0, 0) and (3, 3)
  * Actions leading out of the grid leave the state unchanged
  * Reward is -1 until the terminal state is reached
  * Agent follows uniform random policy

$$
\pi(n \vert \cdot) = \pi(e \vert \cdot) = \pi(s \vert \cdot) = \pi(w \vert \cdot) = \frac{1}{4}
$$

## State-Action Value Iteration
In the `state_values` notebook, I computed everything as a function of state values, i.e., other state values as well as state-action values were computed as function of state values. In this notebook, I'll compute everything as a function of state-action values.

$$
\begin{align}
q(s, a) &= r(s, a) + \gamma \mathbb E_{\pi, P}\left[ q(s', a') \right] \\
&= r(s, a) + \gamma \sum_{s' \in S} P(s' \vert s, a) \sum_{a' \in A} \pi(a' \vert s') \; q(s', a')
\end{align}
$$

I've never seen this anywhere, but I am guessing it will be possible to calculate q-values using dynamic programming as well.

$$
q_{k+1}(s, a) = r(s, a) + \gamma \sum_{s' \in S} P(s' \vert s, a) \sum_{a' \in A} \pi(a' \vert s') \; q_k(s', a')
$$

In [24]:
from collections import defaultdict
from small_grid_world import SmallGridWorld, State, Action, Policy
import numpy as np
from itertools import product

QTable = dict[tuple[State, Action], float]

MAX_ITERS = 1000


def has_converged(qvals: QTable, qvals_next: QTable) -> bool:
    if all(v == 0 for v in qvals.values()):
        return False
    if all(v == 0 for v in qvals_next.values()):
        return False
    return np.allclose([v for v in qvals.values()], [v for v in qvals_next.values()])


def calc_qvals(mdp: SmallGridWorld, pi: Policy) -> QTable:
    qvals: QTable = defaultdict(float)
    qvals_next: QTable = defaultdict(float)
    r = mdp.reward
    p = mdp.prob
    γ = mdp.gamma
    i = 0
    while i < MAX_ITERS and not has_converged(qvals, qvals_next):
        for s, a in product(mdp.states(), mdp.actions()):
            qvals_next[(s, a)] = r(s, a) + γ * sum(
                p(s_, given=(s, a))
                * sum(pi(a_, given=s_) * qvals[(s_, a_)] for a_ in mdp.actions())
                for s_ in mdp.states()
            )
        qvals = qvals_next
        qvals_next = defaultdict(float)
        i += 1
    for s, a in qvals.keys():
        qvals[(s, a)] = round(qvals[(s, a)])
    return qvals

In [25]:
def generate_uniform_random_policy(mdp):
    def policy(a, given):
        return 0 if mdp.is_terminal(given) else 0.25

    return policy

In [31]:
mdp = SmallGridWorld()
qvals = calc_qvals(mdp, generate_uniform_random_policy(mdp))
for s, a in product(mdp.states(), mdp.actions()):
    if qvals[(s, a)] != 0:
        print(f"Q({s}, {a}) = {qvals[(s, a)]}")

Q((0, 1), ↑) = -15
Q((0, 1), ↓) = -19
Q((0, 1), ←) = -1
Q((0, 1), →) = -21
Q((0, 2), ↑) = -21
Q((0, 2), ↓) = -21
Q((0, 2), ←) = -15
Q((0, 2), →) = -23
Q((0, 3), ↑) = -23
Q((0, 3), ↓) = -21
Q((0, 3), ←) = -21
Q((0, 3), →) = -23
Q((1, 0), ↑) = -1
Q((1, 0), ↓) = -21
Q((1, 0), ←) = -15
Q((1, 0), →) = -19
Q((1, 1), ↑) = -15
Q((1, 1), ↓) = -21
Q((1, 1), ←) = -15
Q((1, 1), →) = -21
Q((1, 2), ↑) = -21
Q((1, 2), ↓) = -19
Q((1, 2), ←) = -19
Q((1, 2), →) = -21
Q((1, 3), ↑) = -23
Q((1, 3), ↓) = -15
Q((1, 3), ←) = -21
Q((1, 3), →) = -21
Q((2, 0), ↑) = -15
Q((2, 0), ↓) = -23
Q((2, 0), ←) = -21
Q((2, 0), →) = -21
Q((2, 1), ↑) = -19
Q((2, 1), ↓) = -21
Q((2, 1), ←) = -21
Q((2, 1), →) = -19
Q((2, 2), ↑) = -21
Q((2, 2), ↓) = -15
Q((2, 2), ←) = -21
Q((2, 2), →) = -15
Q((2, 3), ↑) = -21
Q((2, 3), ↓) = -1
Q((2, 3), ←) = -19
Q((2, 3), →) = -15
Q((3, 0), ↑) = -21
Q((3, 0), ↓) = -23
Q((3, 0), ←) = -23
Q((3, 0), →) = -21
Q((3, 1), ↑) = -21
Q((3, 1), ↓) = -21
Q((3, 1), ←) = -23
Q((3, 1), →) = -15
Q((3, 2), ↑) = 

In [28]:
def qval(mdp, v, s, a):
    r = mdp.reward
    p = mdp.prob
    γ = mdp.gamma

    return r(s, a) + γ * sum(p(s_, given=(s, a)) * v[*s_] for s_ in mdp.states())

In [None]:
vals = np.array(
    [
        [0.0, -14.0, -20.0, -22.0],
        [-14.0, -18.0, -20.0, -20.0],
        [-20.0, -20.0, -18.0, -14.0],
        [-22.0, -20.0, -14.0, 0.0],
    ]
)

for s, a in product(mdp.states(), mdp.actions()):
    expected_qval = qval(mdp, vals, s, a)
    actual_qval = qvals[(s, a)]
    assert (
        expected_qval == actual_qval
    ), f"Expected {expected_qval} but got {actual_qval} for Q({s}, {a})"

## Optimal State-Action Values

$$
\begin{align}
q_*(s, a) &= r(s, a) + \gamma \mathbb E_P \left[ \underset{a'}{max}(q_*(s', a')) \right] \\
&= r(s, a) + \gamma \sum_{s' \in S} P\left(s' \vert s, a\right) \; \underset{a'}{max}(q_*(s', a')) \\
\end{align}
$$

Applying dynamic programming -

$$
q^*_{k+1} = r(s, a) + \gamma \sum_{s' \in S} P(s' \vert s, a) \; \underset{a'}{max}(q^*_k(s', a'))
$$

In [33]:
def optimal_qvals(mdp: SmallGridWorld) -> QTable:
    qvals: QTable = defaultdict(float)
    qvals_next: QTable = defaultdict(float)
    r = mdp.reward
    p = mdp.prob
    γ = mdp.gamma
    i = 0
    while i < MAX_ITERS and not has_converged(qvals, qvals_next):
        for s, a in product(mdp.states(), mdp.actions()):
            qvals_next[(s, a)] = r(s, a) + γ * sum(
                p(s_, given=(s, a)) * max(qvals[(s_, a_)] for a_ in mdp.actions())
                for s_ in mdp.states()
            )
        qvals = qvals_next
        qvals_next = defaultdict(float)
        i += 1
    for s, a in qvals.keys():
        qvals[(s, a)] = round(qvals[(s, a)])
    return qvals

In [34]:
qvals_star = optimal_qvals(mdp)
for s, a in product(mdp.states(), mdp.actions()):
    if qvals_star[(s, a)] != 0:
        print(f"Q*({s}, {a}) = {qvals_star[(s, a)]}")

Q*((0, 1), ↑) = -2
Q*((0, 1), ↓) = -3
Q*((0, 1), ←) = -1
Q*((0, 1), →) = -3
Q*((0, 2), ↑) = -3
Q*((0, 2), ↓) = -4
Q*((0, 2), ←) = -2
Q*((0, 2), →) = -4
Q*((0, 3), ↑) = -4
Q*((0, 3), ↓) = -3
Q*((0, 3), ←) = -3
Q*((0, 3), →) = -4
Q*((1, 0), ↑) = -1
Q*((1, 0), ↓) = -3
Q*((1, 0), ←) = -2
Q*((1, 0), →) = -3
Q*((1, 1), ↑) = -2
Q*((1, 1), ↓) = -4
Q*((1, 1), ←) = -2
Q*((1, 1), →) = -4
Q*((1, 2), ↑) = -3
Q*((1, 2), ↓) = -3
Q*((1, 2), ←) = -3
Q*((1, 2), →) = -3
Q*((1, 3), ↑) = -4
Q*((1, 3), ↓) = -2
Q*((1, 3), ←) = -4
Q*((1, 3), →) = -3
Q*((2, 0), ↑) = -2
Q*((2, 0), ↓) = -4
Q*((2, 0), ←) = -3
Q*((2, 0), →) = -4
Q*((2, 1), ↑) = -3
Q*((2, 1), ↓) = -3
Q*((2, 1), ←) = -3
Q*((2, 1), →) = -3
Q*((2, 2), ↑) = -4
Q*((2, 2), ↓) = -2
Q*((2, 2), ←) = -4
Q*((2, 2), →) = -2
Q*((2, 3), ↑) = -3
Q*((2, 3), ↓) = -1
Q*((2, 3), ←) = -3
Q*((2, 3), →) = -2
Q*((3, 0), ↑) = -3
Q*((3, 0), ↓) = -4
Q*((3, 0), ←) = -4
Q*((3, 0), →) = -3
Q*((3, 1), ↑) = -4
Q*((3, 1), ↓) = -3
Q*((3, 1), ←) = -4
Q*((3, 1), →) = -2
Q*((3, 2), ↑

In [35]:
vals = np.array(
    [
        [0.0, -1.0, -2.0, -3.0],
        [-1.0, -2.0, -3.0, -2.0],
        [-2.0, -3.0, -2.0, -1.0],
        [-3.0, -2.0, -1.0, 0.0],
    ]
)

for s, a in product(mdp.states(), mdp.actions()):
    expected_qval = qval(mdp, vals, s, a)
    actual_qval = qvals_star[(s, a)]
    assert (
        expected_qval == actual_qval
    ), f"Expected {expected_qval} but got {actual_qval} for Q({s}, {a})"

In [38]:
from small_grid_world import argmax


def greedy_policy(
    mdp: SmallGridWorld, q_star: QTable, s: State, dbg=False
) -> Action | None:
    if mdp.is_terminal(s):
        return None

    best_actions = argmax(mdp.actions(), key=lambda a: q_star[(s, a)])
    best_action = best_actions[0]
    if dbg:
        if len(best_actions) > 1:
            best_actions_str = ", ".join(str(a) for a in best_actions)
            print(
                f"\tDEBUG: pi({s}): Got {len(best_actions)} best actions: {best_actions_str}, choosing {best_action}"
            )
    return best_action

In [53]:
expected_policy = {
    (0, 1): "←",
    (0, 2): "←",
    (0, 3): "↓",
    (1, 0): "↑",
    (1, 1): "↑",
    (1, 2): "↓",
    (1, 3): "↓",
    (2, 0): "↑",
    (2, 1): "↑",
    (2, 2): "↓",
    (2, 3): "↓",
    (3, 0): "↑",
    (3, 1): "→",
    (3, 2): "→",
}

In [54]:
for s in mdp.states():
    if mdp.is_terminal(s):
        continue
    expected_a = expected_policy[s]
    actual_a = str(greedy_policy(mdp, qvals_star, s, dbg=True))
    if expected_a != actual_a:
        print(f"Expected {expected_a} but got {actual_a} for pi({s})")

	DEBUG: pi((0, 3)): Got 2 best actions: ↓, ←, choosing ↓
	DEBUG: pi((1, 1)): Got 2 best actions: ↑, ←, choosing ↑
	DEBUG: pi((1, 2)): Got 4 best actions: ↑, ↓, ←, →, choosing ↑
Expected ↓ but got ↑ for pi((1, 2))
	DEBUG: pi((2, 1)): Got 4 best actions: ↑, ↓, ←, →, choosing ↑
	DEBUG: pi((2, 2)): Got 2 best actions: ↓, →, choosing ↓
	DEBUG: pi((3, 0)): Got 2 best actions: ↑, →, choosing ↑
