In [None]:
import heapq

class Node:
    def __init__(self, state, parent=None, action=None, cost=0, heuristic=0):
        self.state = state  # The state represented by this node
        self.parent = parent  # Parent node in the search tree
        self.action = action  # Action that led to this state
        self.cost = cost  # Cost from the initial state to this state
        self.heuristic = heuristic  # Heuristic estimate of cost to goal state

    def total_cost(self):
        return self.cost + self.heuristic

    def __lt__(self, other):
        return self.total_cost() < other.total_cost()

def astar_search(initial_state, goal_state, actions, transition, heuristic):
    open_list = []  # Priority queue for nodes to be explored
    closed_set = set()  # Set of explored nodes

    # Create the initial node
    initial_node = Node(initial_state, None, None, 0, heuristic(initial_state, goal_state))
    heapq.heappush(open_list, initial_node)

    while open_list:
        current_node = heapq.heappop(open_list)

        if current_node.state == goal_state:
            return build_path(current_node)

        closed_set.add(current_node.state)

        for action in actions(current_node.state):
            new_state = transition(current_node.state, action)
            if new_state in closed_set:
                continue

            new_cost = current_node.cost + 1  # Assuming constant step cost of 1
            new_node = Node(new_state, current_node, action, new_cost, heuristic(new_state, goal_state))
            heapq.heappush(open_list, new_node)

    return None  # No path found

# Rest of the code remains the same


def build_path(node):
    path = []
    while node:
        path.append((node.state, node.action))
        node = node.parent
    return list(reversed(path))

# Example heuristic function (Manhattan distance)
def heuristic(state, goal_state):
    return abs(state[0] - goal_state[0]) + abs(state[1] - goal_state[1])

# Example usage:
initial_state = (0, 0)
goal_state = (4, 4)

def actions(state):
    x, y = state
    return [(x+1, y), (x-1, y), (x, y+1), (x, y-1)]

def transition(state, action):
    return action

path = astar_search(initial_state, goal_state, actions, transition, heuristic)
print(path)

[((0, 0), None), ((0, 1), (0, 1)), ((0, 2), (0, 2)), ((1, 2), (1, 2)), ((1, 3), (1, 3)), ((2, 3), (2, 3)), ((3, 3), (3, 3)), ((4, 3), (4, 3)), ((4, 4), (4, 4))]
