# Policy Iteration
## Policy Improvement

Knowing the value function $v_\pi$, we want to know whether there is a deterministic policy that we can choose and is better than current policy. One way is to choose an action $a$ and follow policy $\pi$ thereafter. The value of behaving this way is:
$$
\begin{align}
    q_\pi(s, a) &= \mathbb{E}\big[R_{t+1} + \gamma v_\pi(S_{t+1}) | S_t=s, A_t=a\big]\\
    &= \sum_{s^\prime, r}p(s^\prime, r | s, a)\big[r + \gamma v_\pi(s^\prime)\big]
\end{align}
$$
Now if, $q_\pi(s, a)$ is actually better than $v_\pi$ then it would be better to take action $a$ everytime $s$ is encountered rather than following policy $\pi$ all the time; infact it is always better to take action $a$ if the state is $s$.

### Policy Improvement theorem
Let $\pi$ and $\pi^\prime$ be any pair of deterministic policies such that, for all $s \in \mathcal{S}$:
$$
\begin{equation}
    q_\pi(s, \pi^\prime(s)) \ge v_\pi(s)
\end{equation}
$$
then policy $\pi^\prime$ must be as good as or better than policy $\pi$, given that policy $\pi$ shows that by following policy $\pi$ at state $s$ we get $v_\pi(s)$ which is less than or equal to what we get if we take action $\pi^\prime(s)$ at state $s$ and following policy $\pi$ thereafter.
$$
\begin{align}
    v_{\pi^\prime}(s) \ge v_\pi(s)
\end{align}
$$
So, begin at state $s$ under policy $\pi^\prime$ is better.<br>
Two policies $\pi^\prime$ and $\pi$ are identical except for state $s$ that is $\pi^\prime(s)=a\neq \pi(s)$. Also if there is strict inequality $q_\pi(s, a) \gt v_\pi(s)$ then the changed policy is better than policy $\pi$.

#### Proof
$$
\begin{align}
    v_\pi(s) &\le q_\pi(s, \pi^\prime(s))\\
    &= \mathbb{E}\big[R_{t+1} + \gamma v_\pi(S_{t+1}) | S_t=s, A_t=\pi^\prime(s)\big]\\
    &= \mathbb{E}_{\pi^\prime}\big[R_{t+1} + \gamma v_\pi(S_{t+1}) | S_t=s\big]\\
    &\le \mathbb{E}_{\pi^\prime}\big[R_{t+1} + \gamma q_\pi(S_{t+1}, \pi^\prime(S_{t+1})) | S_t=s\big]\\
    &= \mathbb{E}_{\pi^\prime}\bigg[R_{t+1} + \gamma \mathbb{E}_{\pi^\prime}\big[R_{t+2} + \gamma v_\pi(S_{t+2}) | S_{t+1}, A_{t+1}=\pi^\prime(S_{t+1})\big] | S_t=s\bigg]\\
    &= \mathbb{E}_{\pi^\prime}\big[R_{t+1} + \gamma R_{t+2} + \gamma^2 v_\pi(S_{t+2}) | S_t=s\big]\\
    &= \mathbb{E}_{\pi^\prime}\big[R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \gamma^3 v_\pi(S_{t+3}) | S_t=s\big]\\
    &\vdots\\
    &= \mathbb{E}_{\pi^\prime}\big[R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \gamma^3 R_{t+4} + \dots) | S_t=s\big]\\
    &= v_{\pi^\prime}(s)
\end{align}
$$

Given explaination above it is prominent that we should consider changes at all states and to all possible actions; given this, we should choose each action in each states greedily according to $q_\pi(s, a)$:

$$
\begin{align}
    \pi^\prime(s) &= \underset{a}{argmax}\ q_\pi(s, a)\\
    &= \underset{a}{argmax}\ \mathbb{E}\big[R_{t+1} + \gamma v_\pi(S_{t+1}) | S_t=s, A_t=a\big]\\
    &= \underset{a}{argmax}\ \sum_{s^\prime, r}p(s^\prime, r | s, a)\big[r + \gamma v_\pi(s^\prime)\big]
\end{align}
$$

although greedy policy only chooses the best short term action in one step look ahead but since it meets the conditions for policy improvement theorem, we should know that it is as good as or better than the original policy.

#### Policy improvement converges to $\pi_*$
We have shown that by improving policy $\pi$ to get policy $\pi^\prime$, the new policy is always either as good as original policy or better than the original policy.
Let's say a policy is as good as but not better than the old policy $\pi$. Then $v_\pi = v_{\pi^\prime}$ for all of the $s \in \mathcal{S}$:
$$
\begin{align}
    v_{\pi^\prime}(s) &= \underset{a}{max}\ \mathbb{E}\big[R_{t+1} + \gamma v_{\pi^\prime}(S_{t+1}) | S_t=s, A_t=a\big]\\
    &= \underset{a}{max}\ \sum_{s^\prime, r}p(s^\prime, r | s, a)\big[r + \gamma v_{\pi^\prime}(s^\prime)\big]
\end{align}
$$
And this is the same as Bellman Optimality Equation.<br>
Thus, $v_{\pi^\prime}$ is $v_*$ and $\pi$ and $\pi^\prime$ are optimal.

