In [1]:
import numpy as np

states = ["s1", "s2", "s3", "s4"]
gamma = 0.8

policy = {
    ("s1", "a1"): 0.1, ("s1", "a2"): 0.9,
    ("s2", "a2"): 1.0,
    ("s3", "a1"): 0.1, ("s3", "a3"): 0.9,
    ("s4", "a1"): 1.0
}

transitions = {
    ("s1", "a1"): [("s2", 0.1, 2.0), ("s3", 0.3, 3.0), ("s4", 0.6, -1.0)],
    ("s1", "a2"): [("s3", 1.0, 5.0)],
    ("s2", "a2"): [("s1", 1.0, 3.0)],
    ("s3", "a1"): [("s1", 1.0, -3.0)],
    ("s3", "a3"): [("s3", 0.2, 1.0), ("s4", 0.8, 6.0)],
    ("s4", "a1"): [("s1", 0.6, 5.0), ("s2", 0.4, -3.0)]
}

actions_by_state = {s: [] for s in states}
for (s, a) in policy:
    actions_by_state[s].append(a)

V = {s: 0.0 for s in states}
theta = 1e-12

for _ in range(10**6):
    delta = 0.0
    for s in states:
        v_old = V[s]
        v_new = 0.0
        for a in actions_by_state[s]:
            pi = policy[(s, a)]
            for s2, p, r in transitions[(s, a)]:
                v_new += pi * p * (r + gamma * V[s2])
        V[s] = v_new
        delta = max(delta, abs(v_new - v_old))
    if delta < theta:
        break

V_round = {s: round(V[s], 3) for s in states}
print("V(s1) =", V_round["s1"])
print("V(s2) =", V_round["s2"])
print("V(s3) =", V_round["s3"])
print("V(s4) =", V_round["s4"])
print(V_round)


V(s1) = 18.718
V(s2) = 17.975
V(s3) = 17.784
V(s4) = 16.537
{'s1': 18.718, 's2': 17.975, 's3': 17.784, 's4': 16.537}
