In [1]:
import numpy as np

In [6]:
probas = np.zeros((3,2,3))  # states, actions, states
probas[0, 0, 1] = 0.3
probas[0, 0, 2] = 0.7
probas[0, 1, 2] = 1
probas[1, 0, 0] = 1
probas[2, 0, 1] = 1
probas[2, 1, 2] = 1

rewards = np.zeros((3,2))   # states, actions
rewards[0, 0] = 0
rewards[0, 1] = -1
rewards[1, 0] = 2
rewards[2, 0] = 2
rewards[2, 1] = -3

gamma = 0.9
threshold = 1e-2

In [45]:
def value_iteration(P, R, gamma, theta):
    V = np.zeros(3)
    Pol = np.zeros(3, dtype=int)   # only to track policies
    while True:
        delta = 0
        for s in [0, 1, 2]:
            S = np.unique(np.argwhere(P[s,:,:] != 0)[:,1])
            v = V[s]
            V[s] = np.max([np.sum([P[s, a, s_]*(R[s, a] + gamma*V[s_]) for s_ in S]) for a in [0,1]])
            Pol[s] = np.argmax([np.sum([P[s, a, s_]*(R[s, a] + gamma*V[s_]) for s_ in S]) for a in [0,1]])
            delta = max(delta, np.abs(v-V[s]))
        if delta < theta:
            return V, Pol

In [46]:
val_values, val_actions = value_iteration(probas, rewards, gamma, threshold)
val_values, val_actions

(array([11.86962133, 12.6826592 , 13.41439328]), array([0, 0, 0]))

In [40]:
def policy_iteration(P, R, gamma, theta):
    # 1
    V = np.zeros(3)
    Pol = np.ones(3, dtype=int)

    while True:
        # 2
        while True:
            delta = 0
            for s in [0, 1, 2]:
                S = np.unique(np.argwhere(P[s,:,:] != 0)[:,1])
                v = V[s]
                V[s] = np.sum([P[s, Pol[s], s_]*(R[s, Pol[s]] + gamma*V[s_]) for s_ in S])
                Pol[s] = np.argmax([np.sum([P[s, a, s_]*(R[s, a] + gamma*V[s_]) for s_ in S]) for a in [0,1]])
                delta = max(delta, np.abs(v-V[s]))
            if delta < theta:
                break
        # 3
        policy_stable = True
        for s in [0, 1, 2]:
            old_action = Pol[s]
            Pol[s] = np.argmax([np.sum([P[s, a, s_]*(R[s, a] + gamma*V[s_]) for s_ in S]) for a in [0,1]])
            if old_action != Pol[s]:
                policy_stable = False
        if policy_stable:
            return V, Pol

In [41]:
pol_values, pol_actions = policy_iteration(probas, rewards, gamma, threshold)
pol_values, pol_actions

(array([11.86590806, 12.67931725, 13.41138553]), array([0, 0, 0]))

In [50]:
print('-------------- Value Iteration --------------')
print('Optimum Actions:')
[print(f'S{i+1}:',a+1) for i, a in enumerate(val_actions)];
print('Max Values:')
[print(f'S{i+1}:',round(v,3)) for i, v in enumerate(val_values)];
print('-------------- Policy Iteration --------------')
print('Optimum Actions:')
[print(f'S{i+1}:',a+1) for i, a in enumerate(pol_actions)];
print('Max Values:')
[print(f'S{i+1}:',round(v,3)) for i, v in enumerate(pol_values)];


-------------- Value Iteration --------------
Optimum Actions:
S1: 1
S2: 1
S3: 1
Max Values:
S1: 11.87
S2: 12.683
S3: 13.414
-------------- Policy Iteration --------------
Optimum Actions:
S1: 1
S2: 1
S3: 1
Max Values:
S1: 11.866
S2: 12.679
S3: 13.411
