# Snake game using search algorithms

In [1]:
import enum
import time
from collections import deque
import heapq
import sys

MAP_WIDTH: int
MAP_HEIGHT: int
test_file_pathes = ["test1.txt", "test2.txt", "test3.txt"]

In [2]:
def get_input(test_file_path):
    global MAP_WIDTH
    global MAP_HEIGHT
    f = open(test_file_path, "r")
    MAP_HEIGHT, MAP_WIDTH = list(map(int, f.readline().split(",")))
    starting_point = tuple(map(int, f.readline().split(",")))
    seeds_count = int(f.readline())
    seeds = []
    for _ in range(seeds_count):
        seeds.append(tuple(map(int, f.readline().split(","))))
    f.close()
    return starting_point, seeds

### - Initial State: all seeds are left and the snake is at the given starting position
### - Actions: Up, Left, Down, Right
### - Transition model: it's implented in the state class and generates a new state or false from current state and action
### - Goal State: no seed is left

### A helper class for actions

In [3]:
class Direction(enum.IntEnum):
    Up = 0
    Right = 1
    Down = 2
    Left = 3

    @staticmethod
    def actions():
        return [Direction.Right, Direction.Down, Direction.Left, Direction.Up]

    @staticmethod
    def delta_positions():
        return {
            Direction.Up: (-1, 0),
            Direction.Right: (0, 1),
            Direction.Down: (1, 0),
            Direction.Left: (0, -1),
        }

    @staticmethod
    def opposite_directions():
        return {
            Direction.Up: Direction.Down,
            Direction.Right: Direction.Left,
            Direction.Down: Direction.Up,
            Direction.Left: Direction.Right,
        }

### State class  
We keep snake positions in a list(ordered) and seed positions in another list  
there is also a boolean named ate_before to check whether seed was eaten in last step or not  
the transition method generates a new State from a given action and current State

In [4]:
class State:
    def __eq__(self, other: "State"):
        return (self.snake_positions == other.snake_positions) and (
            set(self.seed_positions) == set(other.seed_positions)
        )

    def __hash__(self):
        return hash((tuple(self.snake_positions), tuple(self.seed_positions)))

    def __init__(
        self, _snake_positions: list, _seed_positions: list, _ate_before: bool
    ):
        self.snake_positions = _snake_positions
        self.seed_positions = _seed_positions
        self.ate_before = _ate_before

    def is_goal(self):
        return len(self.seed_positions) == 0

    def correct_wall_effect(self, head: tuple):
        if head[0] == -1:
            return (MAP_HEIGHT - 1, head[1])
        elif head[0] == MAP_HEIGHT:
            return (0, head[1])
        elif head[1] == -1:
            return (head[0], MAP_WIDTH - 1)
        elif head[1] == MAP_WIDTH:
            return (head[0], 0)
        else:
            return head

    def calculate_heuristic(self, heuristic_func):
        return heuristic_func(self)

    def transition(self, action: Direction, parent_action):
        if (len(self.snake_positions) > 1) and (
            action == Direction.opposite_directions()[parent_action]
        ):
            return False
        delta_position = Direction.delta_positions()[action]
        old_head = self.snake_positions[-1]
        new_head = (old_head[0] + delta_position[0], old_head[1] + delta_position[1])
        new_head = self.correct_wall_effect(new_head)
        updated_positions = (
            self.snake_positions[:] if self.ate_before else self.snake_positions[1:]
        )
        if new_head in updated_positions:
            return False
        seeds_just_positions = list(
            map(lambda seed: (seed[0], seed[1]), self.seed_positions)
        )
        new_snake_positions = []
        if self.ate_before:
            new_snake_positions = self.snake_positions[:] + [new_head]
        else:
            new_snake_positions = self.snake_positions[1:] + [new_head]
        if new_head not in seeds_just_positions:
            return State(new_snake_positions, self.seed_positions[:], False)
        else:
            index = seeds_just_positions.index(new_head)
            seed = self.seed_positions[index]
            if seed[2] == 2:
                new_seed_positions = self.seed_positions[:]
                del new_seed_positions[index]
                new_seed_positions += [(seed[0], seed[1], 1)]
                return State(new_snake_positions, new_seed_positions, True)
            else:
                new_seed_positions = self.seed_positions[:]
                del new_seed_positions[index]
                return State(new_snake_positions, new_seed_positions, True)

### Node class
this class stores it's state, parent, action which lead to this node, path cost until this node

