In [None]:
import numpy as np

HEIGHT = 5  # 그리드월드 세로
WIDTH = 5  # 그리드월드 가로
TRANSITION_PROB = 1
POSSIBLE_ACTIONS = [0, 1, 2, 3]  # 상, 우, 하, 좌
ACTIONS = [(-1, 0), (0, 1), (1, 0), (0, -1)]  # 좌표로 나타낸 행동
REWARDS = []

class render_policy:
    def __init__(self, env):
      self.arrays = ['↑', '→', '↓', '←']
      self.table = [[]]
      self.env = env
    
    def init_table(self):
      self.table = [['ㅤ'] * (3*self.env.width) for _ in range(3*self.env.height)]
      for i in range(3):
        for j in range(3):
          self.table[6 + i][6 + j] = '--'
          self.table[3 + i][6 + j] = '@@'
          self.table[6 + i][3 + j] = '@@'
    
    def draw(self, policy, e):
      self.init_table()
      for state in self.env.get_all_states():
        probs = policy.get_policy(state)
        for idx, val in enumerate(probs):
          if val != 0:
            if idx == 0:
              self.table[3 * state[0]][3 * state[1] + 1] = self.arrays[idx]
            elif idx == 1:
              self.table[3 * state[0] + 1][3 * state[1] + 2] = self.arrays[idx]
            elif idx == 2:
              self.table[3 * state[0] + 2][3 * state[1] + 1] = self.arrays[idx]
            else:
              self.table[3 * state[0] + 1][3 * state[1]] = self.arrays[idx]
      self.render(e)
      return
    
    def render(self, e):
      print('-----------------------------------')
      print(f'Episode {e}')
      for row in self.table:
        for i, v in enumerate(row):
          if i % 3 != 2:
            print(v, end='')
          else:
            print(v, end=' ')
        print()
      print('-----------------------------------')

class Env:
    def __init__(self):
        self.transition_probability = TRANSITION_PROB
        self.width = WIDTH
        self.height = HEIGHT
        self.reward = [[0] * WIDTH for _ in range(HEIGHT)]
        self.possible_actions = POSSIBLE_ACTIONS
        self.reward[2][2] = 1  # (2,2) 좌표 동그라미 위치에 보상 1
        self.reward[1][2] = -1  # (1,2) 좌표 세모 위치에 보상 -1
        self.reward[2][1] = -1  # (2,1) 좌표 세모 위치에 보상 -1
        self.all_state = []

        for x in range(WIDTH):
            for y in range(HEIGHT):
                state = [y, x]
                self.all_state.append(state)

    def get_reward(self, state, action_idx):
        next_state = self.state_after_action(state, action_idx)
        return self.reward[next_state[1]][next_state[0]]

    def state_after_action(self, state, action_index):
        action = ACTIONS[action_index]
        return self.check_boundary([state[0] + action[0], state[1] + action[1]])

    @staticmethod
    def check_boundary(state):
        state[0] = (0 if state[0] < 0 else WIDTH - 1
                    if state[0] > WIDTH - 1 else state[0])
        state[1] = (0 if state[1] < 0 else HEIGHT - 1
                    if state[1] > HEIGHT - 1 else state[1])
        return state

    def get_transition_prob(self, state, action):
        return self.transition_probability

    def get_all_states(self):
        return self.all_state