In [96]:
import numpy as np
# Generate environment,'#' is forbidden area,'x' is target area
env = np.array([
    ['*','*','*','*','*'],
    ['*','#','#','*','*'],
    ['*','*','#','*','*'],
    ['*','#','x','#','*'],
    ['*','#','*','*','*']
])

class mc_epsilon_greedy:
    def __init__(self, env, _lambda,k,epsilon = 0.1, step = 100,threshold = 1e-3):
        self.env = env
        self._lambda = _lambda # discount rate
        self.k = k # maximum number of iterations
        self.epsilon = epsilon # epsilon-greedy
        self.step = step # number of steps in each episode
        self.threshold = threshold # threshold for convergence

        self.m,self.n = self.env.shape
        self.action_num = 5
        self.v = np.zeros((self.m,self.n)) # State value
        self.q = np.zeros((self.m,self.n,self.action_num)) # Action-value
        self.policy = np.zeros((self.m,self.n),dtype=int) # optimal policy
        self.returns = np.zeros((self.m,self.n,self.action_num)) # accumulated returns
        self.num = np.zeros((self.m,self.n,self.action_num)) # number of visits

    def next_state(self,x,y,a):
        """return the next state index"""
        xx, yy = [-1, 0, 1, 0, 0], [0, 1, 0, -1, 0] # action space（up, right, down, left, stay）
        reward = 0
        isboundary = False
        x_next = x + xx[a]
        y_next = y + yy[a]
        # check the boundary
        if x_next < 0 or x_next >= self.m or y_next < 0 or y_next >= self.n: 
            x_next, y_next = x, y
            isboundary = True
        # target area
        if self.env[x_next,y_next] == 'x' and not isboundary:
            reward = 10
        # boundary area
        elif isboundary:
            reward = -1
        # forbidden area
        elif self.env[x_next,y_next] == '#': 
            reward = -10
        return x_next, y_next, reward

    def generate_episode(self):
        """generate an episode"""
        episode = []
        # generate initial state
        while True:
            x = np.random.randint(0,self.m)
            y = np.random.randint(0,self.n)
            if self.env[x,y] != '#' and self.env[x,y] != 'x':
                break
        # generate episode
        for _ in range(self.step):
            # ε-greedy
            if np.random.rand() < self.epsilon:
                a = np.random.randint(0,self.action_num)
            else:
                a = np.argmax(self.q[x,y])
            # take action and get next state
            x_next, y_next, reward = self.next_state(x,y,a)
            # update episode
            episode.append((x,y,a,reward))
            # update state
            x, y = x_next, y_next
            # check if reach terminal state
            # if self.env[x,y] == 'x':
            #     break
        return episode

    def mc_update(self):
        """update the policy"""
        for _ in range(self.k):
            episode = self.generate_episode()
            g = 0
            # backward view
            for t in range(len(episode)-1,-1,-1):
                x,y,a,r = episode[t]
                g = self._lambda * g + r
                self.returns[x,y,a] += g
                self.num[x,y,a] += 1
                # MC
                self.q[x,y,a] = self.returns[x,y,a] / self.num[x,y,a]
            # update policy
            for x in range(self.m):
                for y in range(self.n):
                    self.policy[x,y] = np.argmax(self.q[x,y])
                    # update state value
                    self.v[x,y] = np.max(self.q[x,y])

    def show_policy(self):
        """show the optimal policy"""
        s = "↑→↓←O" # action display
        for x in range(self.m):
            for y in range(self.n):
                print(s[self.policy[x,y]], end=" ")
            print(" ")

        print(self.v)
        
if __name__ == "__main__":
    mc = mc_epsilon_greedy(env,
                           _lambda = 0.9, 
                           k = 100, 
                           epsilon=0.1, 
                           step=1000, 
                           threshold=1e-3)
    mc.mc_update()
    mc.show_policy()
    

↓ ← → ↓ ←  
↓ ↑ ↑ ↓ ←  
↓ → ↓ ← ←  
↓ → O ← ↑  
→ → ↑ ← ←  
[[ 3.04359354  4.77366723 31.723744   11.45695049  1.89489065]
 [30.19418021 -1.95697473 32.1812796  46.44044628  5.451296  ]
 [28.77182639 57.47799374 84.05079716 60.90980335  4.46417226]
 [39.98382044 84.38613195 84.25548803 85.03014743  4.24082235]
 [54.6726656  73.6805658  82.91692703 71.46611113 36.24197265]]