In [5]:
class Node:
    def __init__(
        self, _state: State, _parent: "Node", _action: Direction, _path_cost: int
    ):
        self.state = _state
        self.parent = _parent
        self.action = _action
        self.path_cost = _path_cost

    def __lt__(self, other):
        return self

### Helper for running searches

In [6]:
def run(algorithm_func, start_node, *args, **kwargs):
    tic = time.time()
    final, states_met, distinct_states_met = algorithm_func(start_node, *args, **kwargs)
    toc = time.time()
    tic1 = time.time()
    final, states_met, distinct_states_met = algorithm_func(start_node, *args, **kwargs)
    toc1 = time.time()
    answer_actions = []
    answer_path_distance = -1
    while final is not None:
        answer_path_distance += 1
        answer_actions = [final.action] + answer_actions
        final = final.parent
    print("Distance: %d" %(answer_path_distance))
    print("Path: ", end="")
    for action in answer_actions[1:]:
        print(action.name, end=" ")
    print()
    print("States met: " + str(states_met))
    print("Distinct States met: " + str(distinct_states_met))
    print("Average Time taken: " + str((toc - tic + toc1 - tic1)/2))

In [7]:
def run_for_tests(algorithm_func, *args, **kwargs):
    i = 1
    for test_path in test_file_pathes:
        print("TEST NUMBER %d" %(i))
        starting_point, seeds = get_input(test_path)
        start_node = Node(State([starting_point], seeds, False), None, None, 0)
        run(algorithm_func, start_node, *args, **kwargs)
        print("--------------------------------------------------------------------------------------------------------------------------")
        i += 1

## Uninformed Search

### BFS
we use a queue for implementing the frontier and in each state we expand the shallowest unexpanded node  
bfs is optimal when cost is 1  
time complexity: O(b^d)  
space complexity: O(b^d)  
It's space is expensive  

In [8]:
def bfs(start_node: Node):
    states_met = 0
    distinct_states_met = 0
    if start_node.state.is_goal():
        return start_node
    frontier = deque([start_node])
    frontier_state_set = set()
    frontier_state_set.add(start_node.state)
    explored = set()
    while True:
        if not frontier:
            return False
        current = frontier.pop()
        frontier_state_set.remove(current.state)
        explored.add(current.state)
        for action in Direction.actions():
            new_state = current.state.transition(action, current.action)
            if new_state is not False:
                states_met += 1
                child = Node(new_state, current, action, current.path_cost + 1)
                if (
                    child.state not in explored
                    and child.state not in frontier_state_set
                ):
                    distinct_states_met += 1
                    if child.state.is_goal():
                        return child, states_met, distinct_states_met
                    frontier.appendleft(child)
                    frontier_state_set.add(child.state)

In [9]:
print("BFS:")
run_for_tests(bfs)

BFS:
TEST NUMBER 1
Distance: 12
Path: Down Left Right Right Down Down Right Right Down Right Down Down 
States met: 8670
Distinct States met: 4325
Average Time taken: 0.06652545928955078
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 2
Distance: 15
Path: Right Left Left Up Right Up Left Left Up Up Left Left Left Left Up 
States met: 101649
Distinct States met: 46072
Average Time taken: 0.8783884048461914
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 3
Distance: 25
Path: Right Up Right Down Down Down Right Right Down Right Right Right Down Down Right Right Up Left Left Down Left Left Left Up Up 
States met: 472218
Distinct States met: 213962
Average Time taken: 4.549342155456543
--------------------------------------------------------------------------------------------------------------------------


### Ids
we use a stack for implementing the frontier and in each state we expand from top  in dfs
dfs is not optimal  
time complexity: O(b^m), m is the maximum depth  
space complexity: O(bm)  
It's space is less expensive than bfs    

In IDS we give a maximum depth to our dfs so it does not go to the end like dfs  
ids is optimal if cost is 1  
time complexity: O(b^d)  
space complexity: O(b*d)  
it somehow keeps the best charecteristics of bfs and dfs

In [10]:
def dfs(start_node: Node, depth: int):
    states_met = 0
    distinct_states_met = 0
    if start_node.state.is_goal():
        return start_node
    frontier = deque([start_node])
    explored = dict()
    while True:
        if not frontier:
            return False, 0, 0
        current = frontier.pop()
        if current.path_cost == depth:
            continue
        explored[current.state] = current.path_cost
        for action in Direction.actions():
            new_state = current.state.transition(action, current.action)
            if new_state is not False:
                states_met += 1
                child = Node(new_state, current, action, current.path_cost + 1)
                if (
                    child.state not in explored
                    or child.path_cost < explored[child.state]
                ):
                    distinct_states_met += 1
                    if child.state.is_goal():
                        return child, states_met, distinct_states_met
                    frontier.append(child)
