In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class MDP():
    def __init__(self, n_s, n_a):
        self.n_a = n_a
        self.n_s = n_s
        self.p = self.trasition_prob_matrix()
        self.r = self.reward_distribution(0,1)
        self.state_space = np.arange(n_s)
        self.action_space = np.arange(n_a)
        self.state = None
       
    def trasition_prob_matrix(self):
        k = 4
        pts = np.random.uniform(1,1000,(self.n_s,self.n_a,k))
        tpm = np.zeros((self.n_s,self.n_a,self.n_s))
        t = np.arange(0,1,1/self.n_s)
        for s in range(self.n_s):
            for a in range(self.n_a):
                tpm[s,a] = np.array([sum([pts[s,a,pt]*np.exp(-(k*x - (pt+0.5))**2)
                                           for pt in range(k)]) for x in t])
                tpm[s,a] /= sum(tpm[s,a])
                tpm[s,a] = np.clip(tpm[s,a], 0.001, 1)
                tpm[s,a] /= sum(tpm[s,a])
        return tpm
    
    def reward_distribution(self, r_min, r_max):
        return np.random.uniform(r_min, r_max, size=(self.n_s, self.n_a))
   
    def reset(self):
        self.state = np.random.randint(self.n_s)
        return self.state
        
    def step(self, action):
        s = self.state
        self.state = np.random.choice(self.state_space,p=self.p[s,action])
        self.reward = self.r[s]#, self.state]
        return self.state, self.reward
    
    def sample_action(self):
        return np.random.randint(n_a)

In [3]:
n_s = 10
n_a = 8
M = MDP(n_s, n_a)

In [4]:
def value_iteration(MDP):
    H = np.zeros(MDP.n_s)
    V = np.zeros(MDP.n_s)
    
    policy = np.zeros((MDP.n_s)).astype(np.int32)
    k = 0
    while True:
        H_old = H.copy()
        for s in range(MDP.n_s):
            Q = {}
            for a in range(MDP.n_a):
                Q[a] = MDP.r[s,a] + sum(MDP.p[s,a,s_nxt]*H_old[s_nxt] for s_nxt in range(MDP.n_s))
            V[s] = max(Q.values())
            H[s] = V[s] - V[k]
            policy[s] = max(Q, key=Q.get)
        if all(abs(H.reshape(-1) - H_old.reshape(-1)) < 1e-10):
            break
    return policy, V[k]

In [5]:
policy, value = value_iteration(M)

In [6]:
print(policy, value)

[4 3 2 0 5 1 7 7 3 5] 0.874112417387294


In [7]:
def policy_evaluation_iterative(MDP, policy):
    H = np.zeros(MDP.n_s)
    V = np.zeros(MDP.n_s)
    k = 0
    while True:
        H_old = H.copy()
        for s in range(MDP.n_s):
            V[s] = MDP.r[s,policy[s]] + sum(MDP.p[s,policy[s],s_nxt]*H_old[s_nxt] for s_nxt in range(MDP.n_s))
            H[s] = V[s] - V[k]
        if all(abs(H.reshape(-1) - H_old.reshape(-1)) < 1e-10):
            break
    return V[k], H

In [8]:
value, h = policy_evaluation_iterative(M, policy)

In [9]:
print(value)
print(h)

0.874112417387294
[0.         0.57424818 0.31864832 0.59085252 0.42420043 0.5770066
 0.56172762 0.4097849  0.49208389 0.51953923]


In [10]:
def policy_iteration(MDP):
    pol = np.random.randint(MDP.n_a, size=MDP.n_s)
    V, H = policy_evaluation_iterative(MDP, pol)
    while True:
        V_old = V.copy()
        H_old = H.copy()
        for s in range(MDP.n_s):
            Q = {}
            for a in range(MDP.n_a):
                Q[a] = MDP.r[s,a] + sum(MDP.p[s,a,s_nxt]*H[s_nxt] for s_nxt in range(MDP.n_s))
            pol[s] = max(Q, key=Q.get)
        V, H = policy_evaluation_iterative(MDP, pol)
        if V == V_old and all(abs(H.reshape(-1) - H.reshape(-1)) == 0):
            break
    return policy, V

In [11]:
policy, value = policy_iteration(M)

In [12]:
print(policy, value)

[4 3 2 0 5 1 7 7 3 5] 0.874112417387294
