In [22]:
from typing import Callable

import numpy as np

num_states = 7 # 7 cases
S = np.arange(num_states)
A = np.array([0, 1])  # 0: left, 1 : right
T = np.array([0, num_states - 1]) # Terminal State ( position)
P = np.zeros((len(S), len(A), len(S), 2))

# P[s, 0, s - 1, 0] = faisabilité
# P[s, 0, s - 1, 1] = reward

for s in S[1:-1]: #  possible action
    P[s, 0, s - 1, 0] = 1.0
    P[s, 1, s + 1, 0] = 1.0
P[1, 0, 0, 1] = -1.0
P[num_states - 2, 1, num_states - 1, 1] = 1.0


def reset() -> int:
    return num_states // 2


def is_terminal(state: int) -> bool:
    return state in T


def step(state: int, a: int) -> (int, float, bool):
    assert (state not in T)
    s_p = np.random.choice(S, p=P[state, a, :, 0]) # trouve ta position d'arrivée
    r = P[state, a, s_p, 1] # recupere reward
    return s_p, r, (s_p in T) # position, reward, si terminal

In [30]:
def iterative_policy_evaluation(
        S: np.ndarray,
        A: np.ndarray,
        P: np.ndarray,
        T: np.ndarray,
        Pi: np.ndarray,
        gamma: float = 0.99,
        theta: float = 0.000001, # accuracy
        V: np.ndarray = None
) -> np.ndarray:
    assert 0 <= gamma <= 1
    assert theta > 0

    if V is None:
        V = np.random.random((S.shape[0],))
        V[T] = 0.0
    while True:
        delta = 0
        for s in S:
            v_temp = V[s]
            tmp_sum = 0
            for a in A:
                for s_p in S: # proba x faisabilité x (reward + longterme x Value s')
                    tmp_sum += Pi[s, a] * P[s, a, s_p, 0] * (
                            P[s, a, s_p, 1] + gamma * V[s_p]
                    )
            V[s] = tmp_sum
            delta = np.maximum(delta, np.abs(tmp_sum - v_temp))
        if delta < theta:
            break
    return V

In [19]:
def tabular_uniform_random_policy(space_size: int, action_size: int):
    return np.ones((space_size, action_size)) / action_size

In [31]:
import time

start_time = time.time()
Pi = tabular_uniform_random_policy(S.shape[0], A.shape[0])
V = iterative_policy_evaluation(S, A, P, T, Pi)
print("--- %s seconds ---" % (time.time() - start_time))
print(V)

--- 0.00805354118347168 seconds ---
[ 0.00000000e+00 -6.62272074e-01 -3.27823181e-01  2.22141125e-06
  3.27827074e-01  6.62274402e-01  0.00000000e+00]


In [34]:
print()

8.44212070472139e-06


In [38]:
st = reset()
V[0] = -1
V[-1] = 1
while (not is_terminal(st)):
    if V[st + 1] > V[st - 1]:
        a = 1
    else:
        a = 0
    st, r, term = step(st, a)
    print(st)

4
5
6