def ids(start_node: Node):
    states_met = 0
    distinct_states_met = 0
    for depth in range(sys.maxsize):
        current_states = 0
        node, current_states, distinct_states_met = dfs(start_node, depth)
        states_met = states_met + current_states
        if node != False:
            return node, states_met, distinct_states_met
    return False

In [11]:
print("IDS:")
run_for_tests(ids)

IDS:
TEST NUMBER 1
Distance: 12
Path: Left Down Up Up Left Up Left Left Up Up Left Left 
States met: 18298
Distinct States met: 11708
Average Time taken: 0.32753264904022217
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 2
Distance: 15
Path: Up Right Down Left Left Up Up Up Up Left Up Left Left Left Left 
States met: 121842
Distinct States met: 74125
Average Time taken: 3.8491894006729126
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 3
Distance: 25
Path: Up Right Down Down Down Right Down Right Right Down Down Right Right Right Up Right Right Down Left Left Left Up Up Left Left 
States met: 792297
Distinct States met: 455397
Average Time taken: 24.947928190231323
--------------------------------------------------------------------------------------------------------------------------


## Informed

### A*
we use a list and keep it sorted using heapq  
we keep the list order by g(n) + heuristic(n)  
A* is optimal if heuristic is admissible and non negative in tree search   
in graph search A* is optimal if heuristic is consistent  

In [12]:
def a_star(start_node: Node, heuristic_func, weight=1):
    states_met = 0
    distinct_states_met = 0
    if start_node.state.is_goal():
        return start_node
    frontier = []
    heapq.heappush(
        frontier,
        (
            start_node.path_cost
            + weight * start_node.state.calculate_heuristic(heuristic_func),
            start_node,
        ),
    )
    frontier_state_set = set()
    frontier_state_set.add(start_node.state)
    explored = set()
    while True:
        if not frontier:
            return False, 0, 0
        current = heapq.heappop(frontier)
        frontier_state_set.remove(current[1].state)
        explored.add(current[1].state)
        for action in Direction.actions():
            new_state = current[1].state.transition(action, current[1].action)
            if new_state is not False:
                child = Node(
                    new_state, current[1], action, current[1].path_cost + 1
                )
                states_met += 1
                if (
                    child.state not in explored
                    and child.state not in frontier_state_set
                ):
                    distinct_states_met += 1
                    if child.state.is_goal():
                        return child, states_met, distinct_states_met
                    heapq.heappush(
                        frontier,
                        (
                            child.path_cost
                            + weight
                            * child.state.calculate_heuristic(heuristic_func),
                            child,
                        ),
                    )
                    frontier_state_set.add(child.state)

#### Heuristics

In [13]:
def heuristic1(state: State):
    return len(state.seed_positions)

def heuristic2(state: State):
    return len(state.seed_positions) - len(state.snake_positions)

#### Heuristics explained
heuristic 1 is consistent because it's always non encreasing in it's path  
so because of it's consistency it causes optimal answers  
heuristic 2 is admissible because it never overestimates the cost  
but heuristic 2 is negative and it may not always make the answer optimal  
as we can see less states are met in heuristic2 and it was a more realistic estimate

In [14]:
print("A* with heuristic1:")
run_for_tests(a_star, heuristic1)

A* with heuristic1:
TEST NUMBER 1
Distance: 12
Path: Left Down Up Up Left Left Up Left Up Up Left Left 
States met: 4585
Distinct States met: 2588
Average Time taken: 0.04141724109649658
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 2
Distance: 15
Path: Left Right Right Up Left Left Left Up Up Up Left Left Up Left Left 
States met: 51517
Distinct States met: 26797
Average Time taken: 0.4773130416870117
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 3
Distance: 25
Path: Up Right Down Down Right Down Down Right Right Right Down Right Down Right Up Right Right Down Left Left Left Left Left Up Up 
States met: 269463
Distinct States met: 132120
Average Time taken: 2.9520214796066284
--------------------------------------------------------------------------------------------------------------------------


In [15]:
print("A* with heuristic2:")
run_for_tests(a_star, heuristic2)

