In [1]:
import numpy as np
import copy

# MDP の設定
p = [0.8, 0.5, 1.0]

# 割引率の設定
gamma = 0.95

# 報酬期待値の設定
r = np.zeros((3, 3, 2))
r[0, 1, 0] = 1.0
r[0, 2, 0] = 2.0
r[0, 0, 1] = 0.0
r[1, 0, 0] = 1.0
r[1, 2, 0] = 2.0
r[1, 1, 1] = 1.0
r[2, 0, 0] = 1.0
r[2, 1, 0] = 0.0
r[2, 2, 1] = -1.0

# 価値関数の初期化
v = [0, 0, 0]
v_new = copy.copy(v)

# 行動価値関数の初期化
q = np.zeros((3, 2))

# 方策分布の初期化
pi = [0.5, 0.5, 0.5]

# 価値反復法の計算
for step in range(1000):

    for i in range(3):

        # 行動価値関数を計算
        q[i, 0] = p[i] * (
            r[i, (i + 1) % 3, 0] + gamma * v[(i + 1) % 3]
        ) + (1 - p[i]) * (r[i, (i + 2) % 3, 0]
                          + gamma * v[(i + 2) % 3])
        q[i, 1] = r[i, i, 1] + gamma * v[i]

        # 行動価値関数のもとで greedy に方策を改善
        if q[i, 0] > q[i, 1]:
            pi[i] = 1
        elif q[i, 0] == q[i, 1]:
            pi[i] = 0.5
        else:
            pi[i] = 0

    # 改善された方策のもとで価値関数を計算
    v_new = np.max(q, axis=-1)

    # 計算された価値関数 v_new が前ステップの値 v を改善しなければ終了
    if np.min(v_new - v) <= 0:
        break

    # 価値関数を更新
    v = copy.copy(v_new)

    # 現ステップの価値関数と方策を表示
    print('step:', step, ' value:', v, ' policy:', pi)


step: 0  value: [1.2 1.5 1. ]  policy: [1, 1, 1]
step: 1  value: [2.53  2.545 2.14 ]  policy: [1, 1, 1]
step: 2  value: [3.5408  3.71825 3.4035 ]  policy: [1, 1, 1]
step: 3  value: [4.672535  4.7985425 4.36376  ]  policy: [1, 1, 1]
step: 4  value: [5.6760067  5.79224012 5.43890825]  policy: [1, 1, 1]
step: 5  value: [6.63549506 6.7795846  6.39220636]  policy: [1, 1, 1]
step: 6  value: [7.56700351 7.68815818 7.30372031]  policy: [1, 1, 1]
step: 7  value: [8.43070707 8.56359381 8.18865333]  policy: [1, 1, 1]
step: 8  value: [9.26417543 9.39419619 9.00917172]  policy: [1, 1, 1]
step: 9  value: [10.05133173 10.1798399   9.80096666]  policy: [1, 1, 1]
step: 10  value: [10.79886199 10.92984174 10.54876515]  policy: [1, 1, 1]
step: 11  value: [11.5109451  11.64012289 11.25891889]  policy: [1, 1, 1]
step: 12  value: [12.18568798 12.31568539 11.93539784]  policy: [1, 1, 1]
step: 13  value: [12.82764649 12.95751577 12.57640358]  policy: [1, 1, 1]
step: 14  value: [13.43722866 13.56692378 13.1862