In [1]:
# https://github.com/seungeunrho/RLfrombasics 참조
import random
import numpy as np

In [2]:
METHOD = 3 # 1 : E Greedy, 2 : nTD, 3 : Q Learning

## MAP
  0 1 2 3 4 5 6
0 S 0 X 0 0 0 0
1 0 0 X 0 0 0 0
2 0 0 X 0 X 0 0
3 0 0 0 0 X 0 0
4 0 0 0 0 X 0 G

In [3]:
# environment class
class GridWorld():
    def __init__(self):
        self.x = 0
        self.y = 0
        
    def step(self, a):
        if a == 0:
            self.move_left()
            #print('left')
        elif a == 1:
            self.move_right()
            #print('right')            
        elif a == 2:
            self.move_up()
            #print('up')            
        elif a == 3:
            self.move_down()
            #print('down')            
            
        reward = -1
        done = self.is_done()
        return (self.y, self.x), reward, done
    
    def move_left(self):
        if self.x == 0:
            pass
        elif self.x == 3 and self.y in [0, 1, 2]:
            pass
        elif self.x == 5 and self.y in [2, 3, 4]:
            pass
        else:
            self.x -= 1        
            
    def move_right(self):
        if self.x == 1 and self.y in [0, 1, 2]:
            pass
        elif self.x == 3 and self.y in [2, 3, 4]:
            pass
        elif self.x == 6:
            pass
        else:
            self.x += 1        
            
           
    def move_up(self):
        if self.y == 0:
            pass
        elif self.y == 3 and self.x == 2:
            pass               
        else:
            self.y -= 1
          
    def move_down(self):
        if self.y == 4:
            pass
        elif self.y == 1 and self.x == 4:
            pass
        else:    
            self.y += 1
            
    def is_done(self):
        if self.y == 4 and self.x == 6:
            return True
        else:
            return False
        
    def get_state(self):
        return (self.y, self.x)
    
    def reset(self):
        self.x = 0
        self.y = 0
        return (self.y, self.x)
    


In [4]:
# Agent class
class QAgent():
    def __init__(self):
        # q value를 저장하는 변수. 모두 0으로 초기화
        self.q_table = np.zeros((5,7,4))
        self.eps = 0.9
        self.alpha = 0.01
    
    def select_action(self, s):
        # eps-greedy로 액션을 선택
        y, x = s
        coin = random.random()
        #print('coin -> ',coin)
        
        if coin < self.eps:
            action = random.randint(0,3)
        else:
            action_val = self.q_table[y, x,:]
            # print(y,x,action_val)
            # argmax는 가장 큰 값을 가지고 있는 위치
            action = np.argmax(action_val)
            # print('action -> ', action)
            
        return action

    if METHOD == 1:
        def update_table(self, history):
            # 한 에피소드에 해당하는 history를 입력으로 받아 q 테이블의 값을 업데이트
            cum_reward = 0
            for transition in history[::-1]:
                s, a, r, s_prime = transition
                y, x = s

                # MC(몬테카를로)방식을 이용하여 업데이트
                #print('transition -> ',transition)
                #print('q_table1 -> ',self.q_table[y,x,a])
                # MC 수식 Q(S0,A0) = Q(S0,A0) + alpha * (R * gamma의 t-1 제곱 - Q(S0,A0))
                self.q_table[y,x,a] = self.q_table[y,x,a] + self.alpha * (cum_reward - self.q_table[y,x,a])
                cum_reward = cum_reward + r
                #print('q_table2 -> ',self.q_table[y,x,a])
                #print('cum_reward -> ',cum_reward) 
                #print('--------------------------------')    
    elif METHOD == 2:
        def update_table(self, transition):
            s, a, r, s_prime = transition
            y, x = s
            next_y, next_x = s_prime
            
            # 벨만 기대 방적식에 따른 진행
            # s'에서 선택할 Action이며, 실제로 취한 Action이 아님(정책에 따라 진행)
            a_prime = self.select_action(s_prime)
            
            # SARSA 수식 Q(S0,A0) = Q(S0,A0) + alpha * (R + gamma * Q(S1,A1) - Q(S0,A0))
            self.q_table[y,x,a] = self.q_table[y,x,a] + 0.1 * (r + self.q_table[next_y,next_x,a_prime]- self.q_table[y,x,a])        
    else:
        def update_table(self, transition):
            s, a, r, s_prime = transition
            y, x = s
            next_y, next_x = s_prime
            
            # 벨만 최적 방적식에 따른 진행
            # Q Learning 수식 Q(S0,A0) = Q(S0,A0) + alpha * (R + gamma * MAX(Q(S1,A1)) - Q(S0,A0))
            self.q_table[y,x,a] = self.q_table[y,x,a] + 0.1 * (r + np.amax(self.q_table[next_y,next_x,:])- self.q_table[y,x,a])
    

    def anneal_eps(self):
        self.eps -= 0.03
        self.eps = max(self.eps, 0.1)
        
    def show_table(self):
        # 학습이 각 위치에서 어느 액션의 q값이 가장 높았는지 보여주는 함수
        q_lst = self.q_table.tolist()
        data = np.zeros((5,7))
        for row_idx in range(len(q_lst)):
            row = q_lst[row_idx]
            for col_idx in range(len(row)):
                col = row[col_idx]
                action = np.argmax(col)
                data[row_idx, col_idx] = action
        #print(self.q_table)
        print('0:left, 1:right, 2:up, 3:down' )
        print(data)
       

In [5]:
def main():
    env = GridWorld()
    agent = QAgent()

    for k in range(1000): # 총 1000번의 에피소드 진행
        done = False
        history = []
        
        s = env.reset()
        if METHOD == 1:
            while not done:
                action = agent.select_action(s)
                s_prime, reward, done = env.step(action)
                #print(s, action, s_prime)
                history.append((s, action, reward, s_prime))
                s = s_prime

            # history를 이용하여 에이전트를 업데이트
            agent.update_table(history)
        else:
            while not done:
                action = agent.select_action(s)
                s_prime, reward, done = env.step(action)
                #print(s, action, s_prime)
                agent.update_table((s, action, reward, s_prime))
                s = s_prime
                
        agent.anneal_eps()

    # 학습이 끝난 결과를 출력
    agent.show_table()

In [6]:
main()

0:left, 1:right, 2:up, 3:down
[[3. 3. 0. 2. 3. 3. 3.]
 [3. 3. 0. 1. 1. 3. 3.]
 [1. 3. 0. 2. 0. 1. 3.]
 [1. 1. 1. 2. 0. 3. 3.]
 [2. 3. 1. 2. 0. 1. 0.]]
