In [30]:
class Node:
    def __init__(self, row, col, value):
        self.id = str(row)+"-"+str(col)
        self.row = row
        self.col = col
        self.value = value
        self.distanceFromStart = float('inf')
        self.distanceFromEnd = float('inf')
        self.prev = None
    
def aStarAlgorithm(startRow, startCol, endRow, endCol, graph):
# Write your code here.
    nodes = createNodes(graph)
    startNode = nodes[startRow][startCol]
    endNode = nodes[endRow][endCol]

    startNode.distanceFromStart = 0
    startNode.distanceFromEnd = calculateDistance(startNode, endNode)
    minHeap = MinHeap([startNode])
    while not minHeap.isEmpty():
        currentNode = minHeap.remove()
        if currentNode == endNode:
            break
        neighbours = getNeighbours(currentNode, nodes)
        for neighbour in neighbours:
            if neighbour.value == 1:
                continue
            distanceFromStart = currentNode.distanceFromStart + 1
            if distanceFromStart >= neighbour.distanceFromStart:
                continue
            neighbour.prev = currentNode
            neighbour.distanceFromStart = distanceFromStart
            neighbour.distanceFromEnd = distanceFromStart + calculateDistance(neighbour, endNode)

            if minHeap.contains(neighbour):
                minHeap.update(neighbour)
            else:
                minHeap.insert(neighbour)

    return constructPath(endNode)

def getNeighbours(node, nodes):
    neighbours = []
    i = node.row
    j = node.col

    if i > 0:
        neighbours.append(nodes[i-1][j])
    if i< len(nodes)-1:
        neighbours.append(nodes[i+1][j])
    if j > 0:
        neighbours.append(nodes[i][j-1])
    if j < len(nodes[0])-1:
        neighbours.append(nodes[i][j+1])
    return neighbours

def createNodes(graph):
    nodes = []
    for i, row in enumerate(graph):
        nodes.append([])
        for j, value in enumerate(row):
            node = Node(i,j,value)
            nodes[i].append(node)
    return nodes

def calculateDistance(node1, node2):
    r1 = node1.row
    c1 = node1.col
    r2 = node2.row
    c2 = node2.col

    d = abs(r1-r2) + abs(c1-c2)
    return d

def constructPath(node):
    if node.prev is None:
        return []
    path = []
    currentNode = node
    while currentNode is not None:
        path.append([currentNode.row, currentNode.col])
        currentNode = currentNode.prev
    return path[::-1]




In [31]:
class MinHeap:
    def __init__(self, array):
        self.heap = array
        self.position = {node.id: idx for idx, node  in enumerate(array) }
        self.buildHeap()

    def buildHeap(self):
        m = len(self.heap)
        firstParent = (m-2)//2

        for i in reversed(range(firstParent + 1)):
            self.shiftDown(i)
    def shiftDown(self, index):
        left = 2*index + 1
        m = len(self.heap)
        while left < m:
            right = 2*index + 2 if 2*index + 2 < m else -1
            if right != -1 and self.heap[right].distanceFromEnd < self.heap[left].distanceFromEnd:
                indexToSwap = right
            else:
                indexToSwap = left
            if self.heap[index].distanceFromEnd > self.heap[indexToSwap].distanceFromEnd:
                self.swap(indexToSwap, index)
                index = indexToSwap
                left = 2*index + 1
            else:
                return
    def shiftUp(self, index):
        parentIndex = (index - 1)//2
        while parentIndex >= 0 and self.heap[parentIndex].distanceFromEnd > self.heap[index].distanceFromEnd:
            self.swap(parentIndex, index)
            index = parentIndex
            parentIndex = (index-1)//2
    def isEmpty(self):
        return len(self.heap) == 0
    def remove(self):
        if self.isEmpty():
            return None
        self.swap(0, len(self.heap)-1)
        node = self.heap.pop()

        del self.position[node.id]
        self.shiftDown(0)
        return node

    def insert(self, node):
        self.heap.append(node)
        self.position[node.id] = len(self.heap) - 1
        self.shiftUp(len(self.heap) - 1)

    def update(self, node):
        index = self.position[node.id]
        self.heap[index] = node
        self.shiftUp(index)

    def contains(self, node):
        return node.id in self.position


    def swap(self, index1, index2):
        self.position[self.heap[index1].id] = index2
        self.position[self.heap[index2].id] = index1
        self.heap[index1], self.heap[index2] = self.heap[index2], self.heap[index1]			

In [32]:
startRow =  0
startCol = 1
endRow = 4
endCol = 3
graph =  [
[0, 0, 0, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 1, 1],
[0, 0, 0, 0, 0]
]
aStarAlgorithm(startRow, startCol, endRow, endCol, graph)

[[0, 1], [0, 0], [1, 0], [2, 0], [2, 1], [3, 1], [4, 1], [4, 2], [4, 3]]