In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def iterative_policy_evaluation(grid_world_shape, rewards, transition, policy, V, gamma=1.0, 
                                threshold=1e-12):
    """
    args:
        grid_world_shape: shape of the grid world in 2D
        rewards:          denotes rewards function given state and action and the new-state, SxA
        transition:       function denoting transition probability from state s to s_prime, SxAxS
        policy:           the policy (pi) to be evaluated
        V:                estimated value at each state
        gamma:            discounting factor
        threshold:        determines the accuracy of estimation
    returns:
        v_pi:             approximation of policy evaluation
    """

    # Take the number of rows and columns in grid world shape
    rows, columns = grid_world_shape
    
    while True:
        delta = 0

        # i and j together denote state s
        for i in range(rows):
            for j in range(columns):
                v = V[i, j]
                # for each action find its expected return given action a is taken
                new_vs = 0
                for a in range(4):
                    # i_prime and j_prime denote state s_prime
                    for i_prime in range(rows):
                        for j_prime in range(columns):
                            new_vs += policy[a, i, j] * transition[i, j, a, i_prime, j_prime] * (reward[i, j, a] + gamma * V[i_prime, j_prime])
                # Update the value for state s
                V[i, j] = new_vs 

                # Storing the difference that each update makes
                delta = max(delta, np.abs(v - V[i, j]))

        # If the maximum difference made is less than the threshold break
        if delta <= threshold: break
    return V

In [3]:
def policy_improvement(grid_world_shape, rewards, transition, policy, value_function, gamma=1.0):
    """
    args:
        grid_world_shape: shape of the grid world in 2D
        rewards:          denotes rewards function given state and action and the new-state, SxA
        transition:       function denoting transition probability from state s to s_prime, SxAxS
        policy:           the policy (pi) to be evaluated
        value_function:   estimated value at each state
        gamma:            discounting factor
    returns:
        policy:           improved policy
    """
    # Take the number of rows and columns in grid world shape
    rows, columns = grid_world_shape
    
    for i in range(rows):
        for j in range(columns):
            # actions: stores expected value each action can observe given the state and action pair, q_pi
            actions = np.zeros((4,))
            for i_prime in range(rows):
                for j_prime in range(columns):
                    # calculates q_pi for each action
                    actions += transition[i, j, :, i_prime, j_prime] * (rewards[i, j] + gamma * value_function[i_prime, j_prime])
            policy[:, i, j] = np.zeros((4,))
            # Improves the policy
            policy[np.argmax(actions), i, j] = 1
                
    return policy

In [4]:
def policy_iteration(grid_world_shape, rewards, transition, gamma=1.0, threshold=1e-12):
    """
    args:
        grid_world_shape: shape of the grid world in 2D
        rewards:          denotes rewards function given state and action and the new-state, SxA
        transition:       function denoting transition probability from state s to s_prime, SxAxS
        gamma:            discounting factor
        threshold:        determines the accuracy of estimation
    returns:
        v_star:           approximation of optimal policy evaluation
        pi_star:          optimal policy
    """
    # Initialization
    V = np.zeros(grid_world_shape)
    policy = np.ones((4,) + grid_world_shape) / 4

    # Stores whether policy has improved
    policy_stable = False
    
    while not policy_stable:
        policy_stable = True
        V = iterative_policy_evaluation(grid_world_shape, rewards, transition, policy, V, gamma, threshold)
        # Stores the old policies greedy actions to see whether it has been improved
        old_policy = np.argmax(policy, axis=0)
        # Improvement
        policy = policy_improvement(grid_world_shape, rewards, transition, policy, V, gamma)
        # If policy has improved, it means that policy is not stable yet
        if not np.array_equal(old_policy, np.argmax(policy, axis=0)):
            policy_stable = False

    return V, policy

In [5]:
# Defining the shape of the grid world
grid_world_shape = (4, 4)


# Defining the reward function, SxAxS
reward = np.zeros(grid_world_shape + (4,)) - 1
reward[0, 0, :] = 0                                 # Terminal state at top left corner
reward[-1, -1, :] = 0                               # Terminal state at bottom right corner


# Defining the transition function
transition = np.zeros(grid_world_shape + (4,) + grid_world_shape)
# Let's define the transition function
for a in range(4):                                  # let's denote 0: up, 1:right, 2:down, 3:left
    for i in range(grid_world_shape[0]):
        for j in range(grid_world_shape[1]):
            if a == 0:                    
                transition[i, j, 0, max(0, i - 1), j] = 1
            if a == 1 :
                transition[i, j, 1, i, min(grid_world_shape[1] - 1, j + 1)] = 1
            if a == 2:
                transition[i, j, 2, min(grid_world_shape[0] - 1, i + 1), j] = 1
            if a == 3:
                transition[i, j, 3, i, max(j - 1, 0)] = 1
# Change the transition function for terminal states
transition[0, 0, :, :, :] = 0
transition[0, 0, :, 0, 0] = 1
transition[-1, -1, :, :, :] = 0
transition[-1, -1, :, -1, -1] = 1

In [6]:
V, policy = policy_iteration(grid_world_shape, reward, transition)

In [7]:
V

array([[ 0., -1., -2., -3.],
       [-1., -2., -3., -2.],
       [-2., -3., -2., -1.],
       [-3., -2., -1.,  0.]])

In [8]:
np.argmax(policy, axis=0)

array([[0, 3, 3, 2],
       [0, 0, 0, 2],
       [0, 0, 1, 2],
       [0, 1, 1, 0]])