In [1]:
import numpy as np
import sys
if "../" not in sys.path:
    sys.path.append("../")
from lib.envs.GridEnv import Env
from lib.utils.draw import show_grid,show_policy


In [2]:
def policy_evaluation(policy,env,gamma=1.0,theta= 0.001):
    """
    arg:
        policy:  策略函数
        env:     环境
        gamma:   折扣回报率
        theta:   确定 v_k 与 v_k+1 的收敛
    """
    Q = np.zeros( (env.observation_size, env.action_size) )
    while True:
        delta = 0
        
        for s in range(env.observation_size):
            for a in range(env.action_size):
                q = 0
                for next_state, reward, done in env.P[s][a]:
                    teq = 0
                    for action,action_prob in enumerate(policy[next_state]):
                        teq += action_prob * Q[next_state,action]
                    q = reward + gamma * teq
                delta = max(delta,abs(q-Q[s,a]))
                Q[s,a] = q
        # stop 
        # print(delta)
        if delta < theta:
            break
    return np.array(Q)


In [3]:
def policy_iterator_action(env,gamma=1.0,theta= 0.00001):
    
    # 随机策略
    policy = np.ones( (env.observation_size,env.action_size) ) / env.action_size
    V = np.zeros(env.observation_size )
    while True:
        
        # 策略是否收敛了 （上一轮的迭代和这一轮的没变化）
        policy_stable  = True
        Q = policy_evaluation(policy,env,gamma,theta)
        
        for s in range(env.observation_size):
            chosen_action = np.argmax(policy[s])
            
            # action value
            action_value = Q[s]
            best_action = np.argmax(action_value)    
            V[s] = Q[s][best_action]
            if chosen_action != best_action:
                policy_stable = False
            
            # 更新策略，独热编码
            policy[s] = np.eye(env.action_size)[best_action]
        
        if policy_stable:
            return policy, V
    

In [4]:
env = Env((5,5),p=0.5,seed=5,punish=-10)
policy, V = policy_iterator_action(env,gamma=0.9)
print("原始图像")
show_grid(env,env.grid)
print("state value function")
show_grid(env,V)
print("policy")
show_policy(env,policy)