<a href="https://colab.research.google.com/github/LunaTic-Neon/2025-2-RL/blob/main/25_2_0910_%EA%B0%95%ED%99%94%ED%95%99%EC%8A%B5_2%EC%A3%BC%EC%B0%A8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
import time

class GridWorld():
    def __init__(
        self,
        height=5,
        width=5,
        start_state=(0, 0),
        terminal_states=[(4, 4)],
        transition_reward=0.0,
        terminal_reward=1.0,
        outward_reward=0.0,
        warm_hole_states=None
    ):
        self.version = "0.0.1"
        self.HEIGHT = height
        self.WIDTH = width
        self.STATES = []
        self.num_states = self.WIDTH * self.HEIGHT

        for i in range(self.HEIGHT):
            for j in range(self.WIDTH):
                self.STATES.append((i, j))

        for state in terminal_states:
            self.STATES.remove(state)

        self.current_state = None

        self.ACTION_UP = 0
        self.ACTION_DOWN = 1
        self.ACTION_LEFT = 2
        self.ACTION_RIGHT = 3

        self.ACTIONS = [
            self.ACTION_UP,
            self.ACTION_DOWN,
            self.ACTION_LEFT,
            self.ACTION_RIGHT
        ]

        self.ACTION_SYMBOLS = ["", "", "", "→"]

        self.START_STATE = start_state
        self.TERMINAL_STATES = terminal_states
        self.WARM_HOLE_STATES = warm_hole_states
        self.transition_reward = transition_reward
        self.terminal_reward = terminal_reward
        self.outward_reward = outward_reward
        self.NUM_ACTIONS = len(self.ACTIONS)

    def reset(self):
        self.current_state = self.START_STATE
        return self.current_state

    def moveto(self, state):
        self.current_state = state

    def is_warm_hole_state(self, state):
        i, j = state
        if self.WARM_HOLE_STATES is not None and len(self.WARM_HOLE_STATES) > 0:
            for warm_hole_info in self.WARM_HOLE_STATES:
                warm_hole_state = warm_hole_info[0]
                if i == warm_hole_state[0] and j == warm_hole_state[1]:
                    return True
        return False

    def get_next_state_warm_hole(self, state):
        i, j = state
        next_state = None
        for warm_hole_info in self.WARM_HOLE_STATES:
            warm_hole_state = warm_hole_info[0]
            warm_hole_prime_state = warm_hole_info[1]
            if i == warm_hole_state[0] and j == warm_hole_state[1]:
                next_state = warm_hole_prime_state
                break
        return next_state

    def get_reward_warm_hole(self, state):
        i, j = state
        reward = None
        for warm_hole_info in self.WARM_HOLE_STATES:
            warm_hole_state = warm_hole_info[0]
            warm_hole_reward = warm_hole_info[2]
            if i == warm_hole_state[0] and j == warm_hole_state[1]:
                reward = warm_hole_reward
                break
        return reward

    def get_next_state(self, state, action):
        i, j = state

        if self.is_warm_hole_state(state):
            next_state = self.get_next_state_warm_hole(state)
            next_i = next_state[0]
            next_j = next_state[1]
        elif (i, j) in self.TERMINAL_STATES:
            next_i = i
            next_j = j
        else:
            if action == self.ACTION_UP:
                next_i = max(i - 1, 0)
                next_j = j
            elif action == self.ACTION_DOWN:
                next_i = min(i + 1, self.HEIGHT - 1)
                next_j = j
            elif action == self.ACTION_LEFT:
                next_i = i
                next_j = max(j - 1, 0)
            elif action == self.ACTION_RIGHT:
                next_i = i
                next_j = min(j + 1, self.WIDTH - 1)
            else:
                raise ValueError()

        return next_i, next_j

    def get_reward(self, state, next_state):
        i, j = state
        next_i, next_j = next_state

        if self.is_warm_hole_state(state):
            reward = self.get_reward_warm_hole(state)
        else:
            if (next_i, next_j) in self.TERMINAL_STATES:
                reward = self.terminal_reward
            else:
                if i == next_i and j == next_j:
                    reward = self.outward_reward
                else:
                    reward = self.transition_reward

        return reward

    def get_state_action_probability(self, state, action):
        next_i, next_j = self.get_next_state(state, action)
        reward = self.get_reward(state, (next_i, next_j))
        transition_prob = 1.0
        return (next_i, next_j), reward, transition_prob

    def step(self, action):
        next_state = self.get_next_state(
            state=self.current_state,
            action=action
        )
        reward = self.get_reward(self.current_state, next_state)
        self.current_state = next_state

        if self.current_state in self.TERMINAL_STATES:
            done = True
        else:
            done = False

        return next_state, reward, done, None

    def render(self, mode='human'):
        print(self.__str__())

    def get_random_action(self):
        return random.choice(self.ACTIONS)

    def __str__(self):
        gridworld_str = ""
        for i in range(self.HEIGHT):
            gridworld_str += "---"
        gridworld_str += "--\n"

        for i in range(self.HEIGHT):
            for j in range(self.WIDTH):
                if self.current_state[0] == i and self.current_state[1] == j:
                    gridworld_str += "| (*) ".format("*")
                elif (i, j) == self.START_STATE:
                    gridworld_str += "| (S) ".format("S")
                elif (i, j) in self.TERMINAL_STATES:
                    gridworld_str += "| (G) ".format("G")
                elif self.WARM_HOLE_STATES and (i, j) in [state[0] for state in self.WARM_HOLE_STATES]:
                    gridworld_str += "| (W) ".format("W")
                else:
                    gridworld_str += "| ({0},{1})".format(i, j)
            gridworld_str += "|\n"

        gridworld_str += "---" * self.WIDTH + "--\n"
        return gridworld_str

