### 问题描述

<img src="./gridworld.png" width=40%>

In [67]:
# 对gridworld进行建模
LENGTH, WIDTH = 4, 4 #水平为长，垂直为宽
TERMINAL = [0, 15]
S = [i for i in range(LENGTH*WIDTH)]
A = ["l", "r", "u", "d"]
GAMMA = 1.0

ds_actions = {"l": -1, "r": 1, "u": -LENGTH, "d": LENGTH}

def dynamics(s, a): #环境动态特征
    global TERMINAL, LENGTH
    reward = 0 if s in TERMINAL else -1
    s_prime = s
    if (s < LENGTH and a == "u") \
       or (s%LENGTH == 0 and a == "l") \
       or ((s+1)%LENGTH == 0 and a == "r") \
       or (s >= (WIDTH-1)*LENGTH and a == "d") \
       or s in TERMINAL:
        pass
    else:
        s_prime = s+ds_actions[a]
    is_end = True if s_prime in TERMINAL else False
    return s_prime, reward, is_end

for a in A:
    print(dynamics(12, a))

def P(s, a, s_prime): #状态转移函数
    return dynamics(s,a)[0] == s_prime

# print(1*P(4, "u", 0))

def R(s, a): #奖励函数
    return dynamics(s,a)[1]

(12, -1, False)
(13, -1, False)
(8, -1, False)
(12, -1, False)


In [68]:
# 定义策略
def uniform_pi(s=None, a=None, V=None): #均匀随机策略
    global A
    n = len(A)
    return 0 if n==0 else 1.0/n

def greedy_pi(s, a, V): #贪心策略
    global S, A, P, R, GAMMA
    max_v = max([dynamics(s, a_opt)[1] + GAMMA*V[dynamics(s, a_opt)[0]] for a_opt in A]) 
    a_max_v = []
    for a_opt in A:
        if (dynamics(s, a_opt)[1] + GAMMA*V[dynamics(s, a_opt)[0]]) == max_v:
            a_max_v.append(a_opt)
    n = len(a_max_v)
    if n==0: return 0.0
    return 1.0/n if a in a_max_v else 0.0

def get_pi(pi, s, a, V): #将两种策略进行统一
    return pi(s, a, V)    

In [72]:
# 策略迭代
def compute_q(s, a, V):
    global R, GAMMA
    # print(str(s) + "-" + a + str(dynamics(s, a)[0]))
    return R(s, a) + GAMMA*V[dynamics(s, a)[0]]

# V = [0 for _ in range(LENGTH*WIDTH)]
# print(compute_q(4, "l", V))

def compute_v(pi, s, V): #迭代求v
    global S, A, R, P, GAMMA
    v_s = 0
    for a in A:
        v_s += get_pi(pi,s,a,V) * compute_q(s, a, V)
    return v_s

def update_V(pi, V):
    global S
    for s in S:
        V[s] = compute_v(pi, s, V)

def policy_evaluate(pi, n, V):
    for i in range(n):
        update_V(pi, V)
        
def policy_iterate(pi, m, n, V): #GPI
    for i in range(m):
        policy_evaluate(pi, n, V)
        pi = greedy_pi
    return V

In [73]:
# 价值迭代
def compute_v_from_max_q(s, V):
    global S, A, R, P, GAMMA
    v_s = float("-inf")
    for a in A:
        q_sa = compute_q(s, a, V)
        if q_sa > v_s:
            v_s = q_sa
    return v_s

def update_V_without_pi(V):
    global S
    for s in S:
        V[s] = compute_v_from_max_q(s, V)

def value_iterate(n, V):
    for i in range(n):
        update_V_without_pi(V)

In [76]:
# 验证
def display_V(V):
    global LENGTH, WIDTH
    for i in range(LENGTH*WIDTH):
        print('{0:>6.2f}'.format(V[i]), end=" ")
        if (i+1)%LENGTH == 0:
            print("")
    print()
        
V = [0 for _ in range(LENGTH*WIDTH)]
policy_evaluate(uniform_pi, 100, V)
display_V(V)

V = [0 for _ in range(LENGTH*WIDTH)]
value_iterate(4, V)
display_V(V)

  0.00 -14.00 -20.00 -22.00 
-14.00 -18.00 -20.00 -20.00 
-20.00 -20.00 -18.00 -14.00 
-22.00 -20.00 -14.00   0.00 

  0.00  -1.00  -2.00  -3.00 
 -1.00  -2.00  -3.00  -2.00 
 -2.00  -3.00  -2.00  -1.00 
 -3.00  -2.00  -1.00   0.00 

