# Q-Learning snake

## 스네이크 게임

In [8]:
import numpy as np
import random
import time

from IPython.display import clear_output


class SnakeGame:
    def __init__(self, width=10, height=10):
        self.width = width
        self.height = height
        self.reset()

    def reset(self):
        self.snake = [(self.width // 2, self.height // 2)]
        self.score = 0
        self.food = None
        self.place_food()
        self.direction = 'UP'
        self.game_over = False

    def place_food(self):
        while self.food is None:
            food = (random.randint(0, self.width - 1), random.randint(0, self.height - 1))
            if food not in self.snake:
                self.food = food

    def move(self, direction):
        if direction == 'UP':
            next_head = (self.snake[0][0], self.snake[0][1] - 1)
        elif direction == 'DOWN':
            next_head = (self.snake[0][0], self.snake[0][1] + 1)
        elif direction == 'LEFT':
            next_head = (self.snake[0][0] - 1, self.snake[0][1])
        else:  # RIGHT
            next_head = (self.snake[0][0] + 1, self.snake[0][1])

        if next_head in self.snake or \
                next_head[0] < 0 or next_head[0] >= self.width or \
                next_head[1] < 0 or next_head[1] >= self.height:
            self.game_over = True
            return self.score, self.game_over

        self.snake.insert(0, next_head)

        if next_head == self.food:
            self.score += 1
            self.place_food()
        else:
            self.snake.pop()

        return self.score, self.game_over

    def display(self):
        clear_output(wait=True)
        for y in range(self.height):
            for x in range(self.width):
                if (x, y) == self.snake[0]:
                    print('S', end='')
                elif (x, y) in self.snake:
                    print('s', end='')
                elif (x, y) == self.food:
                    print('F', end='')
                else:
                    print('.', end='')
            print()
        print(f'Score: {self.score}')

### Q-learning 파라미터

#### 탐색 공간 정의

In [21]:
from collections import namedtuple

EMPTY = 0
WALL = 1
FOOD = 2
BODY = 3

Node = namedtuple("node", ["state", "action", "reward", "next_state_idx"])
State = namedtuple("state", ["idx", "up", "down", "left", "right"])

Node(state=State(idx=0, up=0, down=0, left=0, right=0), action="UP", reward=0, next_state_idx=1)

env = []

initial_state = State(idx=0, up=0, down=0, left=0, right=0)

env.append(Node(state=initial_state, action="UP", reward=0, next_state_idx=1))
env.append(Node(state=initial_state, action="DOWN", reward=0, next_state_idx=2))
env.append(Node(state=initial_state, action="LEFT", reward=0, next_state_idx=3))
env.append(Node(state=initial_state, action="RIGHT", reward=0, next_state_idx=4))

In [22]:
def calculate_reward(state: State, action):
    death_penalty = -10
    food_reward = 100

    default_reward = -1

    if (state.up == BODY or state.up == WALL) and action == "UP":
        return death_penalty
    
    elif (state.down == BODY or state.down == WALL) and action == "DOWN":
        return death_penalty
    
    elif (state.left == BODY or state.left == WALL) and action == "LEFT":
        return death_penalty
    
    elif (state.right == BODY or state.right == WALL) and action == "RIGHT":
        return death_penalty


    if state.up == FOOD and action == "UP":
        return food_reward

    elif state.down == FOOD and action == "DOWN":
        return food_reward
    
    elif state.left == FOOD and action == "LEFT":
        return food_reward
    
    elif state.right == FOOD and action == "RIGHT":
        return food_reward
    

    return default_reward


In [24]:
def investigate_env(snake_position, food_position, current_state_idx, width=10, height=10):
    # 스네이크 머리 위치
    head_x, head_y = snake_position[0]
    
    # 상태 초기화
    up = down = left = right = EMPTY
    
    # 상하좌우 탐색
    if (head_x, head_y - 1) in snake_position[1:]:
        up = BODY
    elif head_y - 1 < 0:
        up = WALL
    elif (head_x, head_y - 1) == food_position:
        up = FOOD

    if (head_x, head_y + 1) in snake_position[1:]:
        down = BODY
    elif head_y + 1 >= height:
        down = WALL
    elif (head_x, head_y + 1) == food_position:
        down = FOOD

    if (head_x - 1, head_y) in snake_position[1:]:
        left = BODY
    elif head_x - 1 < 0:
        left = WALL
    elif (head_x - 1, head_y) == food_position:
        left = FOOD

    if (head_x + 1, head_y) in snake_position[1:]:
        right = BODY
    elif head_x + 1 >= width:
        right = WALL
    elif (head_x + 1, head_y) == food_position:
        right = FOOD

    # 현재 상태 반환
    return State(idx=current_state_idx, up=up, down=down, left=left, right=right)


In [25]:
q_table = np.zeros((4**4, 4))

alpha = 0.1
gamma = 0.99
epsilon = 0.1
episodes = 1000

In [None]:
for episode in range(episodes):
    game = SnakeGame()
    game.reset()

    current_state_idx = 0  # 초기 상태 인덱스
    done = False

    while not done:
        # 엡실론-그리디 정책에 따라 행동 선택
        if random.uniform(0, 1) < epsilon:
            action = random.choice(["UP", "DOWN", "LEFT", "RIGHT"])
        else:
            action = ["UP", "DOWN", "LEFT", "RIGHT"][np.argmax(q_table[current_state_idx])]
        
        # 선택한 행동 실행
        score, done = game.move(action)

        # 현재 상태 정보를 바탕으로 다음 상태 조사
        next_state = investigate_env(game.snake, game.food, current_state_idx, game.width, game.height)

        # 다음 상태의 인덱스를 계산
        next_state_idx = state_to_index(next_state)
        
        # 보상 계산
        reward = calculate_reward(current_state, action)

        # Q-테이블 업데이트
        old_value = q_table[current_state_idx, ["UP", "DOWN", "LEFT", "RIGHT"].index(action)]
        next_max = np.max(q_table[next_state_idx])
        new_value = old_value + alpha * (reward + gamma * next_max - old_value)
        q_table[current_state_idx, ["UP", "DOWN", "LEFT", "RIGHT"].index(action)] = new_value
        
        # 상태 업데이트
        current_state_idx = next_state_idx

    if episode % 100 == 0:
        print(f"Episode: {episode}, Score: {score}")