def main():
    env = GridWorld()
    env.reset()
    print("reset")
    env.render()
    done = False
    total_steps = 0

    while not done:
        total_steps += 1
        action = env.get_random_action()
        next_state, reward, done, _ = env.step(action)
        print("action: {0}, reward: {1}, done: {2}, total_steps: {3}".format(
            env.ACTION_SYMBOLS[action],
            reward,
            done,
            total_steps
        ))
        env.render()
        time.sleep(1)

def main_warm_hole():
    A_POSITION = (0, 1)
    B_POSITION = (0, 3)
    A_PRIME_POSITION = (4, 1)
    B_PRIME_POSITION = (2, 3)

    env = GridWorld(
        warm_hole_states=[
            (A_POSITION, A_PRIME_POSITION, 10.0),
            (B_POSITION, B_PRIME_POSITION, 5.0)
        ]
    )
    env.reset()
    print("reset")
    env.render()
    done = False
    total_steps = 0

    while not done:
        total_steps += 1
        action = env.get_random_action()
        next_state, reward, done, _ = env.step(action)
        print("action: {0}, reward: {1}, done: {2}, total_steps: {3}".format(
            env.ACTION_SYMBOLS[action],
            reward,
            done,
            total_steps
        ))
        env.render()
        time.sleep(1)

if __name__ == "__main__":
    main()
    # main_warm_hole()


reset
-----------------
| (*) | (0,1)| (0,2)| (0,3)| (0,4)|
| (1,0)| (1,1)| (1,2)| (1,3)| (1,4)|
| (2,0)| (2,1)| (2,2)| (2,3)| (2,4)|
| (3,0)| (3,1)| (3,2)| (3,3)| (3,4)|
| (4,0)| (4,1)| (4,2)| (4,3)| (G) |
-----------------

action: →, reward: 0.0, done: False, total_steps: 1
-----------------
| (S) | (*) | (0,2)| (0,3)| (0,4)|
| (1,0)| (1,1)| (1,2)| (1,3)| (1,4)|
| (2,0)| (2,1)| (2,2)| (2,3)| (2,4)|
| (3,0)| (3,1)| (3,2)| (3,3)| (3,4)|
| (4,0)| (4,1)| (4,2)| (4,3)| (G) |
-----------------

action: , reward: 0.0, done: False, total_steps: 2
-----------------
| (S) | (0,1)| (0,2)| (0,3)| (0,4)|
| (1,0)| (*) | (1,2)| (1,3)| (1,4)|
| (2,0)| (2,1)| (2,2)| (2,3)| (2,4)|
| (3,0)| (3,1)| (3,2)| (3,3)| (3,4)|
| (4,0)| (4,1)| (4,2)| (4,3)| (G) |
-----------------

action: , reward: 0.0, done: False, total_steps: 3
-----------------
| (S) | (0,1)| (0,2)| (0,3)| (0,4)|
| (*) | (1,1)| (1,2)| (1,3)| (1,4)|
| (2,0)| (2,1)| (2,2)| (2,3)| (2,4)|
| (3,0)| (3,1)| (3,2)| (3,3)| (3,4)|
| (4,0)| (4,1)| (4