A* with heuristic2:
TEST NUMBER 1
Distance: 12
Path: Left Down Up Left Up Up Left Left Left Up Up Left 
States met: 2291
Distinct States met: 1595
Average Time taken: 0.027997851371765137
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 2
Distance: 15
Path: Up Right Down Left Left Up Up Left Up Up Left Up Left Left Left 
States met: 16124
Distinct States met: 10174
Average Time taken: 0.15652191638946533
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 3
Distance: 25
Path: Up Right Down Down Right Down Right Right Down Right Right Down Down Right Right Right Up Left Left Down Left Left Up Up Left 
States met: 101398
Distinct States met: 54149
Average Time taken: 1.0170587301254272
--------------------------------------------------------------------------------------------------------------------------


### Weighted A*  
The idea of this search is to speed up search at the expense of optimality  
the difference is that we use weight and multiply it by heuristic  
f = g + alpha * heuristic

In [16]:
print("A* with heuristic1 with weight=2.1:")
run_for_tests(a_star, heuristic1, 2.1)

A* with heuristic1 with weight=2.1:
TEST NUMBER 1
Distance: 12
Path: Left Down Up Left Up Up Left Left Left Up Up Left 
States met: 2441
Distinct States met: 1570
Average Time taken: 0.02336442470550537
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 2
Distance: 15
Path: Right Up Left Left Down Left Up Up Up Up Left Left Up Left Left 
States met: 17170
Distinct States met: 10573
Average Time taken: 0.15696382522583008
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 3
Distance: 25
Path: Up Right Down Down Down Right Down Right Right Down Down Right Right Right Up Right Right Down Left Left Left Up Left Up Left 
States met: 129620
Distinct States met: 64358
Average Time taken: 1.228269100189209
--------------------------------------------------------------------------------------------------------------------------


In [17]:
print("A* with heuristic1 with weight=5:")
run_for_tests(a_star, heuristic1, 5)

A* with heuristic1 with weight=5:
TEST NUMBER 1
Distance: 14
Path: Up Up Right Right Down Right Down Down Right Down Down Down Down Down 
States met: 2596
Distinct States met: 1237
Average Time taken: 0.03234517574310303
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 2
Distance: 17
Path: Up Left Down Right Right Up Up Left Left Left Up Up Left Up Left Left Left 
States met: 1782
Distinct States met: 1495
Average Time taken: 0.01472020149230957
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 3
Distance: 26
Path: Up Right Down Down Down Right Right Down Right Down Left Up Right Right Right Down Right Down Right Up Right Up Left Left Down Down 
States met: 34537
Distinct States met: 19853
Average Time taken: 0.3186228275299072
-------------------------------------------------------------------------------------------

In [18]:
print("A* with heuristic2 with weight=2.1:")
run_for_tests(a_star, heuristic2, 2.1)

A* with heuristic2 with weight=2.1:
TEST NUMBER 1
Distance: 14
Path: Up Up Right Right Down Right Down Right Down Down Down Down Down Down 
States met: 941
Distinct States met: 718
Average Time taken: 0.01059424877166748
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 2
Distance: 17
Path: Up Left Down Right Right Up Left Up Up Up Left Left Left Left Left Up Left 
States met: 3623
Distinct States met: 2632
Average Time taken: 0.03089439868927002
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 3
Distance: 25
Path: Up Right Down Down Down Right Right Down Right Right Right Down Right Down Right Right Up Left Left Down Left Up Left Up Left 
States met: 8405
Distinct States met: 5655
Average Time taken: 0.07089519500732422
--------------------------------------------------------------------------------------------------

In [19]:
print("A* with heuristic2 with weight=5:")
run_for_tests(a_star, heuristic2, 5)

A* with heuristic2 with weight=5:
TEST NUMBER 1
Distance: 14
Path: Up Up Right Right Down Right Down Right Down Down Down Down Down Down 
States met: 206
Distinct States met: 194
Average Time taken: 0.0031766891479492188
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 2
Distance: 17
Path: Up Left Down Right Right Up Left Up Up Left Left Up Left Up Left Left Left 
States met: 1011
Distinct States met: 919
Average Time taken: 0.013422846794128418
--------------------------------------------------------------------------------------------------------------------------
TEST NUMBER 3
Distance: 30
Path: Up Right Down Down Down Right Right Down Right Right Down Left Up Left Left Left Left Left Down Left Left Down Right Right Down Left Down Left Up Up 
States met: 4499
Distinct States met: 3963
Average Time taken: 0.03813648223876953
---------------------------------------------------------------------------

As we see in the results of weighted A* 
when alpha was 2.1 the optimality is not changed a lot  
but when alpha is 5 it runs faster and states are less but the optimality is gone