### The construction of MDP
(1) statement space: S={A, B}<br>
(2) action space: {0, 1}<br>
(3) transition:<br>
* at A: action0->A(reward=1), action1->B(reward=0)<br>
* at B: absorb state(reward=0)<br>

(4) old policy<br>
* at A: pi(0)=0.5, pi(1)=0.5<br>
* at B: do not have meaning<br>

(5) new policy<br>
* at A: pi'(0)=1.0 (always take action 0)<br>

(6) discount factor: gamma = 0.9<br>


In [None]:
import numpy as np
gamma = 0.9

# States indexed as 0=A, 1=B
# Actions 0 or 1

# Transition probabilities
# P[s][a] = (next_state, reward)
P = {
    0: {  # state A
        0: (0, 1),  # action 0: stay in A, reward=1
        1: (1, 0),  # action 1: go to B, reward=0
    },
    1: {  # state B (terminal)
        0: (1, 0), # action 0: stay in B, reward = 0
        1: (1, 0) # action 1: stay in B, reward = 0
    }
}

In [2]:
# Old policy π
pi = {
    0: {0: 0.5, 1: 0.5},  # at A
    1: {0: 1.0, 1: 0.0},  # at B
}

# New policy π'
pi_prime = {
    0: {0: 1.0, 1: 0.0},  # always pick 0 at A
    1: {0: 1.0, 1: 0.0},
}

In [4]:
# Compute Jπ via fixed point iteration
def compute_J(pi):
    J = np.zeros(2)
    for _ in range(200):
        new_J = np.zeros(2)
        for s in [0,1]:
            val = 0
            for a, p_a in pi[s].items():
                s2, r = P[s][a]
                val += p_a * (r + gamma * J[s2])
            new_J[s] = val
        J = new_J
    return J

J_pi = compute_J(pi)
J_pi_prime = compute_J(pi_prime)

print("J_pi =", J_pi)
print("J_pi_prime =", J_pi_prime)


J_pi = [0.90909091 0.        ]
J_pi_prime = [9.99999999 0.        ]


In [5]:
# Compute Qπ and Aπ
Q_pi = np.zeros((2,2))
A_pi = np.zeros((2,2))

for s in [0,1]:
    for a in [0,1]:
        s2, r = P[s][a]
        Q_pi[s][a] = r + gamma * J_pi[s2]
        A_pi[s][a] = Q_pi[s][a] - J_pi[s]

print("\nQ_pi:\n", Q_pi)
print("\nA_pi:\n", A_pi)


Q_pi:
 [[1.81818182 0.        ]
 [0.         0.        ]]

A_pi:
 [[ 0.90909091 -0.90909091]
 [ 0.          0.        ]]


In [6]:
# Difference Lemma RHS: simulate π'
def DL_rhs(s0, steps=10):
    s = s0
    total = 0
    for k in range(steps):
        a_probs = pi_prime[s]
        a = 0 if a_probs[0] > 0.5 else 1
        total += (gamma**k) * A_pi[s][a]
        s, _ = P[s][a]
    return total

rhs = DL_rhs(0, steps=50)


print("\nDifference Lemma RHS =", rhs)
print("Actual difference =", J_pi_prime[0] - J_pi[0])


Difference Lemma RHS = 9.044056589024361
Actual difference = 9.090909083854015
