# n-step-Sarsa 对应《强化学习的数学原理》第7章

# n-step-Sarsa更新公式的推导
具体数学推导请看《强化学习的数学原理》第7章
<br>代码对应《动手学强化学习》第5章 [请点击这里](https://hrl.boyuai.com/chapter/1/%E6%97%B6%E5%BA%8F%E5%B7%AE%E5%88%86%E7%AE%97%E6%B3%95)

<img src="./picture5_2.png" alt="插入图片哈哈" width="50%">

# 环境
还是悬崖漫步环境

In [37]:
import numpy as np
class CliffWalkingEnv:
    def __init__(self, ncol, nrow):
        self.nrow = nrow
        self.ncol = ncol
        self.x = 0              # 记录当前智能体位置的横坐标  。初始位置在左下角
        self.y = self.nrow - 1  # 记录当前智能体位置的纵坐标

    def step(self, action):  # 外部调用这个函数来改变当前位置
        # 4种动作, change[0]:上, change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)
        # 定义在左上角
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
        self.x = min(self.ncol - 1, max(0, self.x + change[action][0]))
        self.y = min(self.nrow - 1, max(0, self.y + change[action][1]))
        next_state = self.y * self.ncol + self.x
        reward = -1
        done = False
        if self.y == self.nrow - 1 and self.x > 0:  # 下一个位置在悬崖或者目标
            done = True
            if self.x != self.ncol - 1:
                reward = -100
        return next_state, reward, done

    def reset(self):  # 回归初始状态,坐标轴原点在左上角
        self.x = 0
        self.y = self.nrow - 1
        return self.y * self.ncol + self.x

## n-step-Sarsa

In [38]:
class nstep_Sarsa:
    """ n步Sarsa算法 """
    def __init__(self, n, ncol, nrow, epsilon, alpha, gamma, n_action=4):
        self.Q_table = np.zeros([nrow * ncol, n_action])
        self.n_action = n_action
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.n = n             # 采用n步Sarsa算法
        self.state_list = []   # n-step-Sarsa,用来保存这n步的状态
        self.action_list = []  # 保存动作
        self.reward_list = []  # 保存奖励

    def take_action(self, state):
        if np.random.random() < self.epsilon:
            action = np.random.randint(self.n_action)
        else:
            action = np.argmax(self.Q_table[state])
        return action

    def best_action(self, state):  # 用于打印策略，训练的过程中不用
        Q_max = np.max(self.Q_table[state])
        a = [0 for _ in range(self.n_action)]
        for i in range(self.n_action):
            if self.Q_table[state, i] == Q_max:
                a[i] = 1
        return a

    def update(self, s0, a0, r, s1, a1, done):
        self.state_list.append(s0)
        self.action_list.append(a0)
        self.reward_list.append(r)
        if len(self.state_list) == self.n:  # 若保存的数据可以进行n步更新，state_list存够n个了才进行
            G = self.Q_table[s1, a1]        # 得到Q(s_{t+n}, a_{t+n})
            for i in reversed(range(self.n)):
                G = self.gamma * G + self.reward_list[i]  # 不断向前计算每一步的回报

                if done == True and i > 0:  # 如果这里episode已经结束了，但是i还没有到0，说明什么，后面不够n步但是episode已经结束了
                    s = self.state_list[i]  # 那么就用剩下的那几步来更新
                    a = self.action_list[i]
                    self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])

            # 更新q
            s = self.state_list[0]
            a = self.action_list[0]
            self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])

            # 删除已经更新的，即多步中的第一个
            self.state_list.pop(0)
            self.action_list.pop(0)
            self.reward_list.pop(0)

        if done == True:  # 如果到达终止状态,即将开始下一条序列,则将列表全清空
            self.state_list = []
            self.action_list = []
            self.reward_list = []

# 训练

## 实例化

In [39]:
ncol = 12
nrow = 4
env = CliffWalkingEnv(ncol, nrow)
np.random.seed(0)
n_step = 5    # 5步Sarsa算法
alpha = 0.1
epsilon = 0.1
gamma = 0.9
agent = nstep_Sarsa(n_step, ncol, nrow, epsilon, alpha, gamma)

## 开始训练

In [40]:
num_episodes = 2000
return_list = []
for i in range(num_episodes):
    episode_return = 0
    state = env.reset()
    action = agent.take_action(state)
    done = False
    while done == False:
        next_state, reward, done = env.step(action)
        next_action = agent.take_action(next_state)
        episode_return += reward
        agent.update(state, action, reward, next_state, next_action,done)
        #调用update,就开始把s,a,r传进列表，所以这个对应的是新的，也就是后面的
        #一开始不够n个时，就不计算不更新，直到传够5个了再开始算，所以当不够n个时，update就只传
        #此时还在while循环里，直到传够了，开始算了才会出现done=Ture
        state = next_state
        action = next_action
    return_list.append(episode_return)

## 训练结束

### 打印q表

In [41]:
agent.Q_table

array([[ -8.67124173,  -8.94154735,  -8.83455162,  -8.00642999],
       [ -8.55453599,  -8.51395348,  -8.5663023 ,  -7.83402992],
       [ -8.32748643,  -8.89485253,  -8.45162647,  -7.57242019],
       [ -7.75715682,  -8.11546513,  -8.00649119,  -7.10185879],
       [ -8.531435  ,  -8.26629271,  -7.91088849,  -6.78081316],
       [ -7.22722434,  -8.32551707,  -7.93510965,  -6.30315173],
       [ -7.12117162,  -7.24727071,  -7.23725892,  -5.86014523],
       [ -6.69694967,  -7.71968737,  -6.9547111 ,  -5.36068139],
       [ -6.45075368,  -6.28342993,  -6.43214199,  -4.77878561],
       [ -5.01249857,  -6.3181405 ,  -6.31587254,  -4.18353934],
       [ -4.96363731,  -3.97572523,  -5.355453  ,  -3.54703933],
       [ -4.270789  ,  -2.82146915,  -4.63821007,  -4.20049381],
       [ -8.26545803,  -9.75824778,  -8.94944896,  -9.48025437],
       [ -8.2327698 , -10.37261831, -15.06644406, -15.18647659],
       [ -7.99701568, -22.25696147, -11.2487198 , -10.73462018],
       [ -7.76388045, -11

### 策略可视化

In [42]:
def print_agent(agent, env, action_meaning, disaster=[], end=[]):
    for i in range(env.nrow):
        for j in range(env.ncol):
            if (i * env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * env.ncol + j) in end:
                print('EEEE', end=' ')
            else:
                a = agent.best_action(i * env.ncol + j)   #根据q-table输出策略
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()

In [43]:
action_meaning = ['^', 'v', '<', '>']
print('5步Sarsa算法最终收敛得到的策略为：')
print_agent(agent, env, action_meaning, list(range(37, 47)), [47])

5步Sarsa算法最终收敛得到的策略为：
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo ^ooo ^ooo ^ooo ooo> ooo> ^ooo ooo> ooo> ^ooo ooo> ovoo 
^ooo oo<o ^ooo ooo> ^ooo ^ooo oo<o ooo> ^ooo ^ooo ooo> ovoo 
^ooo **** **** **** **** **** **** **** **** **** **** EEEE 
