In [1]:
import numpy as np
from queue import PriorityQueue
from enum import Enum
import math

In [4]:
start = (0, 0)
goal = (4, 4)
grid = np.array([
    [0, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0],
    [0, 0, 0, 1, 1, 0],
    [0, 0, 0, 1, 0, 0],
])

In [2]:
# Implemention of class Action of type enum:
class Action(Enum):
    
    UP = (-1, 0, 1)
    DOWN = (1, 0, 1)
    LEFT = (0, -1, 1)
    RIGHT = (0, 1, 1)
    
    def __str__(self):
        
        if self == self.UP:
            return '^'
        
        elif self == self.DOWN:
            return 'V'
        
        elif self == self.LEFT:
            return '<'
        
        elif self == self.RIGHT:
            return '>'
    
    @property
    def cost(self):
        return self.value[2]
    
    @property
    def delta(self):
        return (self.value[0], self.value[1])

In [3]:
def valid_actions(grid, current_node):
    
    valid = [Action.UP, Action.DOWN, Action.LEFT, Action.RIGHT]
    n, m = grid.shape[0] - 1, grid.shape[1] - 1
    x, y = current_node
    
    if x - 1 < 0 or grid[x-1, y] == 1:
        valid.remove(Action.UP)
        
    if x + 1 > n or grid[x+1, y] == 1:
        valid.remove(Action.DOWN)
    
    if y - 1 < 0 or grid[x, y-1] == 1:
        valid.remove(Action.LEFT)
        
    if y + 1 > m or grid[x, y+1] == 1:
        valid.remove(Action.RIGHT)
        
    return valid

In [5]:
def visualize_path(grid, start, goal):
    
    sgrid = np.zeros(np.shape(grid), dtype = np.str)
    sgrid[:] = ' '
    sgrid[grid[:] == 1] = 'O'
    pos = start
    
    for a in path:
        
        da = a.value
        sgrid[pos[0], pos[1]] = str(a)
        pos = (pos[0] + da[0], pos[1] + da[1])
    sgrid[pos[0], pos[1]] = 'G'
    sgrid[start[0], start[1]] = 'S'
    return sgrid

In [6]:
def heuristic(position, goal_position):
    h = math.sqrt((position[0] - goal_position[0])**2 + (position[0] - goal_position[0])**2)
    return h

In [11]:
def a_star(grid, h, start, goal):
    
    path = []
    path_cost = 0
    q = PriorityQueue()
    q.put((0, start))
    visited = set()
    visited.add(start)
    
    branch = {}
    found = False
    
    while not q.empty():
        
        item = q.get()
        current_node = item[1]
        
        if current_node == start:
            current_cost = 0.0
            
        else:
            current_cost = branch[current_node][0]
        
        if current_node == goal:
            print("Found a path")
            found = True
            break
            
        else:
            
            for action in valid_actions(grid, current_node):
                
                da = action.delta
                next_node = (current_node[0] + da[0], current_node[1] + da[1])
                branch_cost = current_cost + action.cost
                queue_cost = branch_cost + h(next_node, goal)
                
                if next_node not in visited:
                    q.put((queue_cost, next_node))
                    visited.add(next_node)
                    branch[next_node] = (queue_cost, current_node, action)
                    
    
    if found:
        
        n = goal
        path_cost = branch[n][0]
        while branch[n][1] != start:
            path.append(branch[n][2])
            n = branch[n][1]
        path.append(branch[n][2])
        
    else:
        
        print('*********************')
        print('Failed to find a path')
        print('*********************')
        
    return path[::-1], path_cost

In [12]:
path, path_cost = a_star(grid, heuristic, start, goal)
print(path, path_cost)

Found a path
[<Action.DOWN: (1, 0, 1)>, <Action.RIGHT: (0, 1, 1)>, <Action.RIGHT: (0, 1, 1)>, <Action.DOWN: (1, 0, 1)>, <Action.RIGHT: (0, 1, 1)>, <Action.RIGHT: (0, 1, 1)>, <Action.RIGHT: (0, 1, 1)>, <Action.DOWN: (1, 0, 1)>, <Action.DOWN: (1, 0, 1)>, <Action.LEFT: (0, -1, 1)>] 35.45584412271572


In [13]:
path_rep = visualize_path(grid, start, goal)
print(path_rep)

[['S' 'O' ' ' ' ' ' ' ' ']
 ['>' '>' 'V' ' ' ' ' ' ']
 [' ' 'O' '>' '>' '>' 'V']
 [' ' ' ' ' ' 'O' 'O' 'V']
 [' ' ' ' ' ' 'O' 'G' '<']]
