# Part III: Solving Snake using Weighted Importance Sampling (an off-policy method)

- Here you will solve the game of Snake using Weighted Importance Sampling (an off-policy method). Refer to [Sutton and Barto](../SuttonBarto.pdf) for an explanation of this algorithm.

- Again, you will make an engine that can solve the game of Snake.

- You are not required to make a GUI for this game (although bonus points for those who do!)

- The snake will be present in a 100x100 pixel grid, and will start at the center. You can choose whether it dies when hitting a wall or if it can pass through walls. Food will appear at random points on the grid. The snake gets a reward of $+1$ whenever it eats the food and it's length increases by one unit. If the snake intersects itself, it dies and the game is over. Refer [here](https://en.wikipedia.org/wiki/Snake_(video_game_genre)) for all the rules.

- You can use the "WASD" keys to represent the actions that can be taken at any point in the game.

In [None]:
import numpy as np
import random

class SnakeEngine:
    def __init__(self, grid_size=100, initial_length=3):
        self.grid_size = grid_size
        self.snake = [(grid_size // 2, grid_size // 2)]
        self.direction = random.choice(['UP', 'DOWN', 'LEFT', 'RIGHT'])
        self.food = self.generate_food()
        self.initial_length = initial_length
        self.steps_since_food = 0
        self.max_steps_since_food = grid_size * 2

    def generate_food(self):
        while True:
            food = (random.randint(0, self.grid_size - 1), random.randint(0, self.grid_size - 1))
            if food not in self.snake:
                return food

    def move(self, action):
        x, y = self.snake[0]

        if action == 'UP':
            x -= 1
        elif action == 'DOWN':
            x += 1
        elif action == 'LEFT':
            y -= 1
        elif action == 'RIGHT':
            y += 1

        new_head = (x, y)

        if (
            new_head in self.snake or
            x < 0 or x >= self.grid_size or
            y < 0 or y >= self.grid_size
        ):
            return self._get_state(), -1, True

        self.snake.insert(0, new_head)

        if new_head == self.food:
            self.food = self.generate_food()
            self.steps_since_food = 0
            return self._get_state(), 1, False
        else:
            self.snake.pop()

        self.steps_since_food += 1

        if self.steps_since_food >= self.max_steps_since_food:
            return self._get_state(), -1, True

        return self._get_state(), 0, False

    def _get_state(self):
        state = np.zeros((self.grid_size, self.grid_size), dtype=int)
        for x, y in self.snake:
            state[x, y] = 1
        state[self.food] = 2
        return state

class WeightedImportanceSampling:
    def __init__(self):
        self.Q = dict()
        self.C = dict()

    def choose_action(self, state, epsilon=0.1):
        if random.random() < epsilon or state not in self.Q:
            return random.choice(['UP', 'DOWN', 'LEFT', 'RIGHT'])
        else:
            return max(self.Q[state], key=self.Q[state].get)

    def update_q_value(self, state, action, reward, next_state, alpha=0.1, gamma=0.9):
        if state not in self.Q:
            self.Q[state] = {'UP': 0, 'DOWN': 0, 'LEFT': 0, 'RIGHT': 0}

        if next_state not in self.Q:
            self.Q[next_state] = {'UP': 0, 'DOWN': 0, 'LEFT': 0, 'RIGHT': 0}

        weight = 1 / (1e-5 + self.C.get((state, action), 1e-5))
        self.C[(state, action)] = self.C.get((state, action), 0) + weight

        target = reward + gamma * max(self.Q[next_state].values())
        error = target - self.Q[state][action]

        self.Q[state][action] += alpha * weight * error

engine = SnakeEngine()
wis = WeightedImportanceSampling()

for _ in range(1000):
    state = engine._get_state()
    action = wis.choose_action(state)
    next_state, reward, done = engine.move(action)
    wis.update_q_value(state, action, reward, next_state)