In [1]:
import numpy as np 

m = 3
m2 = m ** 2
q = np.zeros(m2)
q[m2 // 2] = 1
q

array([0., 0., 0., 0., 1., 0., 0., 0., 0.])

In [2]:
m2 = m ** 2
P = np.zeros((m2, m2))
ix_map = {i + 1: (i // m, i % m) for i in range(m2)}
ix_map

{1: (0, 0),
 2: (0, 1),
 3: (0, 2),
 4: (1, 0),
 5: (1, 1),
 6: (1, 2),
 7: (2, 0),
 8: (2, 1),
 9: (2, 2)}

In [4]:
def get_P(m, p_up, p_down, p_left, p_right):
    m2 = m ** 2
    P = np.zeros((m2, m2))
    ix_map = {i + 1: (i // m, i % m) for i in range(m2)}
    for i in range(m2):
        for j in range(m2):
            r1, c1 = ix_map[i + 1]
            r2, c2 = ix_map[j + 1]
            rdiff = r1 - r2
            cdiff = c1 - c2
            if rdiff == 0:
                if cdiff == 1:
                    P[i, j] = p_left
                elif cdiff == -1:
                    P[i, j] = p_right
                elif cdiff == 0:
                    if r1 == 0:
                        P[i, j] += p_down
                    elif r1 == m - 1:
                        P[i, j] += p_up
                    if c1 == 0:
                        P[i, j] += p_left
                    elif c1 == m - 1:
                        P[i, j] += p_right
            elif rdiff == 1:
                if cdiff == 0:
                    P[i, j] = p_down
            elif rdiff == -1:
                if cdiff == 0:
                    P[i, j] = p_up
    return P

In [5]:
P = get_P(3, 0.2, 0.3, 0.25, 0.25)

In [6]:
n = 1
Pn = np.linalg.matrix_power(P, n)
np.matmul(q, Pn)

array([0.  , 0.3 , 0.  , 0.25, 0.  , 0.25, 0.  , 0.2 , 0.  ])

In [7]:
n = 3
Pn = np.linalg.matrix_power(P, n)
np.round(np.matmul(q, Pn), 3)

array([0.124, 0.176, 0.124, 0.137, 0.061, 0.137, 0.068, 0.107, 0.068])

In [8]:
n = 10
Pn = np.linalg.matrix_power(P, n)
np.round(np.matmul(q, Pn), 3)

array([0.156, 0.156, 0.156, 0.105, 0.106, 0.105, 0.072, 0.072, 0.072])

In [9]:
n = 100
Pn = np.linalg.matrix_power(P, n)
np.round(np.matmul(q, Pn), 3)

array([0.158, 0.158, 0.158, 0.105, 0.105, 0.105, 0.07 , 0.07 , 0.07 ])

## Ergodic MC

In [12]:
# itemfreq deprecated
# https://stackoverflow.com/questions/73556675/issue-with-importing-textexplainer-from-eli5-lime-package-relating-to-deprecated
import scipy
import numpy as np
def monkeypath_itemfreq(sampler_indices):
    return zip(*np.unique(sampler_indices, return_counts=True))

scipy.stats.itemfreq=monkeypath_itemfreq

In [13]:
from scipy.stats import itemfreq

s = 4
n = 10 ** 6
visited = [s]

In [14]:
for t in range(n):
    s = np.random.choice(m2, p=P[s, :])
    visited.append(s)

In [32]:
# itemfreq(visited)
for i in range(len(q)):
    print(f"{i}: {visited.count(i)}")

0: 157505
1: 158194
2: 158284
3: 105487
4: 105123
5: 105345
6: 70211
7: 69793
8: 70059


In [16]:
P

array([[0.55, 0.25, 0.  , 0.2 , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.25, 0.3 , 0.25, 0.  , 0.2 , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.25, 0.55, 0.  , 0.  , 0.2 , 0.  , 0.  , 0.  ],
       [0.3 , 0.  , 0.  , 0.25, 0.25, 0.  , 0.2 , 0.  , 0.  ],
       [0.  , 0.3 , 0.  , 0.25, 0.  , 0.25, 0.  , 0.2 , 0.  ],
       [0.  , 0.  , 0.3 , 0.  , 0.25, 0.25, 0.  , 0.  , 0.2 ],
       [0.  , 0.  , 0.  , 0.3 , 0.  , 0.  , 0.45, 0.25, 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.3 , 0.  , 0.25, 0.2 , 0.25],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.3 , 0.  , 0.25, 0.45]])

# MRP

## Modify P

In [33]:
P = np.zeros((m2 + 1, m2 + 1))
P[:m2, :m2] = get_P(3, 0.2, 0.3, 0.25, 0.25)
for i in range(m2):
    P[i, m2] = P[i, i]
    P[i, i] = 0
P[m2, m2] = 1
P

array([[0.  , 0.25, 0.  , 0.2 , 0.  , 0.  , 0.  , 0.  , 0.  , 0.55],
       [0.25, 0.  , 0.25, 0.  , 0.2 , 0.  , 0.  , 0.  , 0.  , 0.3 ],
       [0.  , 0.25, 0.  , 0.  , 0.  , 0.2 , 0.  , 0.  , 0.  , 0.55],
       [0.3 , 0.  , 0.  , 0.  , 0.25, 0.  , 0.2 , 0.  , 0.  , 0.25],
       [0.  , 0.3 , 0.  , 0.25, 0.  , 0.25, 0.  , 0.2 , 0.  , 0.  ],
       [0.  , 0.  , 0.3 , 0.  , 0.25, 0.  , 0.  , 0.  , 0.2 , 0.25],
       [0.  , 0.  , 0.  , 0.3 , 0.  , 0.  , 0.  , 0.25, 0.  , 0.45],
       [0.  , 0.  , 0.  , 0.  , 0.3 , 0.  , 0.25, 0.  , 0.25, 0.2 ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.3 , 0.  , 0.25, 0.  , 0.45],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  ]])

In [34]:
# add rewards
n = 10 ** 5
avg_rewards = np.zeros(m2)
for s in range(9):
    for i in range(n):
        crashed = False
        s_next = s
        episode_reward = 0
        while not crashed:
            s_next = np.random.choice(m2 + 1, p=P[s_next, :])
            if s_next < m2:
                episode_reward += 1
            else:
                crashed = True
        avg_rewards[s] += episode_reward
avg_rewards /= n

In [35]:
np.round(avg_rewards, 2)

array([1.48, 2.13, 1.48, 2.45, 3.43, 2.45, 2.  , 2.82, 1.99])

## Analytically calculate the state values

In [38]:
R = np.ones(m2 + 1)
R[-1] = 0
R

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 0.])

In [39]:
inv = np.linalg.inv(np.eye(m2 + 1) - 0.9999 * P)
v = np.matmul(inv, np.matmul(P, R))
print(np.round(v, 2))

[1.47 2.12 1.47 2.44 3.42 2.44 1.99 2.82 1.99 0.  ]


## Estimating state values

In [43]:
def estimate_state_values(P, m2, threshold):
    v = np.zeros(m2 + 1)
    max_change = threshold
    terminal_state = m2
    while max_change >= threshold:
        max_change = 0
        for s in range(m2 + 1):
            v_new = 0
            for s_next in range(m2 + 1):
                r = 1 * (s_next != terminal_state)
                v_new += P[s, s_next] * (r + v[s_next])
                print(f"s_next: {s_next}, r: {r}, v_new: {v_new}")
            max_change = max(max_change, np.abs(v[s] - v_new))
            v[s] = v_new
                
    return np.round(v, 2)

In [44]:
estimate_state_values(P, m2, 0.005)

s_next: 0, r: 1, v_new: 0.0
s_next: 1, r: 1, v_new: 0.25
s_next: 2, r: 1, v_new: 0.25
s_next: 3, r: 1, v_new: 0.45
s_next: 4, r: 1, v_new: 0.45
s_next: 5, r: 1, v_new: 0.45
s_next: 6, r: 1, v_new: 0.45
s_next: 7, r: 1, v_new: 0.45
s_next: 8, r: 1, v_new: 0.45
s_next: 9, r: 0, v_new: 0.45
s_next: 0, r: 1, v_new: 0.3625
s_next: 1, r: 1, v_new: 0.3625
s_next: 2, r: 1, v_new: 0.6125
s_next: 3, r: 1, v_new: 0.6125
s_next: 4, r: 1, v_new: 0.8125
s_next: 5, r: 1, v_new: 0.8125
s_next: 6, r: 1, v_new: 0.8125
s_next: 7, r: 1, v_new: 0.8125
s_next: 8, r: 1, v_new: 0.8125
s_next: 9, r: 0, v_new: 0.8125
s_next: 0, r: 1, v_new: 0.0
s_next: 1, r: 1, v_new: 0.453125
s_next: 2, r: 1, v_new: 0.453125
s_next: 3, r: 1, v_new: 0.453125
s_next: 4, r: 1, v_new: 0.453125
s_next: 5, r: 1, v_new: 0.653125
s_next: 6, r: 1, v_new: 0.653125
s_next: 7, r: 1, v_new: 0.653125
s_next: 8, r: 1, v_new: 0.653125
s_next: 9, r: 0, v_new: 0.653125
s_next: 0, r: 1, v_new: 0.435
s_next: 1, r: 1, v_new: 0.435
s_next: 2, r: 1,

array([1.47, 2.12, 1.47, 2.44, 3.42, 2.44, 1.99, 2.82, 1.99, 0.  ])

## Estimating state values