In [2]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt



In [3]:
class CliffWalkingEnv:
    def __init__(self, ncol, nrow):
        self.nrow = nrow
        self.ncol = ncol
        self.x = 0  # 记录当前智能体位置的横坐标
        self.y = self.nrow - 1  # 记录当前智能体位置的纵坐标

    def step(self, action):  # 外部调用这个函数来改变当前位置
        # 4种动作, change[0]:上, change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)
        # 定义在左上角
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
        self.x = min(self.ncol - 1, max(0, self.x + change[action][0]))
        self.y = min(self.nrow - 1, max(0, self.y + change[action][1]))
        next_state = self.y * self.ncol + self.x
        reward = -1
        done = False
        if self.y == self.nrow - 1 and self.x > 0:  # 下一个位置在悬崖或者目标
            done = True
            if self.x != self.ncol - 1:
                reward = -100
        return next_state, reward, done

    def reset(self):  # 回归初始状态,坐标轴原点在左上角
        self.x = 0
        self.y = self.nrow - 1
        return self.y * self.ncol + self.x

In [4]:
class Sarsa:

    def __init__(self, col, row, learning_rate, gamma, epsilon, action_num = 4):
        self.Q_table = np.zeros([col*row, action_num])
        self.action_num = action_num
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.epsilon = epsilon

    def take_action(self, state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.action_num)
        else:
            action = np.argmax(self.Q_table[state])
        return action
    
    def get_policy(self, state):
        Q_max = np.argmax(self.Q_table[state])
        a = [0 for i in range(self.action_num)]
        for i in range(self.action_num):
            if self.Q_table[state, i] == Q_max:
                a[i] = 1
        return a
    
    def sarsa_update(self, s0, a0, r, s1, a1):
        td_error = r + self.gamma * self.Q_table[s1, a1] - self.Q_table[s0, a0]
        self.Q_table[s0, a0] += self.learning_rate * td_error



In [20]:
col = 12
row = 4
env = CliffWalkingEnv(col, row)
np.random.seed(0)
epsilon = 0.1
learning_rate = 0.1
gamma = 0.9
agent = Sarsa(col, row, learning_rate, gamma, epsilon)

num_episodes = 500

In [22]:
return_list = []
for i in range(10):
    with tqdm(total=int(num_episodes/10), desc="Iteration %d" %i) as pbar:
        for i_episode in range(int(num_episodes/10)):
            episode_return = 0
            state = env.reset()
            action = agent.take_action(state)
            done = False
            while not done:
                next_state, reward, done = env.step(action)
                next_action = agent.take_action(next_state)
                episode_return += reward
                agent.sarsa_update(state, action, reward, next_state, next_action)
                state = next_state
                action = next_action
            return_list.append(episode_return)
        if (i_episode + 1) % 10 == 0:
            pbar.set_postfix({
                    'episode':
                    '%d' % (num_episodes / 10 * i + i_episode + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
            })
        pbar.update(1)


Iteration 0:   2%|▏         | 1/50 [00:00<00:00, 100.23it/s, episode=50, return=-18.200]
Iteration 1:   2%|▏         | 1/50 [00:00<00:00, 83.53it/s, episode=100, return=-26.100]
Iteration 2:   2%|▏         | 1/50 [00:00<00:00, 83.53it/s, episode=150, return=-16.000]
Iteration 3:   2%|▏         | 1/50 [00:00<00:00, 73.96it/s, episode=200, return=-17.100]
Iteration 4:   2%|▏         | 1/50 [00:00<00:00, 66.86it/s, episode=250, return=-15.900]
Iteration 5:   2%|▏         | 1/50 [00:00<00:00, 77.16it/s, episode=300, return=-25.000]
Iteration 6:   2%|▏         | 1/50 [00:00<00:00, 86.41it/s, episode=350, return=-18.500]
Iteration 7:   2%|▏         | 1/50 [00:00<00:00, 58.97it/s, episode=400, return=-27.100]
Iteration 8:   2%|▏         | 1/50 [00:00<00:00, 77.15it/s, episode=450, return=-28.200]
Iteration 9:   2%|▏         | 1/50 [00:00<00:00, 66.86it/s, episode=500, return=-19.200]
