In [1]:
import numpy as np     #只需要下载numpy库即可
import random

In [2]:
class GridWorld_v1(object): 
    # 初版gridworld，没有写操作逻辑以及得分逻辑，目的是用来计算policy iteration和value iteration
    # n行，m列，随机若干个forbiddenArea，随机若干个target
    # A1: move upwards
    # A2: move rightwards;
    # A3: move downwards;
    # A4: move leftwards;
    # A5: stay unchanged;
    # 操作与

    stateMap = None  #大小为rows*columns的list，每个位置存的是state的编号
    scoreMap = None  #大小为rows*columns的list，每个位置存的是奖励值 0 1 -1
    score = 0             #targetArea的得分
    forbiddenAreaScore=0  #forbiddenArea的得分

    
    def __init__(self,rows=4, columns=5, forbiddenAreaNums=3, targetNums=1, seed = -1, score = 1, forbiddenAreaScore = -1, desc=None):
        self.score = score
        self.forbiddenAreaScore = forbiddenAreaScore
        if(desc != None):
            #if the gridWorld is fixed
            self.rows = len(desc)
            self.columns = len(desc[0])
            l = []
            for i in range(self.rows):
                tmp = []
                for j in range(self.columns):
                    tmp.append(forbiddenAreaScore if desc[i][j]=='#' else score if desc[i][j]=='T' else 0)
                l.append(tmp)
            self.scoreMap = np.array(l)
            self.stateMap = [[i*self.columns+j for j in range(self.columns)] for i in range(self.rows)]
            return
            
        #if the gridWorld is random
        self.rows = rows
        self.columns = columns
        self.forbiddenAreaNums = forbiddenAreaNums
        self.targetNums = targetNums
        self.seed = seed

        random.seed(self.seed)
        l = [i for i in range(self.rows * self.columns)]
        random.shuffle(l)  #用shuffle来重排列
        self.g = [0 for i in range(self.rows * self.columns)]
        for i in range(forbiddenAreaNums):
            self.g[l[i]] = forbiddenAreaScore;        # 设置禁止进入的区域，惩罚为1
        for i in range(targetNums):
            self.g[l[forbiddenAreaNums+i]] = score # 奖励值为1的targetArea
            
        self.scoreMap = np.array(self.g).reshape(rows,columns)
        self.stateMap = [[i*self.columns+j for j in range(self.columns)] for i in range(self.rows)]

    def show(self):
        for i in range(self.rows):
            s = ""
            for j in range(self.columns):
                tmp = {0:"⬜️",self.forbiddenAreaScore:"🚫",self.score:"✅"}
                s = s + tmp[self.scoreMap[i][j]]
            print(s)
        
    def getScore(self, nowState, action):
        nowx = nowState // self.columns
        nowy = nowState % self.columns
        
        if(nowx<0 or nowy<0 or nowx>=self.rows or nowy>=self.columns):
            print(f"coordinate error: ({nowx},{nowy})")
        if(action<0 or action>=5 ):
            print(f"action error: ({action})")
            
        # 上右下左 不动
        actionList = [(-1,0),(0,1),(1,0),(0,-1),(0,0)]
        tmpx = nowx + actionList[action][0]
        tmpy = nowy + actionList[action][1]
        # print(tmpx,tmpy)
        if(tmpx<0 or tmpy<0 or tmpx>=self.rows or tmpy>=self.columns):
            return -1,nowState
        return self.scoreMap[tmpx][tmpy],self.stateMap[tmpx][tmpy]

In [3]:
def showPolicy(gridworld : GridWorld_v1, policy):
    #用emoji表情，可视化策略，在平常的可通过区域就用普通箭头⬆️➡️⬇️⬅️
    #但若是forbiddenArea，那就十万火急急急,于是变成了双箭头⏫︎⏩️⏬⏪
    rows = gridworld.rows
    columns = gridworld.columns
    s = ""
    for i in range(gridworld.rows * gridworld.columns):
        nowx = i // columns
        nowy = i % columns
        if(gridworld.scoreMap[nowx][nowy]==gridworld.score):
            s = s + "✅"
        if(gridworld.scoreMap[nowx][nowy]==0):
            tmp = {0:"⬆️",1:"➡️",2:"⬇️",3:"⬅️",4:"🔄"}
            s = s + tmp[policy[i]]
        if(gridworld.scoreMap[nowx][nowy]==gridworld.forbiddenAreaScore):
            tmp = {0:"⏫️",1:"⏩️",2:"⏬",3:"⏪",4:"🔄"}
            s = s + tmp[policy[i]]
        if(nowy == columns-1):
            print(s)
            s = ""

In [4]:
gamma = 0.9   #折扣因子，越接近0越近视
rows = 5      #记得行数和列数这里要同步改
columns = 5
# gridworld = GridWorld_v1(rows=rows, columns=columns, forbiddenAreaNums=4, targetNums=2, seed = random.randint(1,1000))
# gridworld = GridWorld_v1(desc = [".#",".T"])             #赵老师4-1的例子
# gridworld = GridWorld_v1(desc = ["##.T","...#","...."])  #随便弄的例子
gridworld = GridWorld_v1(forbiddenAreaScore=-10, score=1,desc = [".....",".##..","..#..",".#T#.",".#..."]) 
gridworld.show()


