In [188]:
import numpy as np

class ENV(object):
    def __init__(self):
        self.terminal = False
        self.state = 0 #0代表状态相遇，１代表状态卖拐。
        
    def reset(self):
        self.terminal = False
        self.state = 0
        
        return self.state
        
    def step(self, action):
        '''根据action返回reward和新的state
           action为0代表攀谈，1代表离开。
        '''
        if self.state == 0: #相遇
            if action==0: #攀谈
                self.reward = 0
                self.state = 1
            else: #离开
                self.reward = 10
                self.terminal = True
        else: #卖拐
            if action==0: #攀谈
                self.reward = -100
                self.terminal = True
            else: #离开
                self.reward = 100
                self.terminal = True
            
        return self.state, self.reward, self.terminal

In [173]:
class SarsaAgent(object):
    def __init__(self, state_dim, action_dim, learning_rate=0.1, gamma=0.9, epsilon=0.1):
        self.state_dim = state_dim #状态空间大小，有多少个状态就是几。
        self.action_dim = action_dim #动作空间大小，有多少个动作就是几。
        self.alpha = learning_rate #学习率，对应公式里的alpha
        self.gamma = gamma #累加因子对应公式里的gamma
        self.epsilon = epsilon #冒险系数，对应公式里的epsilon
        self.Q = np.zeros([state_dim, action_dim]) #Q表格，有多少个状态就有多少行，有多少个动作就有多少列。
        
    def getAction(self, current_state):
        random_p = np.random.uniform(0,1) #在标准高斯分布上随机选一个值
        
        if random_p<(1-self.epsilon): #如果随机值小于1-epsilon，则根据Q表格查询在当前状态下的最优动作。
            action = self.getActionByQ(current_state)
        else: #否则，随机选取一个动作。
            action = np.random.choice(self.action_dim)
        
        return action
    
    def getActionByQ(self, current_state):
        '''
        根据状态，查Q表格获得最优action
        '''
        Q_row = self.Q[current_state] #从Q表格中选出目前状态对应的行。
        maxQ = np.max(Q_row) #获得改行的最大值。
        actions = np.where(Q_row==maxQ)[0] #将所有最大值所对应的列号组成动作集合。
        action = np.random.choice(actions) #在动作集合中随机选取一个动作。
        
        return action
    
    def learn(self, current_state, action, reward, next_state, next_action, terminal):
        '''
        Q(S_t,A_t) <- Q(S_t,A_t) + alph[R_t+1+gammaQ(S_t+1,A_t+1)-Q(S_t,A_t)]
        根据算法中的更新公式，更新Q表格中，当前状态（行），当前action（列）的对应值Q_sa
        '''
        Q_sa = self.Q[current_state][action] #当前Q值
        Q_sa_next = self.Q[next_state][next_action] #下一时刻Q值
        if terminal: #episode结束时，更新Q值
            self.Q[current_state][action] = Q_sa + self.alpha*(reward-Q_sa)
        else: #更新Q值
            self.Q[current_state][action] = Q_sa + self.alpha*(reward + self.gamma*Q_sa_next-Q_sa)
            
    def showQTable(self):
        print(self.Q)
        
class QLAgent(SarsaAgent):
    def learn(self, current_state, action, reward, next_state, next_action, terminal):
        '''
        Q(S_t,A_t) <- Q(S_t,A_t) + alph[R_t+1+gammaQ(S_t+1,A_t+1)-Q(S_t,A_t)]
        根据算法中的更新公式，更新Q表格中，当前状态（行），当前action（列）的对应值Q_sa
        '''
        Q_sa = self.Q[current_state][action] #当前Q值
        Q_s_next = self.Q[next_state] #下一状态Q值
        if terminal: #episode结束时，更新Q值
            self.Q[current_state][action] = Q_sa + self.alpha*(reward-Q_sa)
        else: #更新Q值
            self.Q[current_state][action] = Q_sa + self.alpha*(reward + self.gamma*np.max(Q_s_next)-Q_sa)        

In [190]:
def episode(env, LF):
    '''
    agent与ENV互动一个episode，
    在此过程中对Q表格进行更新，直到episode结束。
    env: ENV对象
    LF： 老范的拼音缩写， Agent的对象
    '''
    total_reward = 0 #一个episode获得的总reward
    state = env.reset() #初始化ENV
    action = LF.getAction(state) #根据state获得action
    
    while True:
        next_state, reward, terminal = env.step(action) #用action跟ENV互动，获得下一个状态，reward和是否episode结束标志符
        next_action = LF.getAction(next_state) #根据下一个state获得下一个action
        LF.learn(state,action,reward,next_state,next_action,terminal) #更新Q表格
        total_reward += reward #累加reward
        action = next_action #准备下一次互动，更新action
        state = next_state #准备下一次互动，更新state

        if terminal: #到达终点则结束循环
            break
    
    return total_reward

env = ENV()
###以下算法二选一，注释掉不用的那行###
LF = SarsaAgent(2,2) #这里让老范具有sarsa算法的思维
# LF = QLAgent(2,2) #这里让老范具有Q-learning散发的思维

LF.showQTable()

for ep in range(1, 200):   
    total_reward = episode(env, LF)
    if ep%10==0:
        print("Episode {}, Total reward {}".format(ep, total_reward))
        LF.showQTable()
        

[[0. 0.]
 [0. 0.]]
Episode 10, Total reward 10
[[ 0.9        5.6953279]
 [ 0.        19.       ]]
Episode 20, Total reward 100
[[ 2.52        8.33228183]
 [ 0.         27.1       ]]
Episode 30, Total reward 10
[[ 2.52        9.41850263]
 [ 0.         27.1       ]]
Episode 40, Total reward 10
[[ 2.52       9.7972444]
 [ 0.        27.1      ]]
Episode 50, Total reward 10
[[ 2.52       9.9293035]
 [ 0.        27.1      ]]
Episode 60, Total reward 10
[[ 2.52        9.97534965]
 [ 0.         27.1       ]]
Episode 70, Total reward 10
[[ 4.707       9.99044995]
 [ 0.         34.39      ]]
Episode 80, Total reward 10
[[ 4.707      9.9966701]
 [ 0.        34.39     ]]
Episode 90, Total reward 10
[[ 4.707       9.99883894]
 [ 0.         34.39      ]]
Episode 100, Total reward 10
[[ 7.3314      9.99955018]
 [ 0.         40.951     ]]
Episode 110, Total reward 10
[[ 7.3314      9.99984316]
 [ 0.         40.951     ]]
Episode 120, Total reward 10
[[ 7.3314      9.99994531]
 [ 0.         40.951     