value = np.zeros(rows*columns)       #初始化可以任意，也可以全0
qtable = np.zeros((rows*columns,5))  #初始化，这里主要是初始化维数，里面的内容会被覆盖所以无所谓
policy = np.argmax(qtable,axis=1)    #初始策略
showPolicy(gridworld,policy)

⬜️⬜️⬜️⬜️⬜️
⬜️🚫🚫⬜️⬜️
⬜️⬜️🚫⬜️⬜️
⬜️🚫✅🚫⬜️
⬜️🚫⬜️⬜️⬜️
⬆️⬆️⬆️⬆️⬆️
⬆️⏫︎⏫︎⬆️⬆️
⬆️⬆️⏫︎⬆️⬆️
⬆️⏫︎✅⏫︎⬆️
⬆️⏫︎⬆️⬆️⬆️


In [5]:
policy = np.random.randint(0,5,size=(rows*columns)) 
#随机[0,5)的整数，代表策略
#这里其实不严谨，因为策略是可以不百分百选一个方向的，可以0.5向上，0.5向右，诸如此类。
#但先不考虑那种情况，因为画图不好画，代码实现逻辑是没差多少的

In [6]:
#求解贝尔曼方程
gridworld.show()                     #打印gridworld
showPolicy(gridworld,policy)         #打印策略
print("random policy")
#policy evaluation
value = np.zeros(rows*columns)
value_pre = value.copy()+1

cnt = 0
while(np.sum((value_pre-value)**2)>0.001):
    #policy evaluation
    
    value_pre = value.copy() #用来验证整个迭代是否收敛的

    value0 = value.copy()+1  #这里是随机一个值，然后通过迭代的方式求解贝尔曼方程
                             #这里写了固定，也可以随机，最终都会收敛到同一个结果
    truncatedCnt = 10       # 1:迭代50次  2：迭代26次 3：迭代18次 4：迭代14次  10：迭代6次 100：迭代2次
    while(np.sum((value0-value)**2)>0.001):
        value0 = value.copy()
        
        truncatedCnt = truncatedCnt-1  #这里这里加个限制，其实就是truncated policy iteration了
        if truncatedCnt<0:             #其实就是贝尔曼迭代次数
            break
                
        for i in range(rows * columns):   #使用当前策略policy，计算每个state的value，进行迭代
            j = policy[i]                 #不用遍历5个action了，直接百分百选择policy的策略
            score, nextState = gridworld.getScore(i,j)   #返回得分以及下一步的state id
            value[i] = score + value0[nextState] * gamma #贝尔曼迭代
    
    #policy improvement
    for i in range(rows * columns):
        for j in range(5): # 5个action
            score,nextState = gridworld.getScore(i,j)        #获取Si状态中，执行动作j后的（得分，下一个状态）
            qtable[i][j] = score + gamma * value[nextState]  #开始迭代
    policy = np.argmax(qtable,axis=1)  #更新策略，非常无敌

    showPolicy(gridworld, policy)      #各种打印信息
    print(value.reshape(rows,columns))
    cnt = cnt+1
    print(cnt)

    

⬜️⬜️⬜️⬜️⬜️
⬜️🚫🚫⬜️⬜️
⬜️⬜️🚫⬜️⬜️
⬜️🚫✅🚫⬜️
⬜️🚫⬜️⬜️⬜️
⬇️➡️⬅️⬇️🔄
⬇️⏩️⏬⬆️🔄
⬆️⬆️⏪⬆️⬇️
⬆️🔄✅🔄⬆️
⬇️⏬➡️➡️➡️
random policy
➡️➡️➡️➡️⬇️
⬆️⏫︎⏫︎⬆️⬆️
⬆️⬅️⏩️⬆️⬆️
⬆️⏪✅⏫︎⬆️
⬆️⏩️🔄⬅️⬆️
[[  0.           0.           0.           0.           0.        ]
 [  0.         -51.71774599 -46.3530511    0.           0.        ]
 [  0.         -53.05918699 -44.26648389   0.           0.        ]
 [  0.         -65.13215599 -65.13215599 -65.13215599   0.        ]
 [ -6.5132156   -6.5132156   -4.6132156   -5.5132156   -6.5132156 ]]
1
➡️➡️➡️➡️⬇️
⬆️⏫︎⏫︎⬆️⬆️
⬆️⬅️⏩️⬆️⬆️
⬆️⏫︎✅⏫︎⬆️
⬆️⏪⬆️➡️⬆️
[[ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.         -1.60852882  0.          0.        ]
 [ 0.         -1.60852882 -1.60852882 -1.60852882  0.        ]]
2
➡️➡️➡️➡️⬇️
⬆️⏫︎⏫︎⬆️⬆️
⬆️⬅️⏬⬆️⬆️
⬆️⏩️✅⏪⬆️
⬆️⏩️⬆️⬅️⬆️
[[0.         0.         0.         0.         0.        ]
 [0.         0.

In [7]:
print('⬜️🚫')
print('⬜️✅')
print('⬆️➡️⬇️⬅️🔄')
print('⏫︎⏩️⏬⏪🔄✅')

tmp = "⏫︎⏩️⏬⏪🔄"

⬜️🚫
⬜️✅
⬆️➡️⬇️⬅️🔄
⏫︎⏩️⏬⏪🔄